Skip to content

Commit

Permalink
Enable unbounded dynamism using env var, add more guards for unbounde…
Browse files Browse the repository at this point in the history
…d dynamism code path
  • Loading branch information
Siyuan Liu committed Nov 13, 2023
1 parent 495f844 commit 23ee18c
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 89 deletions.
5 changes: 0 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,13 +212,8 @@ def run(self):
extra_compile_args = []
cxx_abi = os.getenv(
'CXX_ABI', default='') or getattr(torch._C, '_GLIBCXX_USE_CXX11_ABI', None)
experimental_dynamism = os.getenv(
'EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM', default=None)
if cxx_abi is not None:
extra_compile_args.append(f'-D_GLIBCXX_USE_CXX11_ABI={int(cxx_abi)}')
if experimental_dynamism is not None:
extra_compile_args.append(
f'-DEXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM={experimental_dynamism}')


class BazelExtension(Extension):
Expand Down
20 changes: 12 additions & 8 deletions torch_xla/csrc/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/random.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/shape_helper.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/xla_lower_util.h"
Expand All @@ -14,6 +15,9 @@
namespace torch_xla {
namespace {

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

xla::XlaOp Between(xla::XlaOp input, const at::Scalar& min_val,
const at::Scalar& max_val) {
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(input);
Expand Down Expand Up @@ -66,16 +70,16 @@ xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output,

xla::XlaOp BuildRelu(xla::XlaOp input) {
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
#ifndef EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
return xla::Max(input, XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder()));
#else
xla::XlaOp scalar = XlaHelpers::ScalarValue<float>(
0, input_shape.element_type(), input.builder());
auto promoted = XlaHelpers::Promote(input, scalar);

return xla::Max(promoted.first, promoted.second);
#endif
if (experimental_unbounded_dynamism) {
// xla::Max doesn't do implicit broadcasting for unbounded dynamism now.
// TODO(lsy323): Remove this branch once the support is added in XLA.
auto promoted = XlaHelpers::Promote(input, scalar);
return xla::Max(promoted.first, promoted.second);
} else {
return xla::Max(input, scalar);
}
}

xla::XlaOp BuildHardshrink(xla::XlaOp input, xla::XlaOp lambda) {
Expand Down
62 changes: 22 additions & 40 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
namespace torch_xla {
namespace {

#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM

static const bool experimental_unbounded_dynamism =
runtime::sys_util::GetEnvBool("EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM", false);

// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();
#endif

xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2,
xla::XlaOp result) {
Expand Down Expand Up @@ -67,9 +70,9 @@ xla::XlaOp XlaHelpers::BroadcastDimensions(xla::XlaOp input,
std::vector<int64_t> bcast_sizes = SizesOfXlaOp(input);
for (size_t i = 0; i < dimensions.size(); ++i) {
bcast_sizes.at(dimensions[i]) = sizes[i];
#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
XLA_CHECK(sizes[i] != kUnboundedSize);
#endif
if (experimental_unbounded_dynamism) {
XLA_CHECK(sizes[i] != kUnboundedSize);
}
}
return xla::BroadcastInDim(input, bcast_sizes,
GetAllDimensions(bcast_sizes.size()));
Expand Down Expand Up @@ -329,9 +332,9 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,
: xla::Reshape(input, shape.dimensions());
}

#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM

bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
XLA_CHECK(experimental_unbounded_dynamism)
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
const absl::Span<const int64_t> dims = shape.dimensions();
return std::any_of(dims.begin(), dims.end(),
[](int64_t size) { return size == kUnboundedSize; });
Expand All @@ -340,6 +343,8 @@ bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes) {
XLA_CHECK(experimental_unbounded_dynamism)
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
XLA_CHECK(output_sizes.size() == aux_input_shape.rank())
<< "XlaHelpers::DynamicUnboundedReshape constrainled failed!";
Expand Down Expand Up @@ -381,13 +386,17 @@ xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> aux_input_dimensions) {
XLA_CHECK(experimental_unbounded_dynamism)
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
const xla::Shape& input_shape = ShapeHelper::ShapeOfXlaOp(input);
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
bool all_static = true;
std::vector<int64_t> output_dimensions;
std::vector<bool> output_dynamic;
for (auto dim : aux_input_dimensions) {
if (aux_input_shape.dimensions(dim) == kUnboundedSize) all_static = false;
if (aux_input_shape.dimensions(dim) == kUnboundedSize) {
all_static = false;
}
output_dimensions.push_back(aux_input_shape.dimensions(dim));
output_dynamic.push_back(aux_input_shape.is_dynamic_dimension(dim));
}
Expand Down Expand Up @@ -432,13 +441,6 @@ xla::XlaOp XlaHelpers::DynamicUnboundedBroadcast(
output_dynamic));
}

void XlaHelpers::PrintXlaOp(xla::XlaOp op, const std::string& msg) {
std::cout << "Handle: " << msg << ": " << op << "\n";
const xla::Shape& shape = ShapeHelper::ShapeOfXlaOp(op);
std::cout << xla::ShapeUtil::HumanString(shape);
}
#endif

bool XlaHelpers::SameStaticDimensions(const xla::Shape& shape1,
const xla::Shape& shape2) {
return shape1.is_static() && shape2.is_static() &&
Expand Down Expand Up @@ -602,11 +604,11 @@ xla::Shape XlaHelpers::GetPromotedBinaryOpShape(const xla::Shape& shape1,
runtime::util::ToVector<int64_t>(shape1.dimensions()),
runtime::util::ToVector<int64_t>(shape2.dimensions())));
}
#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) &&
!XlaHelpers::IsUnboundedDynamic(shape2))
<< "Unreachable for unbounded dynamic code\n";
#endif
if (experimental_unbounded_dynamism) {
XLA_CHECK(!XlaHelpers::IsUnboundedDynamic(shape1) &&
!XlaHelpers::IsUnboundedDynamic(shape2))
<< "Unreachable for unbounded dynamic code\n";
}
return GetPromotedDynamicShape(shape1, shape2);
}

