diff --git a/setup.py b/setup.py index cb3e0fe7f30..a8a04c4c286 100644 --- a/setup.py +++ b/setup.py @@ -244,9 +244,10 @@ def bazel_build(self, ext): bazel_argv = [ 'bazel', 'build', ext.bazel_target, - f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}", - '\n'.join(['--cxxopt=%s' % opt for opt in extra_compile_args]) + f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}" ] + for opt in extra_compile_args: + bazel_argv.append("--cxxopt={}".format(opt)) # Debug build. if DEBUG: diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py new file mode 100644 index 00000000000..a4223b799aa --- /dev/null +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -0,0 +1,58 @@ +import sys +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from torch_xla.stablehlo import exported_program_to_stablehlo + +# Note: Unbounded dynamism is under development. It works with unmerged +# XLA changes. Experimental XLA branch: https://github.com/lsy323/openxla-xla/tree/lsiyuan/sandeep-dynamism-rebased + +device = xm.xla_device() + + +class UnboundedDynamismExportTest(unittest.TestCase): + + def test_simply_add(self): + a = torch.tensor([[1, 2], [2, 4]], device=device) + torch_xla._XLAC._xla_mark_dynamic(a, 0) + b = torch.tensor([[1, 2], [2, 4]], device=device) + torch_xla._XLAC._xla_mark_dynamic(b, 0) + c = a * b + hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c]) + self.assertTrue( + "(p0.1: s64[?,2], p1.2: s64[?,2]) -> (s64[?,2])" in hlo_content) + + def test_export_dynamism(self): + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x * y + + example_args = (torch.tensor([[1, 2], [2, 4]], device=device), + torch.tensor([[1, 2], [2, 4]], device=device)) + constraints = [ + # First dimension of each input is a dynamic batch size + torch.export.dynamic_dim(example_args[0], 0), + torch.export.dynamic_dim(example_args[1], 0), + # The dynamic batch size between the inputs are equal + torch.export.dynamic_dim(example_args[0], + 0) == torch.export.dynamic_dim( + example_args[1], 0), + ] + ep = torch.export.export(M(), args=example_args, constraints=constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text("forward") + self.assertTrue( + "(%arg0: tensor, %arg1: tensor) -> tensor" in + shlo_text) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 88ce96cab99..62a906d98ea 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -5,6 +5,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/random.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_lower_util.h" @@ -66,8 +67,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output, xla::XlaOp BuildRelu(xla::XlaOp input) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - return xla::Max(input, XlaHelpers::ScalarValue( - 0, input_shape.element_type(), input.builder())); + xla::XlaOp scalar = XlaHelpers::ScalarValue( + 0, input_shape.element_type(), input.builder()); + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + // xla::Max doesn't do implicit broadcasting for unbounded dynamism now. + // TODO(lsy323): Remove this branch once the support is added in XLA. + auto promoted = XlaHelpers::Promote(input, scalar); + return xla::Max(promoted.first, promoted.second); + } else { + return xla::Max(input, scalar); + } } xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) { diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index af9ff6ba49b..995e43078b6 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -9,7 +9,6 @@ #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/shape_helper.h" @@ -21,6 +20,9 @@ namespace torch_xla { namespace { +// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. +static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); + xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, xla::XlaOp result) { xla::PrimitiveType type1 = XlaHelpers::TypeOfXlaOp(op1); @@ -63,6 +65,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input, std::vector bcast_sizes = SizesOfXlaOp(input); for (size_t i = 0; i < dimensions.size(); ++i) { bcast_sizes.at(dimensions[i]) = sizes[i]; + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + XLA_CHECK(sizes[i] != kUnboundedSize); + } } return xla::BroadcastInDim(input, bcast_sizes, GetAllDimensions(bcast_sizes.size())); @@ -322,6 +327,59 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, : xla::Reshape(input, shape.dimensions()); } +bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { + XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) + << "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded " + "dynamism workload."; + const absl::Span dims = shape.dimensions(); + return std::any_of(dims.begin(), dims.end(), + [](int64_t size) { return size == kUnboundedSize; }); +} + +xla::XlaOp XlaHelpers::DynamicUnboundedReshape( + xla::XlaOp input, xla::XlaOp aux_input, + absl::Span output_sizes) { + XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) + << "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded " + "dynamism workload."; + const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); + XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) + << "XlaHelpers::DynamicUnboundedReshape constrainled failed!"; + std::vector get_dim_ops; + std::vector reshaped_ops; + bool all_static = true; + std::vector output_dynamic(output_sizes.size(), false); + + for (int i = 0; i < output_sizes.size(); i++) { + if (output_sizes[i] == kUnboundedSize) { + output_dynamic[i] = true; + get_dim_ops.push_back(xla::GetDimensionSize(aux_input, i)); + all_static = false; + } else { + get_dim_ops.push_back(XlaHelpers::ScalarValue( + output_sizes[i], aux_input.builder())); + } + } + + if (all_static) { + return xla::Reshape(input, output_sizes); + } + + // Create the reshape from scalar to 1-D vector + for (auto get_dim_op : get_dim_ops) { + reshaped_ops.push_back(xla::Reshape(get_dim_op, {1})); + } + + // Create Concatenate op + auto concat_op = xla::ConcatInDim(input.builder(), reshaped_ops, {0}); + return xla::CustomCall( + aux_input.builder(), "stablehlo.dynamic_reshape", {input, concat_op}, + xla::ShapeUtil::MakeShape(aux_input_shape.element_type(), output_sizes, + output_dynamic)); + + return input; +} + bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1, const xla::Shape& shape2) { return shape1.is_static() && shape2.is_static() && @@ -485,6 +543,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, runtime::util::ToVector(shape1.dimensions()), runtime::util::ToVector(shape2.dimensions()))); } + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && + !XlaHelpers::IsUnboundedDynamic(shape2)) + << "Unreachable for unbounded dynamic code\n"; + } return GetPromotedDynamicShape(shape1, shape2); } diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 817566159ed..66c01588b57 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -13,6 +13,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/util.h" #include "tsl/platform/bfloat16.h" #include "xla/client/xla_builder.h" @@ -158,6 +159,17 @@ class XlaHelpers { static xla::XlaOp DynamicReshape(xla::XlaOp input, absl::Span output_sizes); + static bool IsUnboundedDynamic(const xla::Shape& shape); + + static bool IsUnboundedDynamismEnabled() { + return runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", + false); + } + + static xla::XlaOp DynamicUnboundedReshape( + xla::XlaOp input, xla::XlaOp aux_input, + absl::Span output_sizes); + static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape); static bool SameStaticDimensions(const xla::Shape& shape1, diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b81f0978d27..be0cbad991f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1968,6 +1968,12 @@ void InitXlaModuleBindings(py::module m) { return handles; }); + m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) { + TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + xtensor->MarkDynamicDimension(dim); + }); + // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 0fc2d77a47f..82b746ab181 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -174,6 +174,8 @@ xla::Shape XlaNode::GetOpShape( std::string XlaNode::ToString() const { std::stringstream ss; ss << torch::lazy::Node::ToString() << ", xla_shape=" << xla_shape_; + ss << ", dynamic_dims: (" << absl::StrJoin(unbounded_dynamic_dims_, ", ") + << ')'; return ss.str(); } diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index c63fe289b9d..d0619ef5c98 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -9,9 +9,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -138,6 +138,17 @@ class XlaNode : public torch::lazy::Node { std::string ToString() const override; + void MarkDynamicDimension(uint32_t dim) { + unbounded_dynamic_dims_.insert(dim); + } + + const std::unordered_set& dynamic_dims() const { + return unbounded_dynamic_dims_; + } + + protected: + std::unordered_set unbounded_dynamic_dims_; + private: xla::Shape GetOpShape(const std::function& shape_fn) const; diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index dbb1fd69cb3..404fa82ea7b 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -93,19 +93,31 @@ LoweringContext::LoweringContext( } } +// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. +static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); + xla::XlaOp LoweringContext::GetParameter( - const std::shared_ptr& data) { + const std::shared_ptr& data, + const std::unordered_set& unbounded_dynamic_dims) { torch::lazy::BackendData::Handle handle = data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { - xla::XlaOp param = xla::Parameter( - builder(), parameters_.size(), + xla::Shape shape = std::dynamic_pointer_cast(data) - ->shape(), - absl::StrCat("p", parameters_.size())); + ->shape(); + for (const int dim : unbounded_dynamic_dims) { + shape.set_dynamic_dimension(dim, true); + shape.set_dimensions(dim, kUnboundedSize); + } + xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape, + absl::StrCat("p", parameters_.size())); it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) .first; parameters_.push_back(data); + } else { + XLA_CHECK(unbounded_dynamic_dims.empty()) + << "The unbounded dynamic dims can only be set when Parameter is " + "created."; } parameter_sequence_.push_back(it->second.index); return it->second.param; @@ -170,6 +182,22 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { const XlaNode* casted = dynamic_cast(node); result_ops = casted->Lower(this); + if (!casted->dynamic_dims().empty()) { + xla::internal::XlaBuilderFriend builder_friend; + auto* inst = builder_friend.GetInstruction(result_ops[0]); + auto* mutable_dynamic = + inst->mutable_shape()->mutable_is_dynamic_dimension(); + if (mutable_dynamic->empty()) { + for (int i = 0; i < inst->dimensions_size(); i++) { + mutable_dynamic->Add(false); + } + } + auto* mutable_dims = inst->mutable_shape()->mutable_dimensions(); + for (const auto dim : casted->dynamic_dims()) { + mutable_dynamic->Set(dim, true); + mutable_dims->Set(dim, kUnboundedSize); + } + } } catch (const std::exception& ex) { ReportBuilderError(node, ex.what()); } diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 76684326326..b46d91874b0 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -37,7 +37,8 @@ class LoweringContext : public torch::lazy::LoweringContext { // returned. Otherwise a new one will be created, associated with the tensor // held in data. xla::XlaOp GetParameter( - const std::shared_ptr& data); + const std::shared_ptr& data, + const std::unordered_set& dynamic_dims = {}); // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index 07956843a7d..e07fe3c4e76 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -36,7 +36,7 @@ torch::lazy::NodePtr DeviceData::Clone(torch::lazy::OpList operands) const { } XlaOpVector DeviceData::Lower(LoweringContext* loctx) const { - return ReturnOp(loctx->GetParameter(data_), loctx); + return ReturnOp(loctx->GetParameter(data_, unbounded_dynamic_dims_), loctx); } DeviceData* DeviceData::Cast(const torch::lazy::Node* node) { diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 1b0d47a3735..aff46743410 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -81,7 +81,15 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count, xla::XlaOp scale = xla::Select(xla::Ne(count, zero), one / xla::ConvertElementType(count, type), xla::NanValue(input.builder(), type)); - return input * scale; + + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + // XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now. + // TODO(lsy323): Remove this branch once the support is added in XLA. + auto promoted = XlaHelpers::Promote(input, scale); + return promoted.first * promoted.second; + } else { + return input * scale; + } } xla::XlaOp AverageValue(xla::XlaOp input, xla::XlaOp reduced) { @@ -109,8 +117,15 @@ SummationResult CreateSummation(xla::XlaOp input, result.result, result.rinfo.element_count.size, shape.element_type()); } if (keep_reduced_dimensions) { - result.result = - XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions); + if (XlaHelpers::IsUnboundedDynamismEnabled()) { + // TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is + // added. + result.result = XlaHelpers::DynamicUnboundedReshape( + result.result, input, result.rinfo.new_dimensions); + } else { + result.result = XlaHelpers::DynamicReshape(result.result, + result.rinfo.new_dimensions); + } } return result; } diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 6f334c76894..672f7fe0c8e 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -893,4 +893,9 @@ int64_t XLATensor::GetHandle() const { } } +void XLATensor::MarkDynamicDimension(uint32_t dim) { + auto* xla_node = dynamic_cast(GetIrValue().node.get()); + xla_node->MarkDynamicDimension(dim); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 8564729bb71..f73aed5ce5f 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -201,6 +201,7 @@ class XLATensor : public torch::lazy::LazyTensor { // Set logical_element_type which is visible to upstream PyTorch. void SetScalarType(c10::optional logical_element_type); + void MarkDynamicDimension(uint32_t dim); // We don't use the upstream shape to provide xla::shape. runtime::util::MaybeRef shape() const; diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index e6916e08fb8..407cfeb972f 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -212,6 +212,16 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: new_kwargs['device'] = self._device return super().call_function(target, args, new_kwargs) + def run_node(self, n) -> Any: + if n.op == 'placeholder': + fake_t = n.meta['val'] + res = super().run_node(n) + for i, x in enumerate(fake_t.shape): + if not isinstance(x, int): + torch_xla._XLAC._xla_mark_dynamic(res, i) + return res + return super().run_node(n) + def _extract_input_args(exported_model, options): if options.override_tracing_arguments is not None: