Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Moore] [Canonicalizer] Lower struct-related assignOp #7341

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c423d1e
Using pass canonicalizers to lower variableOp to structCreateOp
mingzheTerapines Jul 18, 2024
65c065d
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Jul 22, 2024
775b3ed
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Jul 24, 2024
35a57e6
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Jul 25, 2024
d461289
Lower to structCreatOp and structInjectOp
mingzheTerapines Jul 25, 2024
808f546
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Jul 26, 2024
05aacc1
Modify structCreateOp result type to unpacked type
mingzheTerapines Jul 26, 2024
a6e507b
Add default constant value 0 for structCreateOp
mingzheTerapines Jul 26, 2024
d54fdd4
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Jul 29, 2024
79c1db1
Merge remote-tracking branch 'upstream/main' into mingzhe-structCreate
mingzheTerapines Jul 31, 2024
3b4ab12
Merge main
mingzheTerapines Jul 31, 2024
ab3ed8f
Fit new version of struct
mingzheTerapines Jul 31, 2024
db8f16b
Remove typeswitch
mingzheTerapines Aug 2, 2024
c4ccda7
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Aug 2, 2024
732a412
Support union type
mingzheTerapines Aug 2, 2024
0729f70
support uniont type2
mingzheTerapines Aug 2, 2024
3b1e302
Merge remote-tracking branch 'upstream/main' into mingzhe-structCreate
mingzheTerapines Aug 22, 2024
5b91dc8
Revert some codes.
mingzheTerapines Aug 23, 2024
cb5ffe8
Merge remote-tracking branch 'upstream/main' into mingzhe-structCreate
mingzheTerapines Aug 23, 2024
6027b9d
Lower strcutextractref to structinject
mingzheTerapines Aug 23, 2024
2d2511f
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Sep 3, 2024
8446fd3
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Sep 11, 2024
0ac4eea
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Sep 18, 2024
4aafd70
Merge branch 'llvm:main' into mingzhe-structCreate
mingzheTerapines Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/circt/Dialect/Moore/MooreDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def MooreDialect : Dialect {
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 0;
let hasConstantMaterializer = 1;
let dependentDialects = ["hw::HWDialect"];
let dependentDialects = ["hw::HWDialect", "mlir::func::FuncDialect"];
}

#endif // CIRCT_DIALECT_MOORE_MOOREDIALECT
1 change: 1 addition & 0 deletions include/circt/Dialect/Moore/MooreOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "circt/Dialect/Moore/MooreAttributes.h"
#include "circt/Dialect/Moore/MooreDialect.h"
#include "circt/Dialect/Moore/MooreTypes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/RegionKindInterface.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
Expand Down
1 change: 1 addition & 0 deletions include/circt/Dialect/Moore/MooreOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def BlockingAssignOp : AssignOpBase<"blocking_assign", [
Arg<RefType, "", [MemWrite]>:$dst,
UnpackedType:$src
);
let hasCanonicalizeMethod = true;
}