Expand Down Expand Up @@ -700,7 +702,6 @@ std::pair<xla::XlaOp, xla::XlaOp> XlaHelpers::PromoteSecond(xla::XlaOp op1,
return PromoteShapes(vops.first, vops.second);
}

#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
xla::XlaOp op, const xla::Shape& op_shape, xla::XlaOp aux_op,
const xla::Shape& shape) {
Expand All @@ -721,7 +722,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
std::vector<xla::XlaOp> reshaped_ops;

if (size_delta > 0) {
std::cout << "\t size_delta > 0\n";
std::vector<int64_t> broadcast_sizes(shape_dims.begin(),
shape_dims.begin() + size_delta);
for (int i = 0; i < size_delta; i++) {
Expand All @@ -730,20 +730,15 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
XlaHelpers::ScalarValue<int32_t>(broadcast_sizes[i], op.builder()));

auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
<< " for size: " << broadcast_sizes[i] << "\n";
} else {
get_dim_ops.push_back(xla::GetDimensionSize(aux_op, i));

auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
<< " for size: ? of index: " << i << "\n";
}
}
}

if (size_delta == 0) {
std::cout << "\t size_delta == 0\n";
int sz = op_shape_dims.size() - aux_shape_dims.size();
std::vector<int64_t> broadcast_sizes(shape_dims.begin(),
shape_dims.begin() + sz);
Expand All @@ -753,14 +748,10 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
XlaHelpers::ScalarValue<int32_t>(broadcast_sizes[i], op.builder()));

auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
<< " for size: " << broadcast_sizes[i] << "\n";
} else {
get_dim_ops.push_back(xla::GetDimensionSize(op, i));

auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
<< " for size: ? of index: " << i << "\n";
}
}
}
Expand Down Expand Up @@ -789,31 +780,23 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(
get_dim_ops.push_back(ScalarValue<int32_t>(shape_dim, op.builder()));

auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
<< " for size: " << shape_dim << "\n";

} else if (op_shape_dim == 1 || aux_op_shape_dim == 1) {
if (op_shape_dim == 1) {
get_dim_ops.push_back(
xla::GetDimensionSize(aux_op, aux_op_shape_index));

auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
<< " for size: ? of index: " << aux_op_shape_index << "\n";

} else {
get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index));

auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
<< " for size: ? of index: " << op_shape_index << "\n";
}
} else {
get_dim_ops.push_back(xla::GetDimensionSize(op, op_shape_index));

auto s = ShapeHelper::ShapeOfXlaOp(get_dim_ops.back());
std::cout << "implicitB shape: " << xla::ShapeUtil::HumanString(s)
<< " for size: ? of index: " << op_shape_index << "\n";
}
}

Expand All @@ -829,7 +812,6 @@ xla::XlaOp XlaHelpers::ImplicitBroadcastWithUnboundedDynamicShapes(

return new_op;
}
#endif

