Skip to content

Commit

Permalink
NFC: Attribute cleanup (remove references of attributes) (llvm#286)
Browse files Browse the repository at this point in the history
* Define krnl.permute op.

* Support krnl.permute operation.

* Properly remove loop references.

* Re-push, Github was down.

* Need to debug interpretOp error.

* Fix lowering bug by erasing ops after full krnl IR interpretation is done, and clean up & comment code.

* Introduce permute, unroll operations.

* More debug.

* Remove std::set.

* krnl.terminate fails to be converted.

* Pass all tests, need to add legal ops as well as part of the conversion target.

* Change test format to new permute spec.

* Bug fix for nested iterate op lowering.

* Simplify error reporting.

* Fix compilation error.

* Increase comments coverage.

* Remove unnecessary imports.

* Re-trigger Jenkins

* Add permute/unroll tests.

* Retrigger Jenkins

* remove & (ref) for Attributes

Co-authored-by: Tian Jin <tjingrant@gmail.com>
  • Loading branch information
AlexandreEichenberger and tjingrant authored Aug 31, 2020
1 parent 8bfde7d commit c1262c1
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class FrontendGenImpl {
}

mlir::NamedAttribute convertOnnxAttributeProtoToMlirNamedAttribute(
onnx::AttributeProto &attr) {
onnx::AttributeProto attr) {
mlir::Attribute mlirAttr;
switch (attr.type()) {
case onnx::AttributeProto::FLOAT:
Expand Down
4 changes: 2 additions & 2 deletions src/Transform/ONNX/Combine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace {

/// Compute the combined permute pattern from a pair of permute patterns.
ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,
ArrayAttr &firstPermAttr, ArrayAttr &secondPermAttr) {
ArrayAttr firstPermAttr, ArrayAttr secondPermAttr) {
// Read first permute vectors.
SmallVector<int64_t, 4> initialPerm;
for (auto firstPermVal : firstPermAttr.getValue())
Expand All @@ -44,7 +44,7 @@ ArrayAttr CombinedTransposePattern(PatternRewriter &rewriter,

/// Test if the permute pattern correspond to an identity pattern.
/// Identity patterns are {0, 1, 2, ... , rank -1}.
bool IsIdentityPermuteVector(ArrayAttr &permAttr) {
bool IsIdentityPermuteVector(ArrayAttr permAttr) {
int64_t currentIndex = 0;
for (auto permVal : permAttr.getValue())
if (permVal.cast<IntegerAttr>().getInt() != currentIndex++)
Expand Down
40 changes: 20 additions & 20 deletions src/Transform/ONNX/ConstProp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ namespace {

template <typename OP>
Attribute ComputeConstPropElementwiseBinary(PatternRewriter &rewriter,
Type elementType, Attribute &lhsAttr, Attribute &secondAttr) {
Type elementType, Attribute lhsAttr, Attribute secondAttr) {
llvm_unreachable("unkonwn operation");
}

template <>
Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
Attribute secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
Expand All @@ -86,8 +86,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXAddOp>(

template <>
Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
Attribute secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
Expand All @@ -105,8 +105,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXSubOp>(

template <>
Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
Attribute secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
Expand All @@ -124,8 +124,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXMulOp>(

template <>
Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(
PatternRewriter &rewriter, Type elementType, Attribute &lhsAttr,
Attribute &secondAttr) {
PatternRewriter &rewriter, Type elementType, Attribute lhsAttr,
Attribute secondAttr) {
if (elementType.isa<FloatType>()) {
double lhsVal = lhsAttr.cast<FloatAttr>().getValueAsDouble();
double rhsVal = secondAttr.cast<FloatAttr>().getValueAsDouble();
Expand Down Expand Up @@ -154,8 +154,8 @@ Attribute ComputeConstPropElementwiseBinary<ONNXDivOp>(

template <typename ElementwiseBinaryOp>
void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &lhsAttr,
DenseElementsAttr &rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
std::vector<Attribute> &resVector, DenseElementsAttr lhsAttr,
DenseElementsAttr rhsAttr, SmallVector<uint64_t, 4> &lhsIndices,
SmallVector<uint64_t, 4> &rhsIndices, int lhsFreeRank, int rhsFreeRank) {
if (lhsFreeRank == 0) {
// Fully defined ranks.
Expand Down Expand Up @@ -222,7 +222,7 @@ void RecurseConstPropElementwiseBinary(PatternRewriter &rewriter,
// generate the new constant operation.
template <typename ElementwiseBinaryOp>
DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,
Value resOperand, Attribute &lhsAttr, Attribute &rhsAttr) {
Value resOperand, Attribute lhsAttr, Attribute rhsAttr) {
DenseElementsAttr lhsDenseAttr =
lhsAttr.dyn_cast_or_null<mlir::DenseElementsAttr>();
DenseElementsAttr rhsDenseAttr =
Expand All @@ -248,13 +248,13 @@ DenseElementsAttr ConstPropElementwiseBinary(PatternRewriter &rewriter,

template <typename OP>
Attribute ComputeConstPropElementwiseUnary(
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
PatternRewriter &rewriter, Type elementType, Attribute attr) {
llvm_unreachable("unkonwn operation");
}

template <>
Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
PatternRewriter &rewriter, Type elementType, Attribute attr) {
if (elementType.isa<FloatType>()) {
double val = attr.cast<FloatAttr>().getValueAsDouble();
double res = -val;
Expand All @@ -270,7 +270,7 @@ Attribute ComputeConstPropElementwiseUnary<ONNXNegOp>(

template <>
Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(
PatternRewriter &rewriter, Type elementType, Attribute &attr) {
PatternRewriter &rewriter, Type elementType, Attribute attr) {
if (elementType.isa<FloatType>()) {
double val = attr.cast<FloatAttr>().getValueAsDouble();
double res = sqrt(val);
Expand All @@ -281,7 +281,7 @@ Attribute ComputeConstPropElementwiseUnary<ONNXSqrtOp>(

template <typename ElementwiseUnaryOp>
void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
std::vector<Attribute> &resVector, DenseElementsAttr attr,
SmallVector<uint64_t, 4> &indices, int freeRank) {
if (freeRank == 0) {
// Fully defined ranks.
Expand All @@ -308,7 +308,7 @@ void RecurseConstPropElementwiseUnary(PatternRewriter &rewriter,
// generate the new constant operation.
template <typename ElementwiseUnaryOp>
DenseElementsAttr ConstPropElementwiseUnary(
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
PatternRewriter &rewriter, Value resOperand, Attribute attr) {
DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
assert(denseAttr && "expected dense attribute");
Expand All @@ -329,7 +329,7 @@ DenseElementsAttr ConstPropElementwiseUnary(
//===----------------------------------------------------------------------===//

void RecurseConstPropTranspose(PatternRewriter &rewriter,
std::vector<Attribute> &resVector, DenseElementsAttr &attr,
std::vector<Attribute> &resVector, DenseElementsAttr attr,
SmallVector<uint64_t, 4> &indices, SmallVector<uint64_t, 4> &perm,
int freeRank) {
if (freeRank == 0) {
Expand All @@ -351,7 +351,7 @@ void RecurseConstPropTranspose(PatternRewriter &rewriter,
}

DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
Value resOperand, Attribute &attr, ArrayAttr &permAttr) {
Value resOperand, Attribute attr, ArrayAttr permAttr) {
// Read dense attribute, the constant tensor we are transforming.
DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
Expand All @@ -378,7 +378,7 @@ DenseElementsAttr ConstPropTranspose(PatternRewriter &rewriter,
//===----------------------------------------------------------------------===//

DenseElementsAttr ConstPropUnsqueeze(
PatternRewriter &rewriter, Value resOperand, Attribute &attr) {
PatternRewriter &rewriter, Value resOperand, Attribute attr) {
// Read dense attribute, the constant tensor we are transforming.
DenseElementsAttr denseAttr =
attr.dyn_cast_or_null<mlir::DenseElementsAttr>();
Expand Down

0 comments on commit c1262c1

Please sign in to comment.