def NonBlockingAssignOp : AssignOpBase<"nonblocking_assign"> {
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Moore/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ add_circt_dialect_library(CIRCTMoore
CIRCTHW
CIRCTSupport
MLIRIR
MLIRFuncDialect
MLIRInferTypeOpInterface
MLIRMemorySlotInterfaces
)
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Moore/MooreDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "circt/Dialect/HW/HWDialect.h"
#include "circt/Dialect/Moore/MooreOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"

using namespace circt;
using namespace circt::moore;
Expand Down
111 changes: 57 additions & 54 deletions lib/Dialect/Moore/MooreOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "circt/Support/CustomDirectiveImpl.h"
#include "mlir/IR/Builders.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace circt;
using namespace circt::moore;
Expand Down Expand Up @@ -723,7 +722,11 @@ static std::optional<uint32_t> getStructFieldIndex(Type type, StringAttr name) {
return structType.getFieldIndex(name);
if (auto structType = dyn_cast<UnpackedStructType>(type))
return structType.getFieldIndex(name);
assert(0 && "expected StructType or UnpackedStructType");
if (auto unionType = dyn_cast<UnionType>(type))
return unionType.getFieldIndex(name);
if (auto unionType = dyn_cast<UnpackedUnionType>(type))
return unionType.getFieldIndex(name);
assert(0 && "expected Struct-Like Type");
return {};
}

Expand All @@ -732,7 +735,10 @@ static ArrayRef<StructLikeMember> getStructMembers(Type type) {
return structType.getMembers();
if (auto structType = dyn_cast<UnpackedStructType>(type))
return structType.getMembers();
assert(0 && "expected StructType or UnpackedStructType");
if (auto unionType = dyn_cast<UnionType>(type))
return unionType.getMembers();
if (auto unionType = dyn_cast<UnpackedUnionType>(type))
return unionType.getMembers();
return {};
}

Expand Down Expand Up @@ -927,71 +933,54 @@ LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
//===----------------------------------------------------------------------===//

LogicalResult UnionCreateOp::verify() {
/// checks if the types of the input is exactly equal to the union field
auto type = getStructFieldType(getType(), getFieldNameAttr());

/// checks if the type of the input is exactly equal to the union field
/// type
return TypeSwitch<Type, LogicalResult>(getType())
.Case<UnionType, UnpackedUnionType>([this](auto &type) {
auto members = type.getMembers();
auto resultType = getType();
auto fieldName = getFieldName();
for (const auto &member : members)
if (member.name == fieldName && member.type == resultType)
return success();
emitOpError("input type must match the union field type");
return failure();
})
.Default([this](auto &) {
emitOpError("input type must be UnionType or UnpackedUnionType");
return failure();
});

if (!type)
return emitOpError() << "union field " << getFieldNameAttr()
<< " which does not exist in " << getInput().getType();
if (type != getType())
return emitOpError() << "result type " << getType()
<< " must match union field type " << type;
return success();
}

//===----------------------------------------------------------------------===//
// UnionExtractOp
//===----------------------------------------------------------------------===//

LogicalResult UnionExtractOp::verify() {
/// checks if the types of the input is exactly equal to the one of the
/// types of the result union fields
return TypeSwitch<Type, LogicalResult>(getInput().getType())
.Case<UnionType, UnpackedUnionType>([this](auto &type) {
auto members = type.getMembers();
auto fieldName = getFieldName();
auto resultType = getType();
for (const auto &member : members)
if (member.name == fieldName && member.type == resultType)
return success();
emitOpError("result type must match the union field type");
return failure();
})
.Default([this](auto &) {
emitOpError("input type must be UnionType or UnpackedUnionType");
return failure();
});
auto type = getStructFieldType(getInput().getType(), getFieldNameAttr());

/// checks if the type of the input is exactly equal to the type of the result
/// union fields
if (!type)
return emitOpError() << "union field " << getFieldNameAttr()
<< " which does not exist in " << getInput().getType();
if (type != getType())
return emitOpError() << "result type " << getType()
<< " must match union field type " << type;
return success();
}

//===----------------------------------------------------------------------===//
// UnionExtractOp
// UnionExtractRefOp
//===----------------------------------------------------------------------===//

LogicalResult UnionExtractRefOp::verify() {
/// checks if the types of the result is exactly equal to the type of the
/// refe union field
return TypeSwitch<Type, LogicalResult>(getInput().getType().getNestedType())
.Case<UnionType, UnpackedUnionType>([this](auto &type) {
auto members = type.getMembers();
auto fieldName = getFieldName();
auto resultType = getType().getNestedType();
for (const auto &member : members)
if (member.name == fieldName && member.type == resultType)
return success();
emitOpError("result type must match the union field type");
return failure();
})
.Default([this](auto &) {
emitOpError("input type must be UnionType or UnpackedUnionType");
return failure();
});
auto type = getStructFieldType(getInput().getType().getNestedType(),
getFieldNameAttr());
/// checks if the type of the result is exactly equal to the type of the
/// referring union field
if (!type)
return emitOpError() << "union field " << getFieldNameAttr()
<< " which does not exist in " << getInput().getType();
if (type != getType())
return emitOpError() << "result type " << getType()
<< " must match union field type " << type;
return success();
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1092,6 +1081,20 @@ DeletionKind BlockingAssignOp::removeBlockingUses(
return DeletionKind::Delete;
}

LogicalResult BlockingAssignOp::canonicalize(BlockingAssignOp op,
PatternRewriter &rewriter) {
if (auto refOp = op.getDst().getDefiningOp<moore::StructExtractRefOp>()) {
auto input = refOp.getInput();
auto read = rewriter.create<ReadOp>(op->getLoc(), input);
auto value = op.getSrc();
auto inject = rewriter.create<StructInjectOp>(
op->getLoc(), read, refOp.getFieldNameAttr(), value);
op->setOperands({input, inject});
return success();
}
return failure();
}

//===----------------------------------------------------------------------===//
// ReadOp
//===----------------------------------------------------------------------===//
Expand Down
13 changes: 13 additions & 0 deletions test/Dialect/Moore/canonicalizers.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ func.func @StructExtractFold2(%arg0: !moore.i17, %arg1: !moore.i42) -> (!moore.i
return %1, %2 : !moore.i17, !moore.i42
}

// CHECK-LABEL: func.func @structExtractRefLower2Inject
func.func @structExtractRefLower2Inject(%arg0: !moore.ref<struct<{a: i32, b: i32}>>) -> (!moore.ref<struct<{a: i32, b: i32}>>) {
// CHECK-NEXT: [[C42:%.+]] = moore.constant 42
// CHECK-NEXT: [[TMP1:%.+]] = moore.read %arg0 : <struct<{a: i32, b: i32}>>
// CHECK-NEXT: [[TMP2:%.+]] = moore.struct_inject [[TMP1]], "a", [[C42]] : struct<{a: i32, b: i32}>, i32
// CHECK-NEXT: moore.blocking_assign %arg0, [[TMP2]] : struct<{a: i32, b: i32}>
// CHECK-NEXT: return %arg0 : !moore.ref<struct<{a: i32, b: i32}>>
%0 = moore.constant 42 : i32
%1 = moore.struct_extract_ref %arg0, "a" : <struct<{a: i32, b: i32}>> -> <i32>
moore.blocking_assign %1, %0 : i32
Comment on lines +202 to +204
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's perfect if you can add some other cases. Such as:

%0 = moore.constant 42 : i32
%1 = moore.struct_extract_ref %arg0, "a" : <struct<{a: i32, b: i32}>> -> <i32>
moore.blocking_assign %1, %0 : i32

%2 = moore.constant 43 : i32 
moore.blocking_assign %1, %2 : i32

And

%0 = moore.constant 42 : i32
%1 = moore.struct_extract_ref %arg0, "a" : <struct<{a: i32, b: i32}>> -> <i32>
moore.blocking_assign %1, %0 : i32

%2 = moore.constant 43 : i32 
%3 = moore.struct_extract_ref %arg0, "b" : <struct<{a: i32, b: i32}>> -> <i32>
moore.blocking_assign %2, %3 : i32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests may be complicated I can print the result of your case here

%0 = moore.constant 43 : i32
    %1 = moore.constant 42 : i32
    %2 = moore.read %arg0 : <struct<{a: i32, b: i32}>>
    %3 = moore.struct_inject %2, "a", %1 : struct<{a: i32, b: i32}>, i32
    moore.blocking_assign %arg0, %3 : struct<{a: i32, b: i32}>
    %4 = moore.read %arg0 : <struct<{a: i32, b: i32}>>
    %5 = moore.struct_inject %4, "a", %0 : struct<{a: i32, b: i32}>, i32
    moore.blocking_assign %arg0, %5 : struct<{a: i32, b: i32}>
    %6 = moore.read %arg0 : <struct<{a: i32, b: i32}>>
    %7 = moore.struct_inject %6, "a", %0 : struct<{a: i32, b: i32}>, i32
    moore.blocking_assign %arg0, %7 : struct<{a: i32, b: i32}>
    return %arg0 : !moore.ref<struct<{a: i32, b: i32}>>

and

%0 = moore.constant 43 : i32
    %1 = moore.constant 42 : i32
    %2 = moore.read %arg0 : <struct<{a: i32, b: i32}>>
    %3 = moore.struct_inject %2, "a", %1 : struct<{a: i32, b: i32}>, i32
    moore.blocking_assign %arg0, %3 : struct<{a: i32, b: i32}>
    %4 = moore.read %arg0 : <struct<{a: i32, b: i32}>>
    %5 = moore.struct_inject %4, "b", %0 : struct<{a: i32, b: i32}>, i32
    moore.blocking_assign %arg0, %5 : struct<{a: i32, b: i32}>
    return %arg0 : !moore.ref<struct<{a: i32, b: i32}>>

return %arg0 : !moore.ref<struct<{a: i32, b: i32}>>
}

// CHECK-LABEL: func.func @StructInjectFold1
func.func @StructInjectFold1(%arg0: !moore.struct<{a: i32, b: i32}>) -> (!moore.struct<{a: i32, b: i32}>) {
// CHECK-NEXT: [[C42:%.+]] = moore.constant 42
Expand Down
Loading