xla::XlaOp XlaHelpers::ImplicitBroadcast(xla::XlaOp op,
const xla::Shape& op_shape,
Expand Down
4 changes: 0 additions & 4 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,6 @@ class XlaHelpers {
static xla::XlaOp DynamicReshape(xla::XlaOp input,
absl::Span<const int64_t> output_sizes);

#if EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM
static bool IsUnboundedDynamic(const xla::Shape& shape);

static xla::XlaOp DynamicUnboundedReshape(
Expand All @@ -169,9 +168,6 @@ class XlaHelpers {
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes);

static void PrintXlaOp(xla::XlaOp op, const std::string& msg);
#endif

static xla::XlaOp DynamicReshapeAs(xla::XlaOp input, const xla::Shape& shape);

static bool SameStaticDimensions(const xla::Shape& shape1,
Expand Down
9 changes: 5 additions & 4 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include <functional>
#include <iostream>
#include <memory>
#include <set>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -139,12 +139,13 @@ class XlaNode : public torch::lazy::Node {
std::string ToString() const override;

void MarkDynamicDimension(uint32_t dim) {
dynamic_dims_.push_back(dim);
dynamic_dims_.insert(dim);
}
const std::vector<uint32_t>& dynamic_dims() const { return dynamic_dims_; }

const std::unordered_set<uint32_t>& dynamic_dims() const { return dynamic_dims_; }

protected:
std::vector<uint32_t> dynamic_dims_;
std::unordered_set<uint32_t> dynamic_dims_;

private:
xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;
Expand Down
36 changes: 21 additions & 15 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,19 @@ LoweringContext::LoweringContext(
}
}

// import xla::Shape.h to inlcude the following defintion.
// TODO(lsy323): Get reserved number for unbounded dim after it's added in XLA.
static constexpr int64_t kUnboundedSize = std::numeric_limits<int64_t>::min();

xla::XlaOp LoweringContext::GetParameter(
const std::shared_ptr<torch::lazy::BackendData>& data,
const std::vector<uint32_t>& dynamic_dims) {
const std::unordered_set<uint32_t>& unbounded_dynamic_dims) {
torch::lazy::BackendData::Handle handle = data->GetHandle();
auto it = parameters_map_.find(handle);
if (it == parameters_map_.end()) {
xla::Shape shape =
std::dynamic_pointer_cast<runtime::ComputationClient::Data>(data)
->shape();
for (const int dim : dynamic_dims) {
for (const int dim : unbounded_dynamic_dims) {
shape.set_dynamic_dimension(dim, true);
shape.set_dimensions(dim, kUnboundedSize);
}
Expand All @@ -113,6 +114,9 @@ xla::XlaOp LoweringContext::GetParameter(
it = parameters_map_.emplace(handle, Parameter{param, parameters_.size()})
.first;
parameters_.push_back(data);
} else {
XLA_CHECK(unbounded_dynamic_dims.empty())
<< "The unbounded dynamic dims can only be set when Parameter is created.";
}
parameter_sequence_.push_back(it->second.index);
return it->second.param;
Expand Down Expand Up @@ -177,19 +181,21 @@ XlaOpVector LoweringContext::LowerNode(const torch::lazy::Node* node) {

const XlaNode* casted = dynamic_cast<const XlaNode*>(node);
result_ops = casted->Lower(this);
xla::internal::XlaBuilderFriend builder_friend;
auto* inst = builder_friend.GetInstruction(result_ops[0]);
auto* mutable_dynamic =
inst->mutable_shape()->mutable_is_dynamic_dimension();
if (mutable_dynamic->empty()) {
for (int i = 0; i < inst->dimensions_size(); i++) {
mutable_dynamic->Add(false);
if (!casted->dynamic_dims().empty()) {
xla::internal::XlaBuilderFriend builder_friend;
auto* inst = builder_friend.GetInstruction(result_ops[0]);
auto* mutable_dynamic =
inst->mutable_shape()->mutable_is_dynamic_dimension();
if (mutable_dynamic->empty()) {
for (int i = 0; i < inst->dimensions_size(); i++) {
mutable_dynamic->Add(false);
}
}
auto* mutable_dims = inst->mutable_shape()->mutable_dimensions();
for (const auto dim : casted->dynamic_dims()) {
mutable_dynamic->Set(dim, true);
mutable_dims->Set(dim, kUnboundedSize);
}
}
auto* mutable_dims = inst->mutable_shape()->mutable_dimensions();
for (const auto dim : casted->dynamic_dims()) {
mutable_dynamic->Set(dim, true);
mutable_dims->Set(dim, kUnboundedSize);
}
} catch (const std::exception& ex) {
ReportBuilderError(node, ex.what());
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class LoweringContext : public torch::lazy::LoweringContext {
// returned. Otherwise a new one will be created, associated with the tensor
// held in data.
xla::XlaOp GetParameter(const std::shared_ptr<torch::lazy::BackendData>& data,
const std::vector<uint32_t>& dynamic_dims = {});
const std::unordered_set<uint32_t>& dynamic_dims = {});

// Retrieves the vector holding all the tensors associated with the parameter
// instructions which have been created.
Expand Down
Loading

0 comments on commit 23ee18c

Please sign in to comment.