diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 61f634904d3f..995e43078b69 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -329,7 +329,8 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input, bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) { XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) - << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; + << "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded " + "dynamism workload."; const absl::Span dims = shape.dimensions(); return std::any_of(dims.begin(), dims.end(), [](int64_t size) { return size == kUnboundedSize; }); @@ -339,7 +340,8 @@ xla::XlaOp XlaHelpers::DynamicUnboundedReshape( xla::XlaOp input, xla::XlaOp aux_input, absl::Span output_sizes) { XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled()) - << "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on."; + << "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded " + "dynamism workload."; const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input); XLA_CHECK(output_sizes.size() == aux_input_shape.rank()) << "XlaHelpers::DynamicUnboundedReshape constrainled failed!"; diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 2c0acd9927e1..82b746ab1813 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -174,10 +174,8 @@ xla::Shape XlaNode::GetOpShape( std::string XlaNode::ToString() const { std::stringstream ss; ss << torch::lazy::Node::ToString() << ", xla_shape=" << xla_shape_; - ss << ", dynamic_dims: "; - for (const auto dim : dynamic_dims_) { - ss << dim; - } + ss << ", dynamic_dims: (" << absl::StrJoin(unbounded_dynamic_dims_, ", ") + << ')'; return ss.str(); } diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 8c852573730c..d0619ef5c987 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -138,14 +138,16 @@ class XlaNode : public torch::lazy::Node { std::string ToString() const override; - void MarkDynamicDimension(uint32_t dim) { dynamic_dims_.insert(dim); } + void MarkDynamicDimension(uint32_t dim) { + unbounded_dynamic_dims_.insert(dim); + } const std::unordered_set& dynamic_dims() const { - return dynamic_dims_; + return unbounded_dynamic_dims_; } protected: - std::unordered_set dynamic_dims_; + std::unordered_set unbounded_dynamic_dims_; private: xla::Shape GetOpShape(const std::function& shape_fn) const; diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index 54c455e3b663..e07fe3c4e769 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -36,7 +36,7 @@ torch::lazy::NodePtr DeviceData::Clone(torch::lazy::OpList operands) const { } XlaOpVector DeviceData::Lower(LoweringContext* loctx) const { - return ReturnOp(loctx->GetParameter(data_, dynamic_dims_), loctx); + return ReturnOp(loctx->GetParameter(data_, unbounded_dynamic_dims_), loctx); } DeviceData* DeviceData::Cast(const torch::lazy::Node* node) {