From 56f698d68ab3db40ad9eebe79ee0a0f2e491c869 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 11 Nov 2023 03:26:58 +0000 Subject: [PATCH 01/10] port sandeep unbounded dynamism change --- setup.py | 8 +- torch_xla/csrc/elementwise.cpp | 8 + torch_xla/csrc/helpers.cpp | 251 ++++++++++++++++++++++++ torch_xla/csrc/helpers.h | 23 +++ torch_xla/csrc/init_python_bindings.cpp | 14 ++ torch_xla/csrc/ir.cpp | 5 + torch_xla/csrc/ir.h | 11 ++ torch_xla/csrc/lowering_context.cpp | 31 ++- torch_xla/csrc/lowering_context.h | 3 +- torch_xla/csrc/ops/device_data.cpp | 2 +- torch_xla/csrc/reduction.cpp | 10 + torch_xla/csrc/tensor.cpp | 11 ++ torch_xla/csrc/tensor.h | 2 + torch_xla/stablehlo.py | 9 + 14 files changed, 379 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index cb3e0fe7f30..b299caca482 100644 --- a/setup.py +++ b/setup.py @@ -212,8 +212,11 @@ def run(self): extra_compile_args = [] cxx_abi = os.getenv( 'CXX_ABI', default='') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None) +experimental_dynamism = os.getenv('EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM', default=None) if cxx_abi is not None: extra_compile_args.append(f'-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}') +if experimental_dynamism is not None: + extra_compile_args.append(f'-DEXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM={experimental_dynamism}') class BazelExtension(Extension): @@ -244,9 +247,10 @@ def bazel_build(self, ext): bazel_argv = [ 'bazel', 'build', ext.bazel_target, - f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}", - '\n'.join(['--cxxopt=%s' % opt for opt in extra_compile_args]) + f"--symlink_prefix={os.path.join(self.build_temp, 'bazel-')}" ] + for opt in extra_compile_args: + bazel_argv.append("--cxxopt={}".format(opt)) # Debug build. if DEBUG: diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 88ce96cab99..0e742e2c81e 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -66,8 +66,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output, xla::XlaOp BuildRelu(xla::XlaOp input) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); +#ifndef EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM return xla::Max(input, XlaHelpers::ScalarValue( 0, input_shape.element_type(), input.builder())); +#else + xla::XlaOp scalar = XlaHelpers::ScalarValue( + 0, input_shape.element_type(), input.builder()); + auto promoted = XlaHelpers::Promote(input, scalar); + + return xla::Max(promoted.first, promoted.second); +#endif } xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) { diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index af9ff6ba49b..a9003f5540a 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -21,6 +21,10 @@ namespace torch_xla { namespace { +#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM +static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); +#endif + xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, xla::XlaOp result) { xla::PrimitiveType type1 = XlaHelpers::TypeOfXlaOp(op1); @@ -63,6 +67,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input, std::vector bcast_sizes = SizesOfXlaOp(input); for (size_t i = 0; i < dimensions.size(); ++i) { bcast_sizes.at(dimensions[i]) = sizes[i]; +#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM + XLA_CHECK(sizes[i] != kUnboundedSize); +#endif } return xla::BroadcastInDim(input, bcast_sizes, GetAllDimensions(bcast_sizes.size())); @@ -322,6 +329,116 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, : xla::Reshape(input, shape.dimensions()); } +#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM + +bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { + const absl::Span dims = shape.dimensions(); + return std::any_of(dims.begin(), dims.end(), + [](int64_t size) { return size == kUnboundedSize; }); +} + +xla::XlaOp XlaHelpers::DynamicUnboundedReshape( + xla::XlaOp input, xla::XlaOp aux_input, + absl::Span output_sizes) { + const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); + XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) + << "XlaHelpers::DynamicUnboundedReshape constrainled failed!"; + std::vector get_dim_ops; + std::vector reshaped_ops; + bool all_static = true; + std::vector output_dynamic(output_sizes.size(), false); + + for (int i = 0; i < output_sizes.size(); i++) { + if (output_sizes[i] == kUnboundedSize) { + output_dynamic[i] = true; + get_dim_ops.push_back(xla::GetDimensionSize(aux_input, i)); + all_static = false; + } else { + get_dim_ops.push_back(XlaHelpers::ScalarValue( + output_sizes[i], aux_input.builder())); + } + } + + if (all_static) { + return xla::Reshape(input, output_sizes); + } + + // Create the reshape from scalar to 1-D vector + for (auto get_dim_op : get_dim_ops) { + reshaped_ops.push_back(xla::Reshape(get_dim_op, {1})); + } + + // Create Concatenate op + auto concat_op = xla::ConcatInDim(input.builder(), reshaped_ops, {0}); + return xla::CustomCall( + aux_input.builder(), "stablehlo.dynamic_reshape", {input, concat_op}, + xla::ShapeUtil::MakeShape(aux_input_shape.element_type(), output_sizes, + output_dynamic)); + + return input; +} + +xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast( + xla::XlaOp input, xla::XlaOp aux_input, + absl::Span aux_input_dimensions) { + const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); + const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); + bool all_static = true; + std::vector output_dimensions; + std::vector output_dynamic; + for (auto dim : aux_input_dimensions) { + if (aux_input_shape.dimensions(dim) == kUnboundedSize) all_static = false; + output_dimensions.push_back(aux_input_shape.dimensions(dim)); + output_dynamic.push_back(aux_input_shape.is_dynamic_dimension(dim)); + } + + if (all_static) { + return xla::Broadcast(input, output_dimensions); + } + + std::vector get_dim_ops; + std::vector reshaped_ops; + for (auto dim : aux_input_dimensions) { + if (aux_input_shape.dimensions(dim) != kUnboundedSize) { + get_dim_ops.push_back(XlaHelpers::ScalarValue( + aux_input_shape.dimensions(dim), aux_input.builder())); + } else { + get_dim_ops.push_back(xla::GetDimensionSize(aux_input, dim)); + } + } + + for (int dim = 0; dim < input_shape.rank(); dim++) { + output_dimensions.push_back(input_shape.dimensions(dim)); + output_dynamic.push_back(input_shape.is_dynamic_dimension(dim)); + if (input_shape.dimensions(dim) != kUnboundedSize) { + get_dim_ops.push_back(XlaHelpers::ScalarValue( + input_shape.dimensions(dim), input.builder())); + } else { + get_dim_ops.push_back(xla::GetDimensionSize(input, dim)); + } + } + + // Create the reshape from scalar to 1-D vector + for (auto get_dim_op : get_dim_ops) { + reshaped_ops.push_back(xla::Reshape(get_dim_op, {1})); + } + + // Create Concatenate op + auto concat_op = xla::ConcatInDim(input.builder(), reshaped_ops, {0}); + return xla::CustomCall( + aux_input.builder(), "stablehlo.dynamic_broadcast_in_dim", + {input, concat_op}, + xla::ShapeUtil::MakeShape(input_shape.element_type(), output_dimensions, + output_dynamic)); +} + +void XlaHelpers::PrintXlaOp(xla::XlaOp op, const std::string& msg) { + std::cout << "Handle: " << msg << ": " << op << "\n"; + const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(op); + std::cout << xla::ShapeUtil::HumanString(shape); +} +#endif + bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1, const xla::Shape& shape2) { return shape1.is_static() && shape2.is_static() && @@ -485,6 +602,10 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, runtime::util::ToVector(shape1.dimensions()), runtime::util::ToVector(shape2.dimensions()))); } +#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM + XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && !XlaHelpers::IsUnboundedDynamic(shape2)) + << "Unreachable for unbounded dynamic code\n"; +#endif return GetPromotedDynamicShape(shape1, shape2); } @@ -578,6 +699,136 @@ std::pair XlaHelpers::PromoteSecond(xla::XlaOp op1, return PromoteShapes(vops.first, vops.second); } +#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM +xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( + xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op, + const xla::Shape& shape) { + XLA_CHECK(XlaHelpers::IsUnboundedDynamic(shape) || XlaHelpers::IsUnboundedDynamic(op_shape)); + + const xla::Shape& aux_shape = ShapeHelper::ShapeOfXlaOp(aux_op); + const auto& op_shape_dims = op_shape.dimensions(); + const auto& aux_shape_dims = aux_shape.dimensions(); + const auto& shape_dims = shape.dimensions(); + + XLA_CHECK_GE(shape_dims.size(), op_shape_dims.size()) + << shape << " vs " << op_shape; + + int64_t size_delta = shape_dims.size() - op_shape_dims.size(); + xla::XlaOp new_op = op; + std::vector get_dim_ops; + std::vector reshaped_ops; + + if (size_delta > 0) { + std::cout << "\t size_delta > 0\n"; + std::vector broadcast_sizes(shape_dims.begin(), + shape_dims.begin() + size_delta); + for (int i = 0; i < size_delta; i++) { + if (broadcast_sizes[i] != kUnboundedSize) { + get_dim_ops.push_back( + XlaHelpers::ScalarValue(broadcast_sizes[i], op.builder())); + + auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); + std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) + << " for size: " << broadcast_sizes[i] << "\n"; + } else { + get_dim_ops.push_back(xla::GetDimensionSize(aux_op, i)); + + auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); + std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) + << " for size: ? of index: " << i << "\n"; + } + } + } + + if (size_delta == 0) { + std::cout << "\t size_delta == 0\n"; + int sz = op_shape_dims.size() - aux_shape_dims.size(); + std::vector broadcast_sizes(shape_dims.begin(), + shape_dims.begin() + sz); + for (int i = 0; i < sz; i++) { + if (broadcast_sizes[i] != kUnboundedSize) { + get_dim_ops.push_back( + XlaHelpers::ScalarValue(broadcast_sizes[i], op.builder())); + + auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); + std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) + << " for size: " << broadcast_sizes[i] << "\n"; + } else { + get_dim_ops.push_back(xla::GetDimensionSize(op, i)); + + auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); + std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) + << " for size: ? of index: " << i << "\n"; + } + } + } + + size_t min_size = std::min(op_shape_dims.size(), aux_shape_dims.size()); + for (int i = 0; i < min_size; i++) { + int op_shape_index = op_shape_dims.size() - min_size + i; + int aux_op_shape_index = aux_shape_dims.size() - min_size + i; + int shape_index = shape_dims.size() - min_size + i; + + int64_t op_shape_dim = op_shape_dims[op_shape_index]; + int64_t aux_op_shape_dim = aux_shape_dims[aux_op_shape_index]; + int64_t shape_dim = shape_dims[shape_index]; + + // op_shape aux_op_shape shape + // 1 X X + // X 1 X + // X X X + // 1 ? ? (from aux_op_shape) + // ? 1 ? (from op_shape) + // X ? X + // ? X X + // ? ? ? (from any, let's select op_shape) + // where X != kUnboundedSize && X != 1 + if (shape_dim != kUnboundedSize) { + get_dim_ops.push_back(ScalarValue(shape_dim, op.builder())); + + auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); + std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) + << " for size: " << shape_dim << "\n"; + + } else if (op_shape_dim == 1 || aux_op_shape_dim == 1) { + if (op_shape_dim == 1) { + get_dim_ops.push_back( + xla::GetDimensionSize(aux_op, aux_op_shape_index)); + + auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); + std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) + << " for size: ? of index: " << aux_op_shape_index << "\n"; + + } else { + get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index)); + + auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); + std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) + << " for size: ? of index: " << op_shape_index << "\n"; + } + } else { + get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index)); + + auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); + std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) + << " for size: ? of index: " << op_shape_index << "\n"; + } + } + + // Create the reshape from scalar to 1-D vector + for (auto get_dim_op : get_dim_ops) { + reshaped_ops.push_back(xla::Reshape(get_dim_op, {1})); + } + + // Create Concatenate op + auto concat_op = xla::ConcatInDim(op.builder(), reshaped_ops, {0}); + new_op = xla::CustomCall(op.builder(), "stablehlo.dynamic_broadcast_in_dim", + {op, concat_op}, shape); + + return new_op; +} +#endif + xla::XlaOp XlaHelpers::ImplicitBroadcast(xla::XlaOp op, const xla::Shape& op_shape, const xla::Shape& shape) { diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 817566159ed..51dee64573b 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -158,6 +158,20 @@ class XlaHelpers { static xla::XlaOp DynamicReshape(xla::XlaOp input, absl::Span output_sizes); +#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM + static bool IsUnboundedDynamic(const xla::Shape& shape); + + static xla::XlaOp DynamicUnboundedReshape( + xla::XlaOp input, xla::XlaOp aux_input, + absl::Span output_sizes); + + static xla::XlaOp DynamicUnboundedBroadcast( + xla::XlaOp input, xla::XlaOp aux_input, + absl::Span output_sizes); + + static void PrintXlaOp(xla::XlaOp op, const std::string& msg); +#endif + static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape); static bool SameStaticDimensions(const xla::Shape& shape1, @@ -288,6 +302,15 @@ class XlaHelpers { static xla::XlaOp ImplicitBroadcast(xla::XlaOp op, const xla::Shape& op_shape, const xla::Shape& shape); + // Returns a new operations which broadcast the input operation with unbounded + // dynamic dimensions into the shape. The op_shape is the shape of the op + // operation, while shape should be one that op is broadcast-able to (usually + // the result of a GetPromotedShape() call). If op_shape matches shape, the op + // itself is returned. + static xla::XlaOp ImplicitBroadcastWithUnboundedDynamicShapes( + xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op, + const xla::Shape& shape); + // Performs the bin_op binary operation by promoting types and shapes of the // two input operands. static xla::XlaOp PromotedBinaryOp( diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index b81f0978d27..f4dfff79be6 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1968,6 +1968,20 @@ void InitXlaModuleBindings(py::module m) { return handles; }); + m.def("_xla_mark_dynamic", + [](const at::Tensor& input, uint32_t dim) { + TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + xtensor->MarkDynamicDimension(dim); + }); + + m.def("_xla_set_tag", + [](const at::Tensor& input, const std::string& tag) { + TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + xtensor->SetTag(tag); + }); + // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 0fc2d77a47f..41d449dc55b 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -174,6 +174,11 @@ xla::Shape XlaNode::GetOpShape( std::string XlaNode::ToString() const { std::stringstream ss; ss << torch::lazy::Node::ToString() << ", xla_shape=" << xla_shape_; + ss << ", dynamic_dims: "; + for (const auto dim : dynamic_dims_) { + ss << dim; + } + ss << ", " << "tags: " << experimental_tag_; return ss.str(); } diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index c63fe289b9d..d8ea10f0540 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -138,6 +138,17 @@ class XlaNode : public torch::lazy::Node { std::string ToString() const override; + void MarkDynamicDimension(uint32_t dim) { + dynamic_dims_.push_back(dim); + } + void SetTag(const std::string& tag) { experimental_tag_ = tag; } + const std::string& experimental_tag() const { return experimental_tag_; } + const std::vector& dynamic_dims() const { return dynamic_dims_; } + + protected: + std::string experimental_tag_; + std::vector dynamic_dims_; + private: xla::Shape GetOpShape(const std::function& shape_fn) const; diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index dbb1fd69cb3..17742ff07e4 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -93,16 +93,24 @@ LoweringContext::LoweringContext( } } +// import xla::Shape.h to inlcude the following defintion. +static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); xla::XlaOp LoweringContext::GetParameter( - const std::shared_ptr& data) { + const std::shared_ptr& data, + const std::vector& dynamic_dims + ) { torch::lazy::BackendData::Handle handle = data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { + xla::Shape shape = std::dynamic_pointer_cast(data) + ->shape(); + for (const int dim : dynamic_dims) { + shape.set_dynamic_dimension(dim, true); + shape.set_dimensions(dim, kUnboundedSize); + } xla::XlaOp param = xla::Parameter( - builder(), parameters_.size(), - std::dynamic_pointer_cast(data) - ->shape(), - absl::StrCat("p", parameters_.size())); + builder(), parameters_.size(), + shape, absl::StrCat("p", parameters_.size())); it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) .first; parameters_.push_back(data); @@ -170,6 +178,19 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { const XlaNode* casted = dynamic_cast(node); result_ops = casted->Lower(this); + xla::internal::XlaBuilderFriend builder_friend; + auto* inst = builder_friend.GetInstruction(result_ops[0]); + auto* mutable_dynamic = inst->mutable_shape()->mutable_is_dynamic_dimension(); + if (mutable_dynamic->empty()) { + for (int i = 0; i < inst->dimensions_size(); i++) { + mutable_dynamic->Add(false); + } + } + auto* mutable_dims = inst->mutable_shape()->mutable_dimensions(); + for (const auto dim : casted->dynamic_dims()) { + mutable_dynamic->Set(dim, true); + mutable_dims->Set(dim, kUnboundedSize); + } } catch (const std::exception& ex) { ReportBuilderError(node, ex.what()); } diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 76684326326..0bbced7bc45 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -37,7 +37,8 @@ class LoweringContext : public torch::lazy::LoweringContext { // returned. Otherwise a new one will be created, associated with the tensor // held in data. xla::XlaOp GetParameter( - const std::shared_ptr& data); + const std::shared_ptr& data, + const std::vector& dynamic_dims={}); // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index 07956843a7d..54c455e3b66 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -36,7 +36,7 @@ torch::lazy::NodePtr DeviceData::Clone(torch::lazy::OpList operands) const { } XlaOpVector DeviceData::Lower(LoweringContext* loctx) const { - return ReturnOp(loctx->GetParameter(data_), loctx); + return ReturnOp(loctx->GetParameter(data_, dynamic_dims_), loctx); } DeviceData* DeviceData::Cast(const torch::lazy::Node* node) { diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 1b0d47a3735..a6082dbbe26 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -81,7 +81,12 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count, xla::XlaOp scale = xla::Select(xla::Ne(count, zero), one / xla::ConvertElementType(count, type), xla::NanValue(input.builder(), type)); +#if !EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM return input * scale; +#else + auto promoted = XlaHelpers::Promote(input, scale); + return promoted.first * promoted.second; +#endif } xla::XlaOp AverageValue(xla::XlaOp input, xla::XlaOp reduced) { @@ -109,8 +114,13 @@ SummationResult CreateSummation(xla::XlaOp input, result.result, result.rinfo.element_count.size, shape.element_type()); } if (keep_reduced_dimensions) { +#if !EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM result.result = XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions); +#else + result.result = XlaHelpers::DynamicUnboundedReshape(result.result, input, + result.rinfo.new_dimensions); +#endif } return result; } diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 6f334c76894..d3fa030d99e 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -893,4 +893,15 @@ int64_t XLATensor::GetHandle() const { } } +void XLATensor::MarkDynamicDimension(uint32_t dim) { + // auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); + auto* xla_node = dynamic_cast(GetIrValue().node.get()); + xla_node->MarkDynamicDimension(dim); +} + +void XLATensor::SetTag(const std::string& tag) { + auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); + xla_node->SetTag(tag); +} + } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 8564729bb71..75a15463206 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -201,6 +201,8 @@ class XLATensor : public torch::lazy::LazyTensor { // Set logical_element_type which is visible to upstream PyTorch. void SetScalarType(c10::optional logical_element_type); + void MarkDynamicDimension(uint32_t dim); + void SetTag(const std::string& tag); // We don't use the upstream shape to provide xla::shape. runtime::util::MaybeRef shape() const; diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index e6916e08fb8..de2f55de81f 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -212,6 +212,15 @@ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any: new_kwargs['device'] = self._device return super().call_function(target, args, new_kwargs) + def run_node(self, n) -> Any: + if n.op == 'placeholder': + fake_t = n.meta['val'] + res = super().run_node(n) + for i, x in enumerate(fake_t.shape): + if not isinstance(x, int): + torch_xla._XLAC._xla_mark_dynamic(res, i) + return res + return super().run_node(n) def _extract_input_args(exported_model, options): if options.override_tracing_arguments is not None: From 24f42c4858990a5ff5e3758138a15e1512aa00da Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 11 Nov 2023 03:33:44 +0000 Subject: [PATCH 02/10] format --- setup.py | 6 ++++-- torch_xla/csrc/helpers.cpp | 8 +++++--- torch_xla/csrc/init_python_bindings.cpp | 22 ++++++++++------------ torch_xla/csrc/ir.cpp | 3 ++- torch_xla/csrc/ir.h | 4 +--- torch_xla/csrc/lowering_context.cpp | 16 ++++++++-------- torch_xla/csrc/lowering_context.h | 5 ++--- torch_xla/csrc/reduction.cpp | 4 ++-- 8 files changed, 34 insertions(+), 34 deletions(-) diff --git a/setup.py b/setup.py index b299caca482..b0cca020701 100644 --- a/setup.py +++ b/setup.py @@ -212,11 +212,13 @@ def run(self): extra_compile_args = [] cxx_abi = os.getenv( 'CXX_ABI', default='') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None) -experimental_dynamism = os.getenv('EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM', default=None) +experimental_dynamism = os.getenv( + 'EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM', default=None) if cxx_abi is not None: extra_compile_args.append(f'-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}') if experimental_dynamism is not None: - extra_compile_args.append(f'-DEXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM={experimental_dynamism}') + extra_compile_args.append( + f'-DEXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM={experimental_dynamism}') class BazelExtension(Extension): diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index a9003f5540a..4f317aed2a7 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -334,7 +334,7 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { const absl::Span dims = shape.dimensions(); return std::any_of(dims.begin(), dims.end(), - [](int64_t size) { return size == kUnboundedSize; }); + [](int64_t size) { return size == kUnboundedSize; }); } xla::XlaOp XlaHelpers::DynamicUnboundedReshape( @@ -603,7 +603,8 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, runtime::util::ToVector(shape2.dimensions()))); } #if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && !XlaHelpers::IsUnboundedDynamic(shape2)) + XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && + !XlaHelpers::IsUnboundedDynamic(shape2)) << "Unreachable for unbounded dynamic code\n"; #endif return GetPromotedDynamicShape(shape1, shape2); @@ -703,7 +704,8 @@ std::pair XlaHelpers::PromoteSecond(xla::XlaOp op1, xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op, const xla::Shape& shape) { - XLA_CHECK(XlaHelpers::IsUnboundedDynamic(shape) || XlaHelpers::IsUnboundedDynamic(op_shape)); + XLA_CHECK(XlaHelpers::IsUnboundedDynamic(shape) || + XlaHelpers::IsUnboundedDynamic(op_shape)); const xla::Shape& aux_shape = ShapeHelper::ShapeOfXlaOp(aux_op); const auto& op_shape_dims = op_shape.dimensions(); diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f4dfff79be6..f149b1dac37 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1968,19 +1968,17 @@ void InitXlaModuleBindings(py::module m) { return handles; }); - m.def("_xla_mark_dynamic", - [](const at::Tensor& input, uint32_t dim) { - TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - xtensor->MarkDynamicDimension(dim); - }); + m.def("_xla_mark_dynamic", [](const at::Tensor& input, uint32_t dim) { + TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + xtensor->MarkDynamicDimension(dim); + }); - m.def("_xla_set_tag", - [](const at::Tensor& input, const std::string& tag) { - TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - xtensor->SetTag(tag); - }); + m.def("_xla_set_tag", [](const at::Tensor& input, const std::string& tag) { + TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); + XLATensorPtr xtensor = bridge::GetXlaTensor(input); + xtensor->SetTag(tag); + }); // -------------Dynamo Integration API Start------------------------- /* diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 41d449dc55b..b928fd6488f 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -178,7 +178,8 @@ std::string XlaNode::ToString() const { for (const auto dim : dynamic_dims_) { ss << dim; } - ss << ", " << "tags: " << experimental_tag_; + ss << ", " + << "tags: " << experimental_tag_; return ss.str(); } diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index d8ea10f0540..2ca090195c1 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -138,9 +138,7 @@ class XlaNode : public torch::lazy::Node { std::string ToString() const override; - void MarkDynamicDimension(uint32_t dim) { - dynamic_dims_.push_back(dim); - } + void MarkDynamicDimension(uint32_t dim) { dynamic_dims_.push_back(dim); } void SetTag(const std::string& tag) { experimental_tag_ = tag; } const std::string& experimental_tag() const { return experimental_tag_; } const std::vector& dynamic_dims() const { return dynamic_dims_; } diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index 17742ff07e4..e980eaa2636 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -96,21 +96,20 @@ LoweringContext::LoweringContext( // import xla::Shape.h to inlcude the following defintion. static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); xla::XlaOp LoweringContext::GetParameter( - const std::shared_ptr& data, - const std::vector& dynamic_dims - ) { + const std::shared_ptr& data, + const std::vector& dynamic_dims) { torch::lazy::BackendData::Handle handle = data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { - xla::Shape shape = std::dynamic_pointer_cast(data) + xla::Shape shape = + std::dynamic_pointer_cast(data) ->shape(); for (const int dim : dynamic_dims) { shape.set_dynamic_dimension(dim, true); shape.set_dimensions(dim, kUnboundedSize); } - xla::XlaOp param = xla::Parameter( - builder(), parameters_.size(), - shape, absl::StrCat("p", parameters_.size())); + xla::XlaOp param = xla::Parameter(builder(), parameters_.size(), shape, + absl::StrCat("p", parameters_.size())); it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) .first; parameters_.push_back(data); @@ -180,7 +179,8 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { result_ops = casted->Lower(this); xla::internal::XlaBuilderFriend builder_friend; auto* inst = builder_friend.GetInstruction(result_ops[0]); - auto* mutable_dynamic = inst->mutable_shape()->mutable_is_dynamic_dimension(); + auto* mutable_dynamic = + inst->mutable_shape()->mutable_is_dynamic_dimension(); if (mutable_dynamic->empty()) { for (int i = 0; i < inst->dimensions_size(); i++) { mutable_dynamic->Add(false); diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 0bbced7bc45..5213b8b0bf3 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -36,9 +36,8 @@ class LoweringContext : public torch::lazy::LoweringContext { // If a parameter associated with data has already been declared, it will be // returned. Otherwise a new one will be created, associated with the tensor // held in data. - xla::XlaOp GetParameter( - const std::shared_ptr& data, - const std::vector& dynamic_dims={}); + xla::XlaOp GetParameter(const std::shared_ptr& data, + const std::vector& dynamic_dims = {}); // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index a6082dbbe26..17d981bc8d3 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -118,8 +118,8 @@ SummationResult CreateSummation(xla::XlaOp input, result.result = XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions); #else - result.result = XlaHelpers::DynamicUnboundedReshape(result.result, input, - result.rinfo.new_dimensions); + result.result = XlaHelpers::DynamicUnboundedReshape( + result.result, input, result.rinfo.new_dimensions); #endif } return result; From 4acf02dad8116c8cbbd9344ad2c3cd48b091819d Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Sat, 11 Nov 2023 18:20:18 +0000 Subject: [PATCH 03/10] fix linter --- torch_xla/stablehlo.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_xla/stablehlo.py b/torch_xla/stablehlo.py index de2f55de81f..407cfeb972f 100644 --- a/torch_xla/stablehlo.py +++ b/torch_xla/stablehlo.py @@ -222,6 +222,7 @@ def run_node(self, n) -> Any: return res return super().run_node(n) + def _extract_input_args(exported_model, options): if options.override_tracing_arguments is not None: args = options.override_tracing_arguments From 495f844a5e952a2cde3cfd6dcc5b14aa06466353 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 13 Nov 2023 05:59:48 +0000 Subject: [PATCH 04/10] remove set tag --- torch_xla/csrc/init_python_bindings.cpp | 6 ------ torch_xla/csrc/ir.cpp | 2 -- torch_xla/csrc/ir.h | 7 +++---- torch_xla/csrc/tensor.cpp | 5 ----- torch_xla/csrc/tensor.h | 1 - 5 files changed, 3 insertions(+), 18 deletions(-) diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index f149b1dac37..be0cbad991f 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -1974,12 +1974,6 @@ void InitXlaModuleBindings(py::module m) { xtensor->MarkDynamicDimension(dim); }); - m.def("_xla_set_tag", [](const at::Tensor& input, const std::string& tag) { - TORCH_LAZY_COUNTER("XlaMarkDynamic", 1); - XLATensorPtr xtensor = bridge::GetXlaTensor(input); - xtensor->SetTag(tag); - }); - // -------------Dynamo Integration API Start------------------------- /* * Return tensor ids and at::tensors for all DeviceData nodes that is needed diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index b928fd6488f..2c0acd9927e 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -178,8 +178,6 @@ std::string XlaNode::ToString() const { for (const auto dim : dynamic_dims_) { ss << dim; } - ss << ", " - << "tags: " << experimental_tag_; return ss.str(); } diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 2ca090195c1..f9d86696913 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -138,13 +138,12 @@ class XlaNode : public torch::lazy::Node { std::string ToString() const override; - void MarkDynamicDimension(uint32_t dim) { dynamic_dims_.push_back(dim); } - void SetTag(const std::string& tag) { experimental_tag_ = tag; } - const std::string& experimental_tag() const { return experimental_tag_; } + void MarkDynamicDimension(uint32_t dim) { + dynamic_dims_.push_back(dim); + } const std::vector& dynamic_dims() const { return dynamic_dims_; } protected: - std::string experimental_tag_; std::vector dynamic_dims_; private: diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index d3fa030d99e..c4ea8192719 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -899,9 +899,4 @@ void XLATensor::MarkDynamicDimension(uint32_t dim) { xla_node->MarkDynamicDimension(dim); } -void XLATensor::SetTag(const std::string& tag) { - auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); - xla_node->SetTag(tag); -} - } // namespace torch_xla diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 75a15463206..f73aed5ce5f 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -202,7 +202,6 @@ class XLATensor : public torch::lazy::LazyTensor { void SetScalarType(c10::optional logical_element_type); void MarkDynamicDimension(uint32_t dim); - void SetTag(const std::string& tag); // We don't use the upstream shape to provide xla::shape. runtime::util::MaybeRef shape() const; From af1441580a3110dc464cff1b4664bcbd307446b1 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 13 Nov 2023 21:52:02 +0000 Subject: [PATCH 05/10] Enable unbounded dynamism using env var, add more guards for unbounded dynamism code path --- setup.py | 5 --- torch_xla/csrc/elementwise.cpp | 20 ++++++---- torch_xla/csrc/helpers.cpp | 61 ++++++++++------------------- torch_xla/csrc/helpers.h | 4 -- torch_xla/csrc/ir.h | 11 +++--- torch_xla/csrc/lowering_context.cpp | 37 ++++++++++------- torch_xla/csrc/lowering_context.h | 5 ++- torch_xla/csrc/reduction.cpp | 34 ++++++++++------ 8 files changed, 85 insertions(+), 92 deletions(-) diff --git a/setup.py b/setup.py index b0cca020701..a8a04c4c286 100644 --- a/setup.py +++ b/setup.py @@ -212,13 +212,8 @@ def run(self): extra_compile_args = [] cxx_abi = os.getenv( 'CXX_ABI', default='') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None) -experimental_dynamism = os.getenv( - 'EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM', default=None) if cxx_abi is not None: extra_compile_args.append(f'-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}') -if experimental_dynamism is not None: - extra_compile_args.append( - f'-DEXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM={experimental_dynamism}') class BazelExtension(Extension): diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 0e742e2c81e..027b80a6170 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -5,6 +5,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/random.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_lower_util.h" @@ -14,6 +15,9 @@ namespace torch_xla { namespace { +static const bool experimental_unbounded_dynamism = + runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); + xla::XlaOp Between(xla::XlaOp input, const at::Scalar& min_val, const at::Scalar& max_val) { const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input); @@ -66,16 +70,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output, xla::XlaOp BuildRelu(xla::XlaOp input) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); -#ifndef EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - return xla::Max(input, XlaHelpers::ScalarValue( - 0, input_shape.element_type(), input.builder())); -#else xla::XlaOp scalar = XlaHelpers::ScalarValue( 0, input_shape.element_type(), input.builder()); - auto promoted = XlaHelpers::Promote(input, scalar); - - return xla::Max(promoted.first, promoted.second); -#endif + if (experimental_unbounded_dynamism) { + // xla::Max doesn't do implicit broadcasting for unbounded dynamism now. + // TODO(lsy323): Remove this branch once the support is added in XLA. + auto promoted = XlaHelpers::Promote(input, scalar); + return xla::Max(promoted.first, promoted.second); + } else { + return xla::Max(input, scalar); + } } xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) { diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 4f317aed2a7..21f3b285790 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -21,9 +21,11 @@ namespace torch_xla { namespace { -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM +static const bool experimental_unbounded_dynamism = + runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); + +// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); -#endif xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, xla::XlaOp result) { @@ -67,9 +69,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input, std::vector bcast_sizes = SizesOfXlaOp(input); for (size_t i = 0; i < dimensions.size(); ++i) { bcast_sizes.at(dimensions[i]) = sizes[i]; -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - XLA_CHECK(sizes[i] != kUnboundedSize); -#endif + if (experimental_unbounded_dynamism) { + XLA_CHECK(sizes[i] != kUnboundedSize); + } } return xla::BroadcastInDim(input, bcast_sizes, GetAllDimensions(bcast_sizes.size())); @@ -329,9 +331,9 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, : xla::Reshape(input, shape.dimensions()); } -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { + XLA_CHECK(experimental_unbounded_dynamism) + << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const absl::Span dims = shape.dimensions(); return std::any_of(dims.begin(), dims.end(), [](int64_t size) { return size == kUnboundedSize; }); @@ -340,6 +342,8 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { xla::XlaOp XlaHelpers::DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes) { + XLA_CHECK(experimental_unbounded_dynamism) + << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) << "XlaHelpers::DynamicUnboundedReshape constrainled failed!"; @@ -381,13 +385,17 @@ xla::XlaOp XlaHelpers::DynamicUnboundedReshape( xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast( xla::XlaOp input, xla::XlaOp aux_input, absl::Span aux_input_dimensions) { + XLA_CHECK(experimental_unbounded_dynamism) + << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); bool all_static = true; std::vector output_dimensions; std::vector output_dynamic; for (auto dim : aux_input_dimensions) { - if (aux_input_shape.dimensions(dim) == kUnboundedSize) all_static = false; + if (aux_input_shape.dimensions(dim) == kUnboundedSize) { + all_static = false; + } output_dimensions.push_back(aux_input_shape.dimensions(dim)); output_dynamic.push_back(aux_input_shape.is_dynamic_dimension(dim)); } @@ -432,13 +440,6 @@ xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast( output_dynamic)); } -void XlaHelpers::PrintXlaOp(xla::XlaOp op, const std::string& msg) { - std::cout << "Handle: " << msg << ": " << op << "\n"; - const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(op); - std::cout << xla::ShapeUtil::HumanString(shape); -} -#endif - bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1, const xla::Shape& shape2) { return shape1.is_static() && shape2.is_static() && @@ -602,11 +603,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, runtime::util::ToVector(shape1.dimensions()), runtime::util::ToVector(shape2.dimensions()))); } -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && - !XlaHelpers::IsUnboundedDynamic(shape2)) - << "Unreachable for unbounded dynamic code\n"; -#endif + if (experimental_unbounded_dynamism) { + XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && + !XlaHelpers::IsUnboundedDynamic(shape2)) + << "Unreachable for unbounded dynamic code\n"; + } return GetPromotedDynamicShape(shape1, shape2); } @@ -700,7 +701,6 @@ std::pair XlaHelpers::PromoteSecond(xla::XlaOp op1, return PromoteShapes(vops.first, vops.second); } -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op, const xla::Shape& shape) { @@ -721,7 +721,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( std::vector reshaped_ops; if (size_delta > 0) { - std::cout << "\t size_delta > 0\n"; std::vector broadcast_sizes(shape_dims.begin(), shape_dims.begin() + size_delta); for (int i = 0; i < size_delta; i++) { @@ -730,20 +729,15 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( XlaHelpers::ScalarValue(broadcast_sizes[i], op.builder())); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: " << broadcast_sizes[i] << "\n"; } else { get_dim_ops.push_back(xla::GetDimensionSize(aux_op, i)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << i << "\n"; } } } if (size_delta == 0) { - std::cout << "\t size_delta == 0\n"; int sz = op_shape_dims.size() - aux_shape_dims.size(); std::vector broadcast_sizes(shape_dims.begin(), shape_dims.begin() + sz); @@ -753,14 +747,10 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( XlaHelpers::ScalarValue(broadcast_sizes[i], op.builder())); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: " << broadcast_sizes[i] << "\n"; } else { get_dim_ops.push_back(xla::GetDimensionSize(op, i)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << i << "\n"; } } } @@ -789,8 +779,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( get_dim_ops.push_back(ScalarValue(shape_dim, op.builder())); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: " << shape_dim << "\n"; } else if (op_shape_dim == 1 || aux_op_shape_dim == 1) { if (op_shape_dim == 1) { @@ -798,22 +786,16 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( xla::GetDimensionSize(aux_op, aux_op_shape_index)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << aux_op_shape_index << "\n"; } else { get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << op_shape_index << "\n"; } } else { get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << op_shape_index << "\n"; } } @@ -829,7 +811,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( return new_op; } -#endif xla::XlaOp XlaHelpers::ImplicitBroadcast(xla::XlaOp op, const xla::Shape& op_shape, diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 51dee64573b..342e84eb993 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -158,7 +158,6 @@ class XlaHelpers { static xla::XlaOp DynamicReshape(xla::XlaOp input, absl::Span output_sizes); -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM static bool IsUnboundedDynamic(const xla::Shape& shape); static xla::XlaOp DynamicUnboundedReshape( @@ -169,9 +168,6 @@ class XlaHelpers { xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes); - static void PrintXlaOp(xla::XlaOp op, const std::string& msg); -#endif - static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape); static bool SameStaticDimensions(const xla::Shape& shape1, diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index f9d86696913..8c852573730 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -9,9 +9,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -138,13 +138,14 @@ class XlaNode : public torch::lazy::Node { std::string ToString() const override; - void MarkDynamicDimension(uint32_t dim) { - dynamic_dims_.push_back(dim); + void MarkDynamicDimension(uint32_t dim) { dynamic_dims_.insert(dim); } + + const std::unordered_set& dynamic_dims() const { + return dynamic_dims_; } - const std::vector& dynamic_dims() const { return dynamic_dims_; } protected: - std::vector dynamic_dims_; + std::unordered_set dynamic_dims_; private: xla::Shape GetOpShape(const std::function& shape_fn) const; diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index e980eaa2636..404fa82ea7b 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -93,18 +93,19 @@ LoweringContext::LoweringContext( } } -// import xla::Shape.h to inlcude the following defintion. +// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); + xla::XlaOp LoweringContext::GetParameter( const std::shared_ptr& data, - const std::vector& dynamic_dims) { + const std::unordered_set& unbounded_dynamic_dims) { torch::lazy::BackendData::Handle handle = data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { xla::Shape shape = std::dynamic_pointer_cast(data) ->shape(); - for (const int dim : dynamic_dims) { + for (const int dim : unbounded_dynamic_dims) { shape.set_dynamic_dimension(dim, true); shape.set_dimensions(dim, kUnboundedSize); } @@ -113,6 +114,10 @@ xla::XlaOp LoweringContext::GetParameter( it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) .first; parameters_.push_back(data); + } else { + XLA_CHECK(unbounded_dynamic_dims.empty()) + << "The unbounded dynamic dims can only be set when Parameter is " + "created."; } parameter_sequence_.push_back(it->second.index); return it->second.param; @@ -177,19 +182,21 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { const XlaNode* casted = dynamic_cast(node); result_ops = casted->Lower(this); - xla::internal::XlaBuilderFriend builder_friend; - auto* inst = builder_friend.GetInstruction(result_ops[0]); - auto* mutable_dynamic = - inst->mutable_shape()->mutable_is_dynamic_dimension(); - if (mutable_dynamic->empty()) { - for (int i = 0; i < inst->dimensions_size(); i++) { - mutable_dynamic->Add(false); + if (!casted->dynamic_dims().empty()) { + xla::internal::XlaBuilderFriend builder_friend; + auto* inst = builder_friend.GetInstruction(result_ops[0]); + auto* mutable_dynamic = + inst->mutable_shape()->mutable_is_dynamic_dimension(); + if (mutable_dynamic->empty()) { + for (int i = 0; i < inst->dimensions_size(); i++) { + mutable_dynamic->Add(false); + } + } + auto* mutable_dims = inst->mutable_shape()->mutable_dimensions(); + for (const auto dim : casted->dynamic_dims()) { + mutable_dynamic->Set(dim, true); + mutable_dims->Set(dim, kUnboundedSize); } - } - auto* mutable_dims = inst->mutable_shape()->mutable_dimensions(); - for (const auto dim : casted->dynamic_dims()) { - mutable_dynamic->Set(dim, true); - mutable_dims->Set(dim, kUnboundedSize); } } catch (const std::exception& ex) { ReportBuilderError(node, ex.what()); diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 5213b8b0bf3..b46d91874b0 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -36,8 +36,9 @@ class LoweringContext : public torch::lazy::LoweringContext { // If a parameter associated with data has already been declared, it will be // returned. Otherwise a new one will be created, associated with the tensor // held in data. - xla::XlaOp GetParameter(const std::shared_ptr& data, - const std::vector& dynamic_dims = {}); + xla::XlaOp GetParameter( + const std::shared_ptr& data, + const std::unordered_set& dynamic_dims = {}); // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 17d981bc8d3..336123a542e 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -31,6 +31,9 @@ struct SummationResult { xla::XlaOp result; }; +static const bool experimental_unbounded_dynamism = + runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); + ReductionInfo GetReductionInfo(xla::XlaOp input, const xla::Shape& shape, absl::Span dimensions, bool keep_reduced_dimensions) { @@ -81,12 +84,15 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count, xla::XlaOp scale = xla::Select(xla::Ne(count, zero), one / xla::ConvertElementType(count, type), xla::NanValue(input.builder(), type)); -#if !EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - return input * scale; -#else - auto promoted = XlaHelpers::Promote(input, scale); - return promoted.first * promoted.second; -#endif + + if (experimental_unbounded_dynamism) { + // XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now. + // TODO(lsy323): Remove this branch once the support is added in XLA. + auto promoted = XlaHelpers::Promote(input, scale); + return promoted.first * promoted.second; + } else { + return input * scale; + } } xla::XlaOp AverageValue(xla::XlaOp input, xla::XlaOp reduced) { @@ -114,13 +120,15 @@ SummationResult CreateSummation(xla::XlaOp input, result.result, result.rinfo.element_count.size, shape.element_type()); } if (keep_reduced_dimensions) { -#if !EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - result.result = - XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions); -#else - result.result = XlaHelpers::DynamicUnboundedReshape( - result.result, input, result.rinfo.new_dimensions); -#endif + if (experimental_unbounded_dynamism) { + // TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is + // added. + result.result = XlaHelpers::DynamicUnboundedReshape( + result.result, input, result.rinfo.new_dimensions); + } else { + result.result = XlaHelpers::DynamicReshape(result.result, + result.rinfo.new_dimensions); + } } return result; } From 73ec1c58d08d83e2f84c944af3322c2070f25f6b Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Mon, 13 Nov 2023 22:40:00 +0000 Subject: [PATCH 06/10] remove unused util --- torch_xla/csrc/helpers.cpp | 169 ------------------------------------- torch_xla/csrc/helpers.h | 13 --- torch_xla/csrc/tensor.cpp | 1 - 3 files changed, 183 deletions(-) diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 21f3b285790..eeec940b8b9 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -382,64 +382,6 @@ xla::XlaOp XlaHelpers::DynamicUnboundedReshape( return input; } -xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast( - xla::XlaOp input, xla::XlaOp aux_input, - absl::Span aux_input_dimensions) { - XLA_CHECK(experimental_unbounded_dynamism) - << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; - const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); - const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); - bool all_static = true; - std::vector output_dimensions; - std::vector output_dynamic; - for (auto dim : aux_input_dimensions) { - if (aux_input_shape.dimensions(dim) == kUnboundedSize) { - all_static = false; - } - output_dimensions.push_back(aux_input_shape.dimensions(dim)); - output_dynamic.push_back(aux_input_shape.is_dynamic_dimension(dim)); - } - - if (all_static) { - return xla::Broadcast(input, output_dimensions); - } - - std::vector get_dim_ops; - std::vector reshaped_ops; - for (auto dim : aux_input_dimensions) { - if (aux_input_shape.dimensions(dim) != kUnboundedSize) { - get_dim_ops.push_back(XlaHelpers::ScalarValue( - aux_input_shape.dimensions(dim), aux_input.builder())); - } else { - get_dim_ops.push_back(xla::GetDimensionSize(aux_input, dim)); - } - } - - for (int dim = 0; dim < input_shape.rank(); dim++) { - output_dimensions.push_back(input_shape.dimensions(dim)); - output_dynamic.push_back(input_shape.is_dynamic_dimension(dim)); - if (input_shape.dimensions(dim) != kUnboundedSize) { - get_dim_ops.push_back(XlaHelpers::ScalarValue( - input_shape.dimensions(dim), input.builder())); - } else { - get_dim_ops.push_back(xla::GetDimensionSize(input, dim)); - } - } - - // Create the reshape from scalar to 1-D vector - for (auto get_dim_op : get_dim_ops) { - reshaped_ops.push_back(xla::Reshape(get_dim_op, {1})); - } - - // Create Concatenate op - auto concat_op = xla::ConcatInDim(input.builder(), reshaped_ops, {0}); - return xla::CustomCall( - aux_input.builder(), "stablehlo.dynamic_broadcast_in_dim", - {input, concat_op}, - xla::ShapeUtil::MakeShape(input_shape.element_type(), output_dimensions, - output_dynamic)); -} - bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1, const xla::Shape& shape2) { return shape1.is_static() && shape2.is_static() && @@ -701,117 +643,6 @@ std::pair XlaHelpers::PromoteSecond(xla::XlaOp op1, return PromoteShapes(vops.first, vops.second); } -xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( - xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op, - const xla::Shape& shape) { - XLA_CHECK(XlaHelpers::IsUnboundedDynamic(shape) || - XlaHelpers::IsUnboundedDynamic(op_shape)); - - const xla::Shape& aux_shape = ShapeHelper::ShapeOfXlaOp(aux_op); - const auto& op_shape_dims = op_shape.dimensions(); - const auto& aux_shape_dims = aux_shape.dimensions(); - const auto& shape_dims = shape.dimensions(); - - XLA_CHECK_GE(shape_dims.size(), op_shape_dims.size()) - << shape << " vs " << op_shape; - - int64_t size_delta = shape_dims.size() - op_shape_dims.size(); - xla::XlaOp new_op = op; - std::vector get_dim_ops; - std::vector reshaped_ops; - - if (size_delta > 0) { - std::vector broadcast_sizes(shape_dims.begin(), - shape_dims.begin() + size_delta); - for (int i = 0; i < size_delta; i++) { - if (broadcast_sizes[i] != kUnboundedSize) { - get_dim_ops.push_back( - XlaHelpers::ScalarValue(broadcast_sizes[i], op.builder())); - - auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - } else { - get_dim_ops.push_back(xla::GetDimensionSize(aux_op, i)); - - auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - } - } - } - - if (size_delta == 0) { - int sz = op_shape_dims.size() - aux_shape_dims.size(); - std::vector broadcast_sizes(shape_dims.begin(), - shape_dims.begin() + sz); - for (int i = 0; i < sz; i++) { - if (broadcast_sizes[i] != kUnboundedSize) { - get_dim_ops.push_back( - XlaHelpers::ScalarValue(broadcast_sizes[i], op.builder())); - - auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - } else { - get_dim_ops.push_back(xla::GetDimensionSize(op, i)); - - auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - } - } - } - - size_t min_size = std::min(op_shape_dims.size(), aux_shape_dims.size()); - for (int i = 0; i < min_size; i++) { - int op_shape_index = op_shape_dims.size() - min_size + i; - int aux_op_shape_index = aux_shape_dims.size() - min_size + i; - int shape_index = shape_dims.size() - min_size + i; - - int64_t op_shape_dim = op_shape_dims[op_shape_index]; - int64_t aux_op_shape_dim = aux_shape_dims[aux_op_shape_index]; - int64_t shape_dim = shape_dims[shape_index]; - - // op_shape aux_op_shape shape - // 1 X X - // X 1 X - // X X X - // 1 ? ? (from aux_op_shape) - // ? 1 ? (from op_shape) - // X ? X - // ? X X - // ? ? ? (from any, let's select op_shape) - // where X != kUnboundedSize && X != 1 - if (shape_dim != kUnboundedSize) { - get_dim_ops.push_back(ScalarValue(shape_dim, op.builder())); - - auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - - } else if (op_shape_dim == 1 || aux_op_shape_dim == 1) { - if (op_shape_dim == 1) { - get_dim_ops.push_back( - xla::GetDimensionSize(aux_op, aux_op_shape_index)); - - auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - - } else { - get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index)); - - auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - } - } else { - get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index)); - - auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - } - } - - // Create the reshape from scalar to 1-D vector - for (auto get_dim_op : get_dim_ops) { - reshaped_ops.push_back(xla::Reshape(get_dim_op, {1})); - } - - // Create Concatenate op - auto concat_op = xla::ConcatInDim(op.builder(), reshaped_ops, {0}); - new_op = xla::CustomCall(op.builder(), "stablehlo.dynamic_broadcast_in_dim", - {op, concat_op}, shape); - - return new_op; -} - xla::XlaOp XlaHelpers::ImplicitBroadcast(xla::XlaOp op, const xla::Shape& op_shape, const xla::Shape& shape) { diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 342e84eb993..fdbdd4287d0 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -164,10 +164,6 @@ class XlaHelpers { xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes); - static xla::XlaOp DynamicUnboundedBroadcast( - xla::XlaOp input, xla::XlaOp aux_input, - absl::Span output_sizes); - static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape); static bool SameStaticDimensions(const xla::Shape& shape1, @@ -298,15 +294,6 @@ class XlaHelpers { static xla::XlaOp ImplicitBroadcast(xla::XlaOp op, const xla::Shape& op_shape, const xla::Shape& shape); - // Returns a new operations which broadcast the input operation with unbounded - // dynamic dimensions into the shape. The op_shape is the shape of the op - // operation, while shape should be one that op is broadcast-able to (usually - // the result of a GetPromotedShape() call). If op_shape matches shape, the op - // itself is returned. - static xla::XlaOp ImplicitBroadcastWithUnboundedDynamicShapes( - xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op, - const xla::Shape& shape); - // Performs the bin_op binary operation by promoting types and shapes of the // two input operands. static xla::XlaOp PromotedBinaryOp( diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index c4ea8192719..672f7fe0c8e 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -894,7 +894,6 @@ int64_t XLATensor::GetHandle() const { } void XLATensor::MarkDynamicDimension(uint32_t dim) { - // auto* xla_node = dynamic_cast(CurrentIrValue().node.get()); auto* xla_node = dynamic_cast(GetIrValue().node.get()); xla_node->MarkDynamicDimension(dim); } From f2e65aa4b00fa8df43e70e23f3d52029b0d1cfa0 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 14 Nov 2023 22:39:41 +0000 Subject: [PATCH 07/10] Get unbounded dynamism flag from function --- WORKSPACE | 46 +++++++++++++++++----------------- torch_xla/csrc/elementwise.cpp | 5 +--- torch_xla/csrc/helpers.cpp | 12 +++------ torch_xla/csrc/helpers.h | 6 +++++ torch_xla/csrc/reduction.cpp | 7 ++---- 5 files changed, 36 insertions(+), 40 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index ace66355416..587f623f8f8 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -30,25 +30,25 @@ python_configure( # b) get the sha256 hash of the commit by running: # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. -http_archive( - name = "xla", - patch_args = [ - "-l", - "-p1", - ], - patch_tool = "patch", - patches = [ - "//openxla_patches:cache_urls.diff", - "//openxla_patches:constexpr_return.diff", - "//openxla_patches:gpu_race_condition.diff", - "//openxla_patches:f16_abi_clang.diff", - "//openxla_patches:gpu_topk_rewriter.diff", - ], - strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478", - urls = [ - "https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz", - ], -) +# http_archive( +# name = "xla", +# patch_args = [ +# "-l", +# "-p1", +# ], +# patch_tool = "patch", +# patches = [ +# "//openxla_patches:cache_urls.diff", +# "//openxla_patches:constexpr_return.diff", +# "//openxla_patches:gpu_race_condition.diff", +# "//openxla_patches:f16_abi_clang.diff", +# "//openxla_patches:gpu_topk_rewriter.diff", +# ], +# strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478", +# urls = [ +# "https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz", +# ], +# ) # For development, one often wants to make changes to the OpenXLA repository as well # as the PyTorch/XLA repository. You can override the pinned repository above with a @@ -58,10 +58,10 @@ http_archive( # bazel --override_repository=xla=/path/to/openxla # or # b) by commenting out the http_archive above and uncommenting the following: -# local_repository( -# name = "xla", -# path = "/path/to/openxla", -# ) +local_repository( + name = "xla", + path = "/home/lsiyuan/work/xla", +) # Initialize OpenXLA's external dependencies. load("@xla//:workspace4.bzl", "xla_workspace4") diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 027b80a6170..62a906d98ea 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -15,9 +15,6 @@ namespace torch_xla { namespace { -static const bool experimental_unbounded_dynamism = - runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); - xla::XlaOp Between(xla::XlaOp input, const at::Scalar& min_val, const at::Scalar& max_val) { const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input); @@ -72,7 +69,7 @@ xla::XlaOp BuildRelu(xla::XlaOp input) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); xla::XlaOp scalar = XlaHelpers::ScalarValue( 0, input_shape.element_type(), input.builder()); - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { // xla::Max doesn't do implicit broadcasting for unbounded dynamism now. // TODO(lsy323): Remove this branch once the support is added in XLA. auto promoted = XlaHelpers::Promote(input, scalar); diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index eeec940b8b9..61f634904d3 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -9,7 +9,6 @@ #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/shape_helper.h" @@ -21,9 +20,6 @@ namespace torch_xla { namespace { -static const bool experimental_unbounded_dynamism = - runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); - // TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); @@ -69,7 +65,7 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input, std::vector bcast_sizes = SizesOfXlaOp(input); for (size_t i = 0; i < dimensions.size(); ++i) { bcast_sizes.at(dimensions[i]) = sizes[i]; - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { XLA_CHECK(sizes[i] != kUnboundedSize); } } @@ -332,7 +328,7 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, } bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { - XLA_CHECK(experimental_unbounded_dynamism) + XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const absl::Span dims = shape.dimensions(); return std::any_of(dims.begin(), dims.end(), @@ -342,7 +338,7 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { xla::XlaOp XlaHelpers::DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes) { - XLA_CHECK(experimental_unbounded_dynamism) + XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) @@ -545,7 +541,7 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, runtime::util::ToVector(shape1.dimensions()), runtime::util::ToVector(shape2.dimensions()))); } - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && !XlaHelpers::IsUnboundedDynamic(shape2)) << "Unreachable for unbounded dynamic code\n"; diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index fdbdd4287d0..66c01588b57 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -13,6 +13,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/util.h" #include "tsl/platform/bfloat16.h" #include "xla/client/xla_builder.h" @@ -160,6 +161,11 @@ class XlaHelpers { static bool IsUnboundedDynamic(const xla::Shape& shape); + static bool IsUnboundedDynamismEnabled() { + return runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", + false); + } + static xla::XlaOp DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes); diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 336123a542e..aff46743410 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -31,9 +31,6 @@ struct SummationResult { xla::XlaOp result; }; -static const bool experimental_unbounded_dynamism = - runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); - ReductionInfo GetReductionInfo(xla::XlaOp input, const xla::Shape& shape, absl::Span dimensions, bool keep_reduced_dimensions) { @@ -85,7 +82,7 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count, one / xla::ConvertElementType(count, type), xla::NanValue(input.builder(), type)); - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { // XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now. // TODO(lsy323): Remove this branch once the support is added in XLA. auto promoted = XlaHelpers::Promote(input, scale); @@ -120,7 +117,7 @@ SummationResult CreateSummation(xla::XlaOp input, result.result, result.rinfo.element_count.size, shape.element_type()); } if (keep_reduced_dimensions) { - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { // TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is // added. result.result = XlaHelpers::DynamicUnboundedReshape( From 6b8072d41d35e0084913349bf8ad123c730a7b9f Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 15 Nov 2023 04:17:42 +0000 Subject: [PATCH 08/10] add test --- test/stablehlo/test_unbounded_dynamism.py | 58 +++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 test/stablehlo/test_unbounded_dynamism.py diff --git a/test/stablehlo/test_unbounded_dynamism.py b/test/stablehlo/test_unbounded_dynamism.py new file mode 100644 index 00000000000..a4223b799aa --- /dev/null +++ b/test/stablehlo/test_unbounded_dynamism.py @@ -0,0 +1,58 @@ +import sys +import unittest + +import torch +import torch_xla +import torch_xla.core.xla_model as xm +from torch_xla.stablehlo import exported_program_to_stablehlo + +# Note: Unbounded dynamism is under development. It works with unmerged +# XLA changes. Experimental XLA branch: https://github.com/lsy323/openxla-xla/tree/lsiyuan/sandeep-dynamism-rebased + +device = xm.xla_device() + + +class UnboundedDynamismExportTest(unittest.TestCase): + + def test_simply_add(self): + a = torch.tensor([[1, 2], [2, 4]], device=device) + torch_xla._XLAC._xla_mark_dynamic(a, 0) + b = torch.tensor([[1, 2], [2, 4]], device=device) + torch_xla._XLAC._xla_mark_dynamic(b, 0) + c = a * b + hlo_content = torch_xla._XLAC._get_xla_tensors_hlo([c]) + self.assertTrue( + "(p0.1: s64[?,2], p1.2: s64[?,2]) -> (s64[?,2])" in hlo_content) + + def test_export_dynamism(self): + + class M(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x, y): + return x * y + + example_args = (torch.tensor([[1, 2], [2, 4]], device=device), + torch.tensor([[1, 2], [2, 4]], device=device)) + constraints = [ + # First dimension of each input is a dynamic batch size + torch.export.dynamic_dim(example_args[0], 0), + torch.export.dynamic_dim(example_args[1], 0), + # The dynamic batch size between the inputs are equal + torch.export.dynamic_dim(example_args[0], + 0) == torch.export.dynamic_dim( + example_args[1], 0), + ] + ep = torch.export.export(M(), args=example_args, constraints=constraints) + shlo_module = exported_program_to_stablehlo(ep) + shlo_text = shlo_module.get_stablehlo_text("forward") + self.assertTrue( + "(%arg0: tensor, %arg1: tensor) -> tensor" in + shlo_text) + + +if __name__ == '__main__': + test = unittest.main() + sys.exit(0 if test.result.wasSuccessful() else 1) From 07d2b43badb6a7b5703ef2ac24ad6875a08f6c59 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 15 Nov 2023 04:24:57 +0000 Subject: [PATCH 09/10] rename unbounded dynamism private number, and str serializaiton --- torch_xla/csrc/helpers.cpp | 6 ++++-- torch_xla/csrc/ir.cpp | 6 ++---- torch_xla/csrc/ir.h | 8 +++++--- torch_xla/csrc/ops/device_data.cpp | 2 +- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 61f634904d3..995e43078b6 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -329,7 +329,8 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) - << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; + << "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded " + "dynamism workload."; const absl::Span dims = shape.dimensions(); return std::any_of(dims.begin(), dims.end(), [](int64_t size) { return size == kUnboundedSize; }); @@ -339,7 +340,8 @@ xla::XlaOp XlaHelpers::DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes) { XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) - << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; + << "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded " + "dynamism workload."; const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) << "XlaHelpers::DynamicUnboundedReshape constrainled failed!"; diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 2c0acd9927e..82b746ab181 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -174,10 +174,8 @@ xla::Shape XlaNode::GetOpShape( std::string XlaNode::ToString() const { std::stringstream ss; ss << torch::lazy::Node::ToString() << ", xla_shape=" << xla_shape_; - ss << ", dynamic_dims: "; - for (const auto dim : dynamic_dims_) { - ss << dim; - } + ss << ", dynamic_dims: (" << absl::StrJoin(unbounded_dynamic_dims_, ", ") + << ')'; return ss.str(); } diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 8c852573730..d0619ef5c98 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -138,14 +138,16 @@ class XlaNode : public torch::lazy::Node { std::string ToString() const override; - void MarkDynamicDimension(uint32_t dim) { dynamic_dims_.insert(dim); } + void MarkDynamicDimension(uint32_t dim) { + unbounded_dynamic_dims_.insert(dim); + } const std::unordered_set& dynamic_dims() const { - return dynamic_dims_; + return unbounded_dynamic_dims_; } protected: - std::unordered_set dynamic_dims_; + std::unordered_set unbounded_dynamic_dims_; private: xla::Shape GetOpShape(const std::function& shape_fn) const; diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index 54c455e3b66..e07fe3c4e76 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -36,7 +36,7 @@ torch::lazy::NodePtr DeviceData::Clone(torch::lazy::OpList operands) const { } XlaOpVector DeviceData::Lower(LoweringContext* loctx) const { - return ReturnOp(loctx->GetParameter(data_, dynamic_dims_), loctx); + return ReturnOp(loctx->GetParameter(data_, unbounded_dynamic_dims_), loctx); } DeviceData* DeviceData::Cast(const torch::lazy::Node* node) { From 7c0ac9bb10e40d188f6f6e5af4e242c793a15e57 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Wed, 15 Nov 2023 04:25:54 +0000 Subject: [PATCH 10/10] recover WORKSPACE --- WORKSPACE | 46 +++++++++++++++++++++++----------------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 587f623f8f8..ace66355416 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -30,25 +30,25 @@ python_configure( # b) get the sha256 hash of the commit by running: # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update the sha256 with the result. -# http_archive( -# name = "xla", -# patch_args = [ -# "-l", -# "-p1", -# ], -# patch_tool = "patch", -# patches = [ -# "//openxla_patches:cache_urls.diff", -# "//openxla_patches:constexpr_return.diff", -# "//openxla_patches:gpu_race_condition.diff", -# "//openxla_patches:f16_abi_clang.diff", -# "//openxla_patches:gpu_topk_rewriter.diff", -# ], -# strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478", -# urls = [ -# "https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz", -# ], -# ) +http_archive( + name = "xla", + patch_args = [ + "-l", + "-p1", + ], + patch_tool = "patch", + patches = [ + "//openxla_patches:cache_urls.diff", + "//openxla_patches:constexpr_return.diff", + "//openxla_patches:gpu_race_condition.diff", + "//openxla_patches:f16_abi_clang.diff", + "//openxla_patches:gpu_topk_rewriter.diff", + ], + strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478", + urls = [ + "https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz", + ], +) # For development, one often wants to make changes to the OpenXLA repository as well # as the PyTorch/XLA repository. You can override the pinned repository above with a @@ -58,10 +58,10 @@ python_configure( # bazel --override_repository=xla=/path/to/openxla # or # b) by commenting out the http_archive above and uncommenting the following: -local_repository( - name = "xla", - path = "/home/lsiyuan/work/xla", -) +# local_repository( +# name = "xla", +# path = "/path/to/openxla", +# ) # Initialize OpenXLA's external dependencies. load("@xla//:workspace4.bzl", "xla_workspace4")