From 54599508db565a30dcaa1d27dd089d0fb4b897b4 Mon Sep 17 00:00:00 2001 From: Siyuan Liu Date: Tue, 14 Nov 2023 22:43:46 +0000 Subject: [PATCH] Get unbounded dynamism flag from function --- torch_xla/csrc/elementwise.cpp | 5 +---- torch_xla/csrc/helpers.cpp | 12 ++++-------- torch_xla/csrc/helpers.h | 6 ++++++ torch_xla/csrc/reduction.cpp | 7 ++----- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index 027b80a61707..62a906d98eac 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -15,9 +15,6 @@ namespace torch_xla { namespace { -static const bool experimental_unbounded_dynamism = - runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); - xla::XlaOp Between(xla::XlaOp input, const at::Scalar& min_val, const at::Scalar& max_val) { const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input); @@ -72,7 +69,7 @@ xla::XlaOp BuildRelu(xla::XlaOp input) { const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input); xla::XlaOp scalar = XlaHelpers::ScalarValue( 0, input_shape.element_type(), input.builder()); - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { // xla::Max doesn't do implicit broadcasting for unbounded dynamism now. // TODO(lsy323): Remove this branch once the support is added in XLA. auto promoted = XlaHelpers::Promote(input, scalar); diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index eeec940b8b9f..61f634904d3f 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -9,7 +9,6 @@ #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/dtype.h" #include "torch_xla/csrc/runtime/debug_macros.h" -#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/tf_logging.h" #include "torch_xla/csrc/runtime/util.h" #include "torch_xla/csrc/shape_helper.h" @@ -21,9 +20,6 @@ namespace torch_xla { namespace { -static const bool experimental_unbounded_dynamism = - runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); - // TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA. static constexpr int64_t kUnboundedSize = std::numeric_limits::min(); @@ -69,7 +65,7 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input, std::vector bcast_sizes = SizesOfXlaOp(input); for (size_t i = 0; i < dimensions.size(); ++i) { bcast_sizes.at(dimensions[i]) = sizes[i]; - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { XLA_CHECK(sizes[i] != kUnboundedSize); } } @@ -332,7 +328,7 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, } bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { - XLA_CHECK(experimental_unbounded_dynamism) + XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const absl::Span dims = shape.dimensions(); return std::any_of(dims.begin(), dims.end(), @@ -342,7 +338,7 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { xla::XlaOp XlaHelpers::DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes) { - XLA_CHECK(experimental_unbounded_dynamism) + XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) @@ -545,7 +541,7 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1, runtime::util::ToVector(shape1.dimensions()), runtime::util::ToVector(shape2.dimensions()))); } - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) && !XlaHelpers::IsUnboundedDynamic(shape2)) << "Unreachable for unbounded dynamic code\n"; diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index fdbdd4287d04..66c01588b57f 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -13,6 +13,7 @@ #include "absl/types/optional.h" #include "absl/types/span.h" #include "torch_xla/csrc/runtime/debug_macros.h" +#include "torch_xla/csrc/runtime/sys_util.h" #include "torch_xla/csrc/runtime/util.h" #include "tsl/platform/bfloat16.h" #include "xla/client/xla_builder.h" @@ -160,6 +161,11 @@ class XlaHelpers { static bool IsUnboundedDynamic(const xla::Shape& shape); + static bool IsUnboundedDynamismEnabled() { + return runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", + false); + } + static xla::XlaOp DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes); diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 336123a542e3..aff467434103 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -31,9 +31,6 @@ struct SummationResult { xla::XlaOp result; }; -static const bool experimental_unbounded_dynamism = - runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false); - ReductionInfo GetReductionInfo(xla::XlaOp input, const xla::Shape& shape, absl::Span dimensions, bool keep_reduced_dimensions) { @@ -85,7 +82,7 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count, one / xla::ConvertElementType(count, type), xla::NanValue(input.builder(), type)); - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { // XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now. // TODO(lsy323): Remove this branch once the support is added in XLA. auto promoted = XlaHelpers::Promote(input, scale); @@ -120,7 +117,7 @@ SummationResult CreateSummation(xla::XlaOp input, result.result, result.rinfo.element_count.size, shape.element_type()); } if (keep_reduced_dimensions) { - if (experimental_unbounded_dynamism) { + if (XlaHelpers::IsUnboundedDynamismEnabled()) { // TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is // added. result.result = XlaHelpers::DynamicUnboundedReshape(