Skip to content

Commit

Permalink
Simplify AtenSource
Browse files Browse the repository at this point in the history
  • Loading branch information
will-cromar committed Nov 8, 2023
1 parent 8894096 commit 4225deb
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 37 deletions.
5 changes: 2 additions & 3 deletions torch_xla/csrc/runtime/pjrt_computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,6 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
for (auto& tensor : tensors) {
xla::PjRtDevice* pjrt_device = StringToPjRtDevice(tensor->device());

total_size += xla::ShapeUtil::ByteSizeOf(tensor->shape());

std::shared_ptr<xla::PjRtBuffer> buffer =
std::move(client_
->BufferFromHostBuffer(
Expand All @@ -310,7 +308,8 @@ std::vector<ComputationClient::DataPtr> PjRtComputationClient::TransferToServer(
.value());

ComputationClient::DataPtr data =
std::make_shared<PjRtData>(tensor->device(), tensor->shape(), buffer);
std::make_shared<PjRtData>(tensor->device(), buffer);
total_size += xla::ShapeUtil::ByteSizeOf(data->shape());
datas.push_back(data);
}
OutboundDataMetric()->AddSample(total_size);
Expand Down
56 changes: 30 additions & 26 deletions torch_xla/csrc/runtime/tensor_source.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,34 +60,22 @@ class TensorSource {

virtual const void* data() const = 0;

virtual const xla::Shape& shape() const = 0;
virtual xla::PrimitiveType primitive_type() const = 0;

const std::string& device() const { return device_; }

virtual std::vector<int64_t> byte_strides() const {
std::vector<int64_t> byte_strides(shape().dimensions_size());
XLA_CHECK_OK(
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides)));
return byte_strides;
}
virtual std::vector<int64_t> dimensions() const = 0;

virtual std::vector<int64_t> dimensions() const {
auto dimensions = shape().dimensions();
return {dimensions.begin(), dimensions.end()};
}
virtual std::vector<int64_t> byte_strides() const = 0;

virtual xla::PrimitiveType primitive_type() const {
return shape().element_type();
}
const std::string& device() const { return device_; }

private:
std::string device_;
};

class AtenSource : public TensorSource {
public:
AtenSource(const at::Tensor& tensor, xla::Shape shape, std::string device)
: TensorSource(std::move(device)), shape_(std::move(shape)) {
AtenSource(const at::Tensor& tensor, xla::PrimitiveType target_type, std::string device)
: TensorSource(std::move(device)), target_type_(target_type_) {
at::ScalarType target_torch_type = TorchTypeFromXlaType(primitive_type());
if (target_torch_type != tensor.type().scalarType()) {
TORCH_LAZY_COUNTER("AtenSourceDowncasts", 1);
Expand All @@ -99,7 +87,12 @@ class AtenSource : public TensorSource {

const void* data() const override { return tensor_.const_data_ptr(); }

const xla::Shape& shape() const override { return shape_; }
xla::PrimitiveType primitive_type() const override { return target_type_; }

std::vector<int64_t> dimensions() const override {
auto sizes = tensor_.sizes();
return {sizes.begin(), sizes.end()};
}

std::vector<int64_t> byte_strides() const override {
std::vector<int64_t> strides;
Expand All @@ -109,14 +102,9 @@ class AtenSource : public TensorSource {
return strides;
}

std::vector<int64_t> dimensions() const override {
auto sizes = tensor_.sizes();
return {sizes.begin(), sizes.end()};
}

private:
at::Tensor tensor_;
xla::Shape shape_;
xla::PrimitiveType target_type_;
};

class LiteralSource : public TensorSource {
Expand All @@ -126,7 +114,23 @@ class LiteralSource : public TensorSource {

const void* data() const override { return literal_.untyped_data(); }

const xla::Shape& shape() const override { return literal_.shape(); }
const xla::Shape& shape() const { return literal_.shape(); }

xla::PrimitiveType primitive_type() const override {
return shape().element_type();
}

std::vector<int64_t> dimensions() const override {
auto dimensions = shape().dimensions();
return {dimensions.begin(), dimensions.end()};
}

std::vector<int64_t> byte_strides() const override {
std::vector<int64_t> byte_strides(shape().dimensions_size());
XLA_CHECK_OK(
xla::ShapeUtil::ByteStrides(shape(), absl::MakeSpan(byte_strides)));
return byte_strides;
}

private:
xla::Literal literal_;
Expand Down
8 changes: 3 additions & 5 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ torch::lazy::BackendDataPtr TensorToXlaData(

std::vector<std::shared_ptr<const runtime::TensorSource>> source_tensors;
source_tensors.push_back(
std::make_shared<runtime::AtenSource>(tensor, shape, device.toString()));
std::make_shared<runtime::AtenSource>(tensor, shape.element_type(), device.toString()));

auto handles =
runtime::GetComputationClient()->TransferToServer(source_tensors);
Expand Down Expand Up @@ -705,9 +705,8 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
std::vector<std::shared_ptr<const runtime::TensorSource>> source_tensors;
for (size_t i = 0; i < tensors.size(); ++i) {
torch::lazy::BackendDevice device = ParseDeviceString(devices[i]);
xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device);
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
tensors[i], std::move(shape), devices[i]));
tensors[i], MaybeDowncastForDevice(tensors[i].type().scalarType(), device), devices[i]));
}
return WrapXlaData(
runtime::GetComputationClient()->TransferToServer(source_tensors));
Expand All @@ -724,7 +723,6 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
std::vector<runtime::ComputationClient::DataPtr> handles;
for (size_t i = 0; i < tensors.size(); ++i) {
torch::lazy::BackendDevice device = ParseDeviceString(devices[i]);
xla::Shape shape = CreateComputationShapeFromTensor(tensors[i], &device);

std::vector<std::shared_ptr<const runtime::TensorSource>>
source_tensors; // in
Expand All @@ -744,7 +742,7 @@ std::vector<torch::lazy::BackendDataPtr> CreateTensorsData(
local_shards, local_devices, shardings[i]));
} else {
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
tensors[i], std::move(shape), devices[i]));
tensors[i], MaybeDowncastForDevice(tensors[i].type().scalarType(), device), devices[i]));
new_handles =
runtime::GetComputationClient()->TransferToServer(source_tensors);
}
Expand Down
4 changes: 1 addition & 3 deletions torch_xla/csrc/xla_sharding_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,10 +727,8 @@ runtime::ComputationClient::DataPtr ShardingUtil::CreateShardedData(
}
for (int64_t j = 0; j < devices.size(); ++j) {
auto shard_device = ParseDeviceString(devices[j]);
auto shard_shape =
CreateComputationShapeFromTensor(local_shards[j], &shard_device);
source_tensors.push_back(std::make_shared<runtime::AtenSource>(
local_shards[j], shard_shape, devices[j]));
local_shards[j], MaybeDowncastForDevice(local_shards[j].type().scalarType(), shard_device), devices[j]));
}
return runtime::GetComputationClient()->TransferShardsToServer(
source_tensors, GetVirtualDevice().toString(), global_shape, sharding);
Expand Down

0 comments on commit 4225deb

Please sign in to comment.