Skip to content

Commit

Permalink
[PIR] Add ShapeAnalysis, and save() into SymDimMgr. (PaddlePaddle#57029)
Browse files Browse the repository at this point in the history
* add ShapeAnalysis、Mgr::save

* add funcOp.

* access ShapedType.

* remove const_cast.
  • Loading branch information
liuruyan authored Sep 14, 2023
1 parent 4698b3d commit b1ad1ec
Show file tree
Hide file tree
Showing 7 changed files with 757 additions and 119 deletions.
6 changes: 6 additions & 0 deletions paddle/pir/core/builtin_type_interfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ class ShapedTypeInterface : public pir::TypeInterfaceBase<ShapedTypeInterface> {
[](int64_t dSize) { return isDynamic(dSize); });
}

/// Check whether shape has any size indicating a dynamic dimension.
bool hasStaticShape() const {
return (*this).hasRank() &&
!pir::ShapedTypeInterface::isDynamicShape((*this).getShape());
}

/// Check whether the given dimension has a dynamic size.
/// Aborts for unranked types.
bool isDynamicDim(unsigned idx) const {
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/dialect/shape/ir/shape_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ShapeDialect::ShapeDialect(IrContext *context)
}

void ShapeDialect::initialize() {
RegisterOps<SymbolicDim, DimOp, TieProductEqualOp>();
RegisterOps<SymbolicDim, DimOp, TieProductEqualOp, TieShapeOp, FuncOp>();
}

} // namespace dialect
Expand Down
59 changes: 48 additions & 11 deletions paddle/pir/dialect/shape/ir/shape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@ const char *SymbolicDim::attributes_name[attributes_num] = {"knownNegativeOne",
"sym_name",
"value"}; // NOLINT

void SymbolicDim::Build(
Builder &builder,
OperationArgument &argument,
const std::string &sym_name,
int64_t value, // TODO(zhangbo) value = ShapedType::kDynamic
bool knownNonNegative,
bool knownNegativeOne,
bool knownNonSizeOne,
bool knownNonSizeZero) {
void SymbolicDim::Build(Builder &builder,
OperationArgument &argument,
const std::string &sym_name,
int64_t value,
bool knownNonNegative,
bool knownNegativeOne,
bool knownNonSizeOne,
bool knownNonSizeZero) {
pir::Attribute attr_sym_name =
pir::StrAttribute::get(pir::IrContext::Instance(), sym_name);
argument.AddAttribute("sym_name", attr_sym_name);
Expand Down Expand Up @@ -106,8 +105,8 @@ void SymbolicDim::updateKnownNonSizeZero(bool attrValue) {
}

bool SymbolicDim::isDynamic() {
return getValue() == -100000;
} // TODO(zhangbo): getValue() == ShapedType::kDynamic;
return getValue() == ShapedTypeInterface::kDynamic;
}

bool SymbolicDim::merge(SymbolicDim other) {
if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue())
Expand Down Expand Up @@ -173,6 +172,21 @@ void TieProductEqualOp::Build(Builder &builder,
argument.inputs = inputs;
}

void TieProductEqualOp::Build(Builder &builder,
OperationArgument &argument,
const std::vector<pir::OpResult> &lhs,
const std::vector<pir::OpResult> &rhs) {
pir::Attribute attr_lhs_len =
pir::Int64Attribute::get(pir::IrContext::Instance(), lhs.size());
argument.AddAttribute("lhs_len", attr_lhs_len);
pir::Attribute attr_rhs_len =
pir::Int64Attribute::get(pir::IrContext::Instance(), rhs.size());
argument.AddAttribute("rhs_len", attr_rhs_len);

argument.inputs = lhs;
argument.inputs.insert(argument.inputs.end(), rhs.begin(), rhs.end());
}

std::vector<pir::Value> TieProductEqualOp::getLhs() {
int64_t lhs_len = attribute<pir::Int64Attribute>("lhs_len").data();
std::vector<pir::Value> res;
Expand All @@ -191,9 +205,32 @@ std::vector<pir::Value> TieProductEqualOp::getRhs() {
return res;
}

const char *TieShapeOp::attributes_name[attributes_num] = {
SymbolicDim::getSymbolicDimAttrName().c_str()}; // NOLINT

void TieShapeOp::Build(Builder &builder,
OperationArgument &argument,
const pir::OpResult &input) {
argument.inputs = {input};
}

pir::Value TieShapeOp::getValue() { return operand_source(0); }

void FuncOp::Build(Builder &builder, OperationArgument &argument) {
argument.num_regions = 1;
}

pir::Block *FuncOp::block() {
pir::Region &region = (*this)->region(0);
if (region.empty()) region.emplace_back();
return region.front();
}

} // namespace dialect
} // namespace pir

IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::SymbolicDim)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::DimOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp)
58 changes: 48 additions & 10 deletions paddle/pir/dialect/shape/ir/shape_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/builtin_type_interfaces.h"
#include "paddle/pir/core/op_base.h"

namespace pir {
Expand All @@ -28,15 +29,14 @@ class IR_API SymbolicDim : public Op<SymbolicDim> {
static constexpr uint32_t attributes_num = 6;
static const char *attributes_name[attributes_num];

static void Build(
Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::string &sym_name,
int64_t value = -100000, // TODO(zhangbo): value = ShapedType::kDynamic
bool knownNonNegative = false,
bool knownNegativeOne = false,
bool knownNonSizeOne = false,
bool knownNonSizeZero = false);
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::string &sym_name,
int64_t value = ShapedTypeInterface::kDynamic,
bool knownNonNegative = false,
bool knownNegativeOne = false,
bool knownNonSizeOne = false,
bool knownNonSizeZero = false);
const std::string getSymName();
int64_t getValue();
bool getKnownNonNegative();
Expand All @@ -54,6 +54,10 @@ class IR_API SymbolicDim : public Op<SymbolicDim> {
bool isDynamic();
bool merge(SymbolicDim other);

static const std::string getSymbolicDimAttrName() {
return "SymbolicDimAttr";
}

void Verify() {}
};

Expand Down Expand Up @@ -82,20 +86,54 @@ class IR_API TieProductEqualOp : public Op<TieProductEqualOp> {

static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
// attr operand_segment_sizes

static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
int64_t lhs_len,
int64_t rhs_len,
const std::vector<pir::OpResult> &inputs);
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const std::vector<pir::OpResult> &lhs,
const std::vector<pir::OpResult> &rhs);
std::vector<pir::Value> getLhs();
std::vector<pir::Value> getRhs();
void Verify() {}
};

class IR_API TieShapeOp : public Op<TieShapeOp> {
public:
using Op::Op;
static const char *name() { return "shape.tie_shape"; }

static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];

static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const pir::OpResult &input);
pir::Value getValue();
void Verify() {}
};

class IR_API FuncOp : public Op<FuncOp> {
public:
using Op::Op;
static const char *name() { return "shape.func"; }

static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;

static void Build(Builder &builder, // NOLINT
OperationArgument &argument); // NOLINT
pir::Block *block();
void Verify() {}
};
} // namespace dialect
} // namespace pir

IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::SymbolicDim);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::DimOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp);
Loading

0 comments on commit b1ad1ec

Please sign in to comment.