Skip to content

Commit

Permalink
Get unbounded dynamism flag from function
Browse files Browse the repository at this point in the history
  • Loading branch information
lsy323 committed Nov 14, 2023
1 parent 73ec1c5 commit f2e65aa
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 40 deletions.
46 changes: 23 additions & 23 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,25 @@ python_configure(
# b) get the sha256 hash of the commit by running:
# curl -L https://github.com/openxla/xla/archive/<git hash>.tar.gz | sha256sum
# and update the sha256 with the result.
http_archive(
name = "xla",
patch_args = [
"-l",
"-p1",
],
patch_tool = "patch",
patches = [
"//openxla_patches:cache_urls.diff",
"//openxla_patches:constexpr_return.diff",
"//openxla_patches:gpu_race_condition.diff",
"//openxla_patches:f16_abi_clang.diff",
"//openxla_patches:gpu_topk_rewriter.diff",
],
strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478",
urls = [
"https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz",
],
)
# http_archive(
# name = "xla",
# patch_args = [
# "-l",
# "-p1",
# ],
# patch_tool = "patch",
# patches = [
# "//openxla_patches:cache_urls.diff",
# "//openxla_patches:constexpr_return.diff",
# "//openxla_patches:gpu_race_condition.diff",
# "//openxla_patches:f16_abi_clang.diff",
# "//openxla_patches:gpu_topk_rewriter.diff",
# ],
# strip_prefix = "xla-4f8381651977dff16b1d86bb4b198eb733c5f478",
# urls = [
# "https://github.com/openxla/xla/archive/4f8381651977dff16b1d86bb4b198eb733c5f478.tar.gz",
# ],
# )

# For development, one often wants to make changes to the OpenXLA repository as well
# as the PyTorch/XLA repository. You can override the pinned repository above with a
Expand All @@ -58,10 +58,10 @@ http_archive(
# bazel --override_repository=xla=/path/to/openxla
# or
# b) by commenting out the http_archive above and uncommenting the following:
# local_repository(
# name = "xla",
# path = "/path/to/openxla",
# )
local_repository(
name = "xla",
path = "/home/lsiyuan/work/xla",
)

# Initialize OpenXLA's external dependencies.
load("@xla//:workspace4.bzl", "xla_workspace4")
Expand Down
5 changes: 1 addition & 4 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
namespace torch_xla {
namespace {

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

xla::XlaOp Between(xla::XlaOp input, const at::Scalar& min_val,
const at::Scalar& max_val) {
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input);
Expand Down Expand Up @@ -72,7 +69,7 @@ xla::XlaOp BuildRelu(xla::XlaOp input) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
xla::XlaOp scalar = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder());
if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// xla::Max doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scalar);
Expand Down
12 changes: 4 additions & 8 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "torch_xla/csrc/convert_ops.h"
#include "torch_xla/csrc/dtype.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/tf_logging.h"
#include "torch_xla/csrc/runtime/util.h"
#include "torch_xla/csrc/shape_helper.h"
Expand All @@ -21,9 +20,6 @@
namespace torch_xla {
namespace {

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

Expand Down Expand Up @@ -69,7 +65,7 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input,
std::vector<int64_t> bcast_sizes = SizesOfXlaOp(input);
for (size_t i = 0; i < dimensions.size(); ++i) {
bcast_sizes.at(dimensions[i]) = sizes[i];
if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
XLA_CHECK(sizes[i] != kUnboundedSize);
}
}
Expand Down Expand Up @@ -332,7 +328,7 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
}

bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
XLA_CHECK(experimental_unbounded_dynamism)
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
const absl::Span<const int64_t> dims = shape.dimensions();
return std::any_of(dims.begin(), dims.end(),
Expand All @@ -342,7 +338,7 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes) {
XLA_CHECK(experimental_unbounded_dynamism)
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
XLA_CHECK(output_sizes.size() == aux_input_shape.rank())
Expand Down Expand Up @@ -545,7 +541,7 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
runtime::util::ToVector<int64_t>(shape1.dimensions()),
runtime::util::ToVector<int64_t>(shape2.dimensions())));
}
if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) &&
!XlaHelpers::IsUnboundedDynamic(shape2))
<< "Unreachable for unbounded dynamic code\n";
Expand Down
6 changes: 6 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/util.h"
#include "tsl/platform/bfloat16.h"
#include "xla/client/xla_builder.h"
Expand Down Expand Up @@ -160,6 +161,11 @@ class XlaHelpers {

static bool IsUnboundedDynamic(const xla::Shape& shape);

static bool IsUnboundedDynamismEnabled() {
return runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM",
false);
}

static xla::XlaOp DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes);
Expand Down
7 changes: 2 additions & 5 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ struct SummationResult {
xla::XlaOp result;
};

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

ReductionInfo GetReductionInfo(xla::XlaOp input, const xla::Shape& shape,
absl::Span<const int64_t> dimensions,
bool keep_reduced_dimensions) {
Expand Down Expand Up @@ -85,7 +82,7 @@ xla::XlaOp GetScaleValue(xla::XlaOp input, xla::XlaOp count,
one / xla::ConvertElementType(count, type),
xla::NanValue(input.builder(), type));

if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// XLA Multiply doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scale);
Expand Down Expand Up @@ -120,7 +117,7 @@ SummationResult CreateSummation(xla::XlaOp input,
result.result, result.rinfo.element_count.size, shape.element_type());
}
if (keep_reduced_dimensions) {
if (experimental_unbounded_dynamism) {
if (XlaHelpers::IsUnboundedDynamismEnabled()) {
// TODO(lsy323): Use XLA DynamicReshape once unbounded dynamism support is
// added.
result.result = XlaHelpers::DynamicUnboundedReshape(
Expand Down

0 comments on commit f2e65aa

Please sign in to comment.