Skip to content

Commit

Permalink
Enable passing down dynamic dimensions from torch to XLA (#5790)
Browse files Browse the repository at this point in the history
* port sandeep unbounded dynamism change
* Enable unbounded dynamism using env var, add more guards for unbounded dynamism code path

---------

Co-authored-by: Siyuan Liu <lsiyuan@google.coim>
  • Loading branch information
2 people authored and bhavya01 committed Apr 22, 2024
1 parent a565ef9 commit c7f9492
Show file tree
Hide file tree
Showing 15 changed files with 238 additions and 16 deletions.
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,10 @@ def bazel_build(self, ext):

bazel_argv = [
'bazel', 'build', ext.bazel_target,
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}",
'\n'.join(['--cxxopt=%s' % opt for opt in extra_compile_args])
f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}"
]
for opt in extra_compile_args:
bazel_argv.append("--cxxopt={}".format(opt))

# Debug build.
if DEBUG:
Expand Down
58 changes: 58 additions & 0 deletions test/stablehlo/test_unbounded_dynamism.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sys
import unittest

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from torch_xla.stablehlo import exported_program_to_stablehlo

# Note: Unbounded dynamism is under development. It works with unmerged
# XLA changes. Experimental XLA branch: https://github.com/lsy323/openxla-xla/tree/lsiyuan/sandeep-dynamism-rebased

device = xm.xla_device()


class UnboundedDynamismExportTest(unittest.TestCase):

def test_simply_add(self):
a = torch.tensor([[1, 2], [2, 4]], device=device)
torch_xla._XLAC._xla_mark_dynamic(a, 0)
b = torch.tensor([[1, 2], [2, 4]], device=device)
torch_xla._XLAC._xla_mark_dynamic(b, 0)
c = a * b
hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c])
self.assertTrue(
"(p0.1: s64[?,2], p1.2: s64[?,2]) -> (s64[?,2])" in hlo_content)

def test_export_dynamism(self):

class M(torch.nn.Module):

def __init__(self):
super().__init__()

def forward(self, x, y):
return x * y

example_args = (torch.tensor([[1, 2], [2, 4]], device=device),
torch.tensor([[1, 2], [2, 4]], device=device))
constraints = [
# First dimension of each input is a dynamic batch size
torch.export.dynamic_dim(example_args[0], 0),
torch.export.dynamic_dim(example_args[1], 0),
# The dynamic batch size between the inputs are equal
torch.export.dynamic_dim(example_args[0],
0) == torch.export.dynamic_dim(
example_args[1], 0),
]
ep = torch.export.export(M(), args=example_args, constraints=constraints)
shlo_module = exported_program_to_stablehlo(ep)
shlo_text = shlo_module.get_stablehlo_text("forward")
self.assertTrue(
"(%arg0: tensor<?x2xi64>, %arg1: tensor<?x2xi64>) -> tensor<?x2xi64>" in
shlo_text)


if __name__ == '__main__':
test = unittest.main()
sys.exit(0 if test.result.wasSuccessful() else 1)
13 changes: 11 additions & 2 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/random.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_lower_util.h"
Expand Down Expand Up @@ -66,8 +67,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output,

xla::XlaOp BuildRelu(xla::XlaOp input) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
return xla::Max(input, XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder()));
xla::XlaOp scalar = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder());
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// xla::Max doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scalar);
return xla::Max(promoted.first, promoted.second);
} else {
return xla::Max(input, scalar);
}
}

xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) {
Expand Down
65 changes: 64 additions & 1 deletion torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/shape_helper.h"
Expand All @@ -21,6 +20,9 @@
namespace torch_xla {
namespace {

// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2,
xla::XlaOp result) {
xla::PrimitiveType type1 = XlaHelpers::TypeOfXlaOp(op1);
Expand Down Expand Up @@ -63,6 +65,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input,
std::vector<int64_t> bcast_sizes = SizesOfXlaOp(input);
for (size_t i = 0; i < dimensions.size(); ++i) {
bcast_sizes.at(dimensions[i]) = sizes[i];
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
XLA_CHECK(sizes[i] != kUnboundedSize);
}
}
return xla::BroadcastInDim(input, bcast_sizes,
GetAllDimensions(bcast_sizes.size()));
Expand Down Expand Up @@ -322,6 +327,59 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
: xla::Reshape(input, shape.dimensions());
}

bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded "
"dynamism workload.";
const absl::Span<const int64_t> dims = shape.dimensions();
return std::any_of(dims.begin(), dims.end(),
[](int64_t size) { return size == kUnboundedSize; });
}

xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes) {
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded "
"dynamism workload.";
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
XLA_CHECK(output_sizes.size() == aux_input_shape.rank())
<< "XlaHelpers::DynamicUnboundedReshape constrainled failed!";
std::vector<xla::XlaOp> get_dim_ops;
std::vector<xla::XlaOp> reshaped_ops;
bool all_static = true;
std::vector<bool> output_dynamic(output_sizes.size(), false);

for (int i = 0; i < output_sizes.size(); i++) {
if (output_sizes[i] == kUnboundedSize) {
output_dynamic[i] = true;
get_dim_ops.push_back(xla::GetDimensionSize(aux_input, i));
all_static = false;
} else {
get_dim_ops.push_back(XlaHelpers::ScalarValue<int32_t>(
output_sizes[i], aux_input.builder()));
}
}

if (all_static) {
return xla::Reshape(input, output_sizes);
}

// Create the reshape from scalar to 1-D vector
for (auto get_dim_op : get_dim_ops) {
reshaped_ops.push_back(xla::Reshape(get_dim_op, {1}));
}

// Create Concatenate op
auto concat_op = xla::ConcatInDim(input.builder(), reshaped_ops, {0});
return xla::CustomCall(
aux_input.builder(), "stablehlo.dynamic_reshape", {input, concat_op},
xla::ShapeUtil::MakeShape(aux_input_shape.element_type(), output_sizes,
output_dynamic));

return input;
}

bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1,
const xla::Shape& shape2) {
return shape1.is_static() && shape2.is_static() &&
Expand Down Expand Up @@ -485,6 +543,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
runtime::util::ToVector<int64_t>(shape1.dimensions()),
runtime::util::ToVector<int64_t>(shape2.dimensions())));
}
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) &&
!XlaHelpers::IsUnboundedDynamic(shape2))
<< "Unreachable for unbounded dynamic code\n";
}
return GetPromotedDynamicShape(shape1, shape2);
}

Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/util.h"
#include "tsl/platform/bfloat16.h"
#include "xla/client/xla_builder.h"
Expand Down Expand Up @@ -158,6 +159,17 @@ class XlaHelpers {
static xla::XlaOp DynamicReshape(xla::XlaOp input,
absl::Span<const int64_t> output_sizes);

static bool IsUnboundedDynamic(const xla::Shape& shape);

static bool IsUnboundedDynamismEnabled() {
return runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM",
false);
}

static xla::XlaOp DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes);

static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape);

static bool SameStaticDimensions(const xla::Shape& shape1,
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1936,6 +1936,12 @@ void InitXlaModuleBindings(py::module m) {
return handles;
});

m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) {
TORCH_LAZY_COUNTER("XlaMarkDynamic", 1);
XLATensorPtr xtensor = bridge::GetXlaTensor(input);
xtensor->MarkDynamicDimension(dim);
});

