diff --git a/setup.py b/setup.py index b0cca020701f..a8a04c4c286a 100644 --- a/setup.py +++ b/setup.py @@ -212,13 +212,8 @@ def run(self): extra_compile_args = [] cxx_abi = os.getenv( 'CXX_ABI', default='') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None) -experimental_dynamism = os.getenv( - 'EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM', default=None) if cxx_abi is not None: extra_compile_args.append(f'-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}') -if experimental_dynamism is not None: - extra_compile_args.append( - f'-DEXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM={experimental_dynamism}') class BazelExtension(Extension): diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 0e742e2c81e6..dbc922fcd8f0 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -5,6 +5,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/random.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/shape_helper.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_lower_util.h" @@ -14,6 +15,9 @@ namespace torch_xla { namespace { +static const bool experimental_unbounded_dynamism = + runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); + xla::XlaOp Between(xla::XlaOp input, const at::Scalar& min_val, const at::Scalar& max_val) { const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input); @@ -66,16 +70,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output, xla::XlaOp BuildRelu(xla::XlaOp input) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); -#ifndef EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - return xla::Max(input, XlaHelpers::ScalarValue( - 0, input_shape.element_type(), input.builder())); -#else xla::XlaOp scalar = XlaHelpers::ScalarValue( 0, input_shape.element_type(), input.builder()); - auto promoted = XlaHelpers::Promote(input, scalar); - - return xla::Max(promoted.first, promoted.second); -#endif + if (experimental_unbounded_dynamism) { + // xla::Max doesn't do implicit broadcasting for unbounded dynamism now. + // TODO(lsy323): Remove this branch once the support is added in XLA. + auto promoted = XlaHelpers::Promote(input, scalar); + return xla::Max(promoted.first, promoted.second); + } else { + return xla::Max(input, scalar); + } } xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) { diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 4f317aed2a77..8290da239930 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -21,9 +21,12 @@ namespace torch_xla { namespace { -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM + +static const bool experimental_unbounded_dynamism = + runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); + +// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); -#endif xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, xla::XlaOp result) { @@ -67,9 +70,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input, std::vector bcast_sizes = SizesOfXlaOp(input); for (size_t i = 0; i < dimensions.size(); ++i) { bcast_sizes.at(dimensions[i]) = sizes[i]; -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - XLA_CHECK(sizes[i] != kUnboundedSize); -#endif + if (experimental_unbounded_dynamism) { + XLA_CHECK(sizes[i] != kUnboundedSize); + } } return xla::BroadcastInDim(input, bcast_sizes, GetAllDimensions(bcast_sizes.size())); @@ -329,9 +332,9 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, : xla::Reshape(input, shape.dimensions()); } -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { + XLA_CHECK(experimental_unbounded_dynamism) + << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const absl::Span dims = shape.dimensions(); return std::any_of(dims.begin(), dims.end(), [](int64_t size) { return size == kUnboundedSize; }); @@ -340,6 +343,8 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { xla::XlaOp XlaHelpers::DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes) { + XLA_CHECK(experimental_unbounded_dynamism) + << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) << "XlaHelpers::DynamicUnboundedReshape constrainled failed!"; @@ -381,13 +386,17 @@ xla::XlaOp XlaHelpers::DynamicUnboundedReshape( xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast( xla::XlaOp input, xla::XlaOp aux_input, absl::Span aux_input_dimensions) { + XLA_CHECK(experimental_unbounded_dynamism) + << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); bool all_static = true; std::vector output_dimensions; std::vector output_dynamic; for (auto dim : aux_input_dimensions) { - if (aux_input_shape.dimensions(dim) == kUnboundedSize) all_static = false; + if (aux_input_shape.dimensions(dim) == kUnboundedSize) { + all_static = false; + } output_dimensions.push_back(aux_input_shape.dimensions(dim)); output_dynamic.push_back(aux_input_shape.is_dynamic_dimension(dim)); } @@ -432,13 +441,6 @@ xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast( output_dynamic)); } -void XlaHelpers::PrintXlaOp(xla::XlaOp op, const std::string& msg) { - std::cout << "Handle: " << msg << ": " << op << "\n"; - const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(op); - std::cout << xla::ShapeUtil::HumanString(shape); -} -#endif - bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1, const xla::Shape& shape2) { return shape1.is_static() && shape2.is_static() && @@ -602,11 +604,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, runtime::util::ToVector(shape1.dimensions()), runtime::util::ToVector(shape2.dimensions()))); } -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && - !XlaHelpers::IsUnboundedDynamic(shape2)) - << "Unreachable for unbounded dynamic code\n"; -#endif + if (experimental_unbounded_dynamism) { + XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && + !XlaHelpers::IsUnboundedDynamic(shape2)) + << "Unreachable for unbounded dynamic code\n"; + } return GetPromotedDynamicShape(shape1, shape2); } @@ -700,7 +702,6 @@ std::pair XlaHelpers::PromoteSecond(xla::XlaOp op1, return PromoteShapes(vops.first, vops.second); } -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op, const xla::Shape& shape) { @@ -721,7 +722,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( std::vector reshaped_ops; if (size_delta > 0) { - std::cout << "\t size_delta > 0\n"; std::vector broadcast_sizes(shape_dims.begin(), shape_dims.begin() + size_delta); for (int i = 0; i < size_delta; i++) { @@ -730,20 +730,15 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( XlaHelpers::ScalarValue(broadcast_sizes[i], op.builder())); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: " << broadcast_sizes[i] << "\n"; } else { get_dim_ops.push_back(xla::GetDimensionSize(aux_op, i)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << i << "\n"; } } } if (size_delta == 0) { - std::cout << "\t size_delta == 0\n"; int sz = op_shape_dims.size() - aux_shape_dims.size(); std::vector broadcast_sizes(shape_dims.begin(), shape_dims.begin() + sz); @@ -753,14 +748,10 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( XlaHelpers::ScalarValue(broadcast_sizes[i], op.builder())); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: " << broadcast_sizes[i] << "\n"; } else { get_dim_ops.push_back(xla::GetDimensionSize(op, i)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << i << "\n"; } } } @@ -789,8 +780,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( get_dim_ops.push_back(ScalarValue(shape_dim, op.builder())); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: " << shape_dim << "\n"; } else if (op_shape_dim == 1 || aux_op_shape_dim == 1) { if (op_shape_dim == 1) { @@ -798,22 +787,16 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( xla::GetDimensionSize(aux_op, aux_op_shape_index)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << aux_op_shape_index << "\n"; } else { get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << op_shape_index << "\n"; } } else { get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index)); auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back()); - std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s) - << " for size: ? of index: " << op_shape_index << "\n"; } } @@ -829,7 +812,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes( return new_op; } -#endif xla::XlaOp XlaHelpers::ImplicitBroadcast(xla::XlaOp op, const xla::Shape& op_shape, diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 51dee64573b4..342e84eb9934 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -158,7 +158,6 @@ class XlaHelpers { static xla::XlaOp DynamicReshape(xla::XlaOp input, absl::Span output_sizes); -#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM static bool IsUnboundedDynamic(const xla::Shape& shape); static xla::XlaOp DynamicUnboundedReshape( @@ -169,9 +168,6 @@ class XlaHelpers { xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes); - static void PrintXlaOp(xla::XlaOp op, const std::string& msg); -#endif - static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape); static bool SameStaticDimensions(const xla::Shape& shape1, diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index f9d86696913e..e4c48a01b123 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -9,9 +9,9 @@ #include #include #include -#include #include #include +#include #include #include @@ -139,12 +139,13 @@ class XlaNode : public torch::lazy::Node { std::string ToString() const override; void MarkDynamicDimension(uint32_t dim) { - dynamic_dims_.push_back(dim); + dynamic_dims_.insert(dim); } - const std::vector& dynamic_dims() const { return dynamic_dims_; } + + const std::unordered_set& dynamic_dims() const { return dynamic_dims_; } protected: - std::vector dynamic_dims_; + std::unordered_set dynamic_dims_; private: xla::Shape GetOpShape(const std::function& shape_fn) const; diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index e980eaa26360..07fda8c425b0 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -93,18 +93,19 @@ LoweringContext::LoweringContext( } } -// import xla::Shape.h to inlcude the following defintion. +// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); + xla::XlaOp LoweringContext::GetParameter( const std::shared_ptr& data, - const std::vector& dynamic_dims) { + const std::unordered_set& unbounded_dynamic_dims) { torch::lazy::BackendData::Handle handle = data->GetHandle(); auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { xla::Shape shape = std::dynamic_pointer_cast(data) ->shape(); - for (const int dim : dynamic_dims) { + for (const int dim : unbounded_dynamic_dims) { shape.set_dynamic_dimension(dim, true); shape.set_dimensions(dim, kUnboundedSize); } @@ -113,6 +114,9 @@ xla::XlaOp LoweringContext::GetParameter( it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()}) .first; parameters_.push_back(data); + } else { + XLA_CHECK(unbounded_dynamic_dims.empty()) + << "The unbounded dynamic dims can only be set when Parameter is created."; } parameter_sequence_.push_back(it->second.index); return it->second.param; @@ -177,19 +181,21 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) { const XlaNode* casted = dynamic_cast(node); result_ops = casted->Lower(this); - xla::internal::XlaBuilderFriend builder_friend; - auto* inst = builder_friend.GetInstruction(result_ops[0]); - auto* mutable_dynamic = - inst->mutable_shape()->mutable_is_dynamic_dimension(); - if (mutable_dynamic->empty()) { - for (int i = 0; i < inst->dimensions_size(); i++) { - mutable_dynamic->Add(false); + if (!casted->dynamic_dims().empty()) { + xla::internal::XlaBuilderFriend builder_friend; + auto* inst = builder_friend.GetInstruction(result_ops[0]); + auto* mutable_dynamic = + inst->mutable_shape()->mutable_is_dynamic_dimension(); + if (mutable_dynamic->empty()) { + for (int i = 0; i < inst->dimensions_size(); i++) { + mutable_dynamic->Add(false); + } + } + auto* mutable_dims = inst->mutable_shape()->mutable_dimensions(); + for (const auto dim : casted->dynamic_dims()) { + mutable_dynamic->Set(dim, true); + mutable_dims->Set(dim, kUnboundedSize); } - } - auto* mutable_dims = inst->mutable_shape()->mutable_dimensions(); - for (const auto dim : casted->dynamic_dims()) { - mutable_dynamic->Set(dim, true); - mutable_dims->Set(dim, kUnboundedSize); } } catch (const std::exception& ex) { ReportBuilderError(node, ex.what()); diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index 5213b8b0bf31..c09fb4cd9fb3 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -37,7 +37,7 @@ class LoweringContext : public torch::lazy::LoweringContext { // returned. Otherwise a new one will be created, associated with the tensor // held in data. xla::XlaOp GetParameter(const std::shared_ptr& data, - const std::vector& dynamic_dims = {}); + const std::unordered_set& dynamic_dims = {}); // Retrieves the vector holding all the tensors associated with the parameter // instructions which have been created. diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 17d981bc8d3a..d5f243cf5dad 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -31,6 +31,9 @@ struct SummationResult { xla::XlaOp result; }; +static const bool experimental_unbounded_dynamism = + runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); + ReductionInfo GetReductionInfo(xla::XlaOp input, const xla::Shape& shape, absl::Span dimensions, bool keep_reduced_dimensions) { @@ -81,12 +84,15 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count, xla::XlaOp scale = xla::Select(xla::Ne(count, zero), one / xla::ConvertElementType(count, type), xla::NanValue(input.builder(), type)); -#if !EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - return input * scale; -#else - auto promoted = XlaHelpers::Promote(input, scale); - return promoted.first * promoted.second; -#endif + + if (experimental_unbounded_dynamism) { + // XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now. + // TODO(lsy323): Remove this branch once the support is added in XLA. + auto promoted = XlaHelpers::Promote(input, scale); + return promoted.first * promoted.second; + } else { + return input * scale; + } } xla::XlaOp AverageValue(xla::XlaOp input, xla::XlaOp reduced) { @@ -114,13 +120,14 @@ SummationResult CreateSummation(xla::XlaOp input, result.result, result.rinfo.element_count.size, shape.element_type()); } if (keep_reduced_dimensions) { -#if !EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM - result.result = - XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions); -#else - result.result = XlaHelpers::DynamicUnboundedReshape( + if (experimental_unbounded_dynamism) { + // TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is added. + result.result = XlaHelpers::DynamicUnboundedReshape( result.result, input, result.rinfo.new_dimensions); -#endif + } else { + result.result = + XlaHelpers::DynamicReshape(result.result, result.rinfo.new_dimensions); + } } return result; }