Skip to content

Commit

Permalink
rename unbounded dynamism private number, and str serializaiton
Browse files Browse the repository at this point in the history
  • Loading branch information
Siyuan Liu authored and lsy323 committed Nov 15, 2023
1 parent 6b8072d commit f8927e6
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
6 changes: 4 additions & 2 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ xla::XlaOp XlaHelpers::DynamicReshapeAs(xla::XlaOp input,

bool XlaHelpers::IsUnboundedDynamic(const xla::Shape& shape) {
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
<< "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded "
"dynamism workload.";
const absl::Span<const int64_t> dims = shape.dimensions();
return std::any_of(dims.begin(), dims.end(),
[](int64_t size) { return size == kUnboundedSize; });
Expand All @@ -339,7 +340,8 @@ xla::XlaOp XlaHelpers::DynamicUnboundedReshape(
xla::XlaOp input, xla::XlaOp aux_input,
absl::Span<const int64_t> output_sizes) {
XLA_CHECK(XlaHelpers::IsUnboundedDynamismEnabled())
<< "EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM needs to be turned on.";
<< "set EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM=1 to run any unbounded "
"dynamism workload.";
const xla::Shape& aux_input_shape = ShapeHelper::ShapeOfXlaOp(aux_input);
XLA_CHECK(output_sizes.size() == aux_input_shape.rank())
<< "XlaHelpers::DynamicUnboundedReshape constrainled failed!";
Expand Down
6 changes: 2 additions & 4 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,8 @@ xla::Shape XlaNode::GetOpShape(
std::string XlaNode::ToString() const {
std::stringstream ss;
ss << torch::lazy::Node::ToString() << ", xla_shape=" << xla_shape_;
ss << ", dynamic_dims: ";
for (const auto dim : dynamic_dims_) {
ss << dim;
}
ss << ", dynamic_dims: (" << absl::StrJoin(unbounded_dynamic_dims_, ", ")
<< ')';
return ss.str();
}

Expand Down
8 changes: 5 additions & 3 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,14 +138,16 @@ class XlaNode : public torch::lazy::Node {

std::string ToString() const override;

void MarkDynamicDimension(uint32_t dim) { dynamic_dims_.insert(dim); }
void MarkDynamicDimension(uint32_t dim) {
unbounded_dynamic_dims_.insert(dim);
}

const std::unordered_set<uint32_t>& dynamic_dims() const {
return dynamic_dims_;
return unbounded_dynamic_dims_;
}

protected:
std::unordered_set<uint32_t> dynamic_dims_;
std::unordered_set<uint32_t> unbounded_dynamic_dims_;

private:
xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/device_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ torch::lazy::NodePtr DeviceData::Clone(torch::lazy::OpList operands) const {
}

XlaOpVector DeviceData::Lower(LoweringContext* loctx) const {
return ReturnOp(loctx->GetParameter(data_, dynamic_dims_), loctx);
return ReturnOp(loctx->GetParameter(data_, unbounded_dynamic_dims_), loctx);
}

DeviceData* DeviceData::Cast(const torch::lazy::Node* node) {
Expand Down

0 comments on commit f8927e6

Please sign in to comment.