Skip to content

Commit

Permalink
Get unbounded dynamism flag from function
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 committed Nov 14, 2023
1 parent 73ec1c5 commit 5459950
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 17 deletions.
5 changes: 1 addition & 4 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
namespace torch_xla {
namespace {

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

xla::XlaOp Between(xla::XlaOp input, const at::Scalar& min_val,
const at::Scalar& max_val) {
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input);
Expand Down Expand Up @@ -72,7 +69,7 @@ xla::XlaOp BuildRelu(xla::XlaOp input) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp scalar = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder());
if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// xla::Max doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scalar);
Expand Down
12 changes: 4 additions & 8 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/shape_helper.h"
Expand All @@ -21,9 +20,6 @@
namespace torch_xla {
namespace {

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

Expand Down Expand Up @@ -69,7 +65,7 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input,
std::vector<int64_t> bcast_sizes = SizesOfXlaOp(input);
for (size_t i = 0; i < dimensions.size(); ++i) {
bcast_sizes.at(dimensions[i]) = sizes[i];
if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
XLA_CHECK(sizes[i] != kUnboundedSize);
}
}
Expand Down Expand Up @@ -332,7 +328,7 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
}

bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
XLA_CHECK(experimental_unbounded_dynamism)
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
const absl::Span<const int64_t> dims = shape.dimensions();
return std::any_of(dims.begin(), dims.end(),
Expand All @@ -342,7 +338,7 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes) {
XLA_CHECK(experimental_unbounded_dynamism)
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
XLA_CHECK(output_sizes.size() == aux_input_shape.rank())
Expand Down Expand Up @@ -545,7 +541,7 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
runtime::util::ToVector<int64_t>(shape1.dimensions()),
runtime::util::ToVector<int64_t>(shape2.dimensions())));
}
if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) &&
!XlaHelpers::IsUnboundedDynamic(shape2))
<< "Unreachable for unbounded dynamic code\n";
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/util.h"
#include "tsl/platform/bfloat16.h"
#include "xla/client/xla_builder.h"
Expand Down Expand Up @@ -160,6 +161,11 @@ class XlaHelpers {

static bool IsUnboundedDynamic(const xla::Shape& shape);

static bool IsUnboundedDynamismEnabled() {
return runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM",
false);
}

static xla::XlaOp DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes);
Expand Down
7 changes: 2 additions & 5 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ struct SummationResult {
xla::XlaOp result;
};

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

ReductionInfo GetReductionInfo(xla::XlaOp input, const xla::Shape& shape,
absl::Span<const int64_t> dimensions,
bool keep_reduced_dimensions) {
Expand Down Expand Up @@ -85,7 +82,7 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count,
one / xla::ConvertElementType(count, type),
xla::NanValue(input.builder(), type));

if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scale);
Expand Down Expand Up @@ -120,7 +117,7 @@ SummationResult CreateSummation(xla::XlaOp input,
result.result, result.rinfo.element_count.size, shape.element_type());
}
if (keep_reduced_dimensions) {
if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is
// added.
result.result = XlaHelpers::DynamicUnboundedReshape(
Expand Down

0 comments on commit 5459950

Please sign in to comment.