// -------------Dynamo Integration API Start-------------------------
/*
* Return tensor ids and at::tensors for all DeviceData nodes that is needed
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ xla::Shape XlaNode::GetOpShape(
std::string XlaNode::ToString() const {
std::stringstream ss;
ss << torch::lazy::Node::ToString() << ", xla_shape=" << xla_shape_;
ss << ", dynamic_dims: (" << absl::StrJoin(unbounded_dynamic_dims_, ", ")
<< ')';
return ss.str();
}

Expand Down
13 changes: 12 additions & 1 deletion torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include <functional>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -138,6 +138,17 @@ class XlaNode : public torch::lazy::Node {

std::string ToString() const override;

void MarkDynamicDimension(uint32_t dim) {
unbounded_dynamic_dims_.insert(dim);
}

const std::unordered_set<uint32_t>& dynamic_dims() const {
return unbounded_dynamic_dims_;
}

protected:
std::unordered_set<uint32_t> unbounded_dynamic_dims_;

private:
xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;

Expand Down
38 changes: 33 additions & 5 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,31 @@ LoweringContext::LoweringContext(
}
}

// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

xla::XlaOp LoweringContext::GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data) {
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::unordered_set<uint32_t>& unbounded_dynamic_dims) {
torch::lazy::BackendData::Handle handle = data->GetHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
xla::XlaOp param = xla::Parameter(
builder(), parameters_.size(),
xla::Shape shape =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data)
->shape(),
absl::StrCat("p", parameters_.size()));
->shape();
for (const int dim : unbounded_dynamic_dims) {
shape.set_dynamic_dimension(dim, true);
shape.set_dimensions(dim, kUnboundedSize);
}
xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape,
absl::StrCat("p", parameters_.size()));
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
.first;
parameters_.push_back(data);
} else {
XLA_CHECK(unbounded_dynamic_dims.empty())
<< "The unbounded dynamic dims can only be set when Parameter is "
"created.";
}
parameter_sequence_.push_back(it->second.index);
return it->second.param;
Expand Down Expand Up @@ -170,6 +182,22 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) {

const XlaNode* casted = dynamic_cast<const XlaNode*>(node);
result_ops = casted->Lower(this);
if (!casted->dynamic_dims().empty()) {
xla::internal::XlaBuilderFriend builder_friend;
auto* inst = builder_friend.GetInstruction(result_ops[0]);
auto* mutable_dynamic =
inst->mutable_shape()->mutable_is_dynamic_dimension();
if (mutable_dynamic->empty()) {
for (int i = 0; i < inst->dimensions_size(); i++) {
mutable_dynamic->Add(false);
}
}
auto* mutable_dims = inst->mutable_shape()->mutable_dimensions();
for (const auto dim : casted->dynamic_dims()) {
mutable_dynamic->Set(dim, true);
mutable_dims->Set(dim, kUnboundedSize);
}
}
} catch (const std::exception& ex) {
ReportBuilderError(node, ex.what());
}
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class LoweringContext : public torch::lazy::LoweringContext {
// returned. Otherwise a new one will be created, associated with the tensor
// held in data.
xla::XlaOp GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data);
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::unordered_set<uint32_t>& dynamic_dims = {});

// Retrieves the vector holding all the tensors associated with the parameter
// instructions which have been created.
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ torch::lazy::NodePtr DeviceData::Clone(torch::lazy::OpList operands) const {
}

XlaOpVector DeviceData::Lower(LoweringContext* loctx) const {
return ReturnOp(loctx->GetParameter(data_), loctx);
return ReturnOp(loctx->GetParameter(data_, unbounded_dynamic_dims_), loctx);
}

DeviceData* DeviceData::Cast(const torch::lazy::Node* node) {
Expand Down
21 changes: 18 additions & 3 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,15 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count,
xla::XlaOp scale = xla::Select(xla::Ne(count, zero),
one / xla::ConvertElementType(count, type),
xla::NanValue(input.builder(), type));
return input * scale;

if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scale);
return promoted.first * promoted.second;
} else {
return input * scale;
}
}

xla::XlaOp AverageValue(xla::XlaOp input, xla::XlaOp reduced) {
Expand Down Expand Up @@ -109,8 +117,15 @@ SummationResult CreateSummation(xla::XlaOp input,
result.result, result.rinfo.element_count.size, shape.element_type());
}
if (keep_reduced_dimensions) {
result.result =
XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions);
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is
// added.
result.result = XlaHelpers::DynamicUnboundedReshape(
result.result, input, result.rinfo.new_dimensions);
} else {
result.result = XlaHelpers::DynamicReshape(result.result,
result.rinfo.new_dimensions);
}
}
return result;
}
Expand Down
5 changes: 5 additions & 0 deletions torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -892,4 +892,9 @@ int64_t XLATensor::GetHandle() const {
}
}

void XLATensor::MarkDynamicDimension(uint32_t dim) {
auto* xla_node = dynamic_cast<XlaNode*>(GetIrValue().node.get());
xla_node->MarkDynamicDimension(dim);
}

} // namespace torch_xla
Loading

0 comments on commit c7f9492

Please sign in to comment.