Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arc] Make the canonicalizer shuffle the input vector elements before merging #7394

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -619,18 +619,58 @@ MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp,
// Ensure that the input vector matches the output of the `otherVecOp`
// Make sure that the results of the otherVecOp have only one use
auto otherVecOp = inputVec[0].getDefiningOp<VectorizeOp>();
if (!otherVecOp || inputVec != otherVecOp.getResults() ||
otherVecOp == vecOp ||
if (!otherVecOp || otherVecOp == vecOp ||
!llvm::all_of(otherVecOp.getResults(),
[](auto result) { return result.hasOneUse(); })) {
[](auto result) { return result.hasOneUse(); }) ||
!llvm::all_of(inputVec, [&](auto result) {
return result.template getDefiningOp<VectorizeOp>() == otherVecOp;
})) {
newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
continue;
}

// Here, all elements are from the same `VectorizeOp`.
// If all elements of the input vector come from the same `VectorizeOp`
// sort the vectors by their indices
DenseMap<Value, size_t> resultIdxMap;
for (auto [resultIdx, result] : llvm::enumerate(otherVecOp.getResults()))
resultIdxMap[result] = resultIdx;

SmallVector<Value> tempVec(inputVec.begin(), inputVec.end());
llvm::sort(tempVec, [&](Value a, Value b) {
return resultIdxMap[a] < resultIdxMap[b];
});

// Check if inputVec matches the result after sorting.
if (tempVec != SmallVector<Value>(otherVecOp.getResults().begin(),
otherVecOp.getResults().end())) {
Comment on lines +645 to +646
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does tempVec != otherVecOp.getResults() not work here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will try, but I think getResults() returns result_range which cannot be converted into a SmallVector<Value> but is worth trying!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be an implicit conversion from SmallVector<Value> to an iterator range, which should theoretically be able to compare against a result range. Worth trying, don't worry if it doesn't work 😃

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't work!

 error: invalid operands to binary expression ('SmallVector<mlir::Value>' and '::mlir::Operation::result_range' (aka 'mlir::ResultRange'))
    if (tempVec != otherVecOp.getResults()) 

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking 😃

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianschuiki, should I land this now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's 12:17 AM here, so please land it if it's Okay.

newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end());
continue;
}

DenseMap<size_t, size_t> fromRealIdxToSortedIdx;
for (auto [inIdx, in] : llvm::enumerate(inputVec))
fromRealIdxToSortedIdx[inIdx] = resultIdxMap[in];

// If this flag is set that means we changed the IR so we cannot return
// failure
canBeMerged = true;
newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
otherVecOp.getOperands().end());

// If the results got shuffled, then shuffle the operands before merging.
if (inputVec != otherVecOp.getResults()) {
for (auto otherVecOpInputVec : otherVecOp.getInputs()) {
// use the tempVec again instead of creating another one.
tempVec = SmallVector<Value>(inputVec.size());
for (auto [realIdx, opernad] : llvm::enumerate(otherVecOpInputVec))
tempVec[realIdx] =
otherVecOpInputVec[fromRealIdxToSortedIdx[realIdx]];

newOperands.insert(newOperands.end(), tempVec.begin(), tempVec.end());
}

} else
newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(),
otherVecOp.getOperands().end());

auto &otherBlock = otherVecOp.getBody().front();
for (auto &otherArg : otherBlock.getArguments()) {
Expand Down
53 changes: 39 additions & 14 deletions test/Dialect/Arc/arc-canonicalizer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -453,22 +453,47 @@ in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) {
}

// CHECK-LABEL: hw.module @Needs_Shuffle(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) {
// CHECK-NEXT: [[VEC0:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[OUT:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC1:%.+]]:4 = arc.vectorize ([[VEC0]]#1, [[VEC0]]#0, [[VEC0]]#2, [[VEC0]]#3), (%n, %p, %r, %t) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[OUT:%.+]] = comb.and %arg0, %arg1 : i8
// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l), (%p, %n, %r, %t), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8):
// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: [[AND:%.+]] = comb.and [[ADD]], %arg2 : i8
// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg3) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC2:%.+]]:4 = arc.vectorize ([[VEC1]]#1, [[VEC1]]#0, [[VEC1]]#2, [[VEC1]]#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8):
// CHECK-NEXT: [[OUT:%.+]] = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }

hw.module @Needs_Shuffle_2(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8,
in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1,
in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) {
%R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.add %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%L:4 = arc.vectorize(%R#3, %R#2, %R#1, %R#0), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0: i8, %arg1: i8):
%ret = comb.and %arg0, %arg1: i8
arc.vectorize.return %ret: i8
}
%C:4 = arc.vectorize(%L#1, %L#0, %L#2, %L#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
^bb0(%arg0 : i8, %arg1: i8):
%1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8
arc.vectorize.return %1692 : i8
}
%4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8
}

// CHECK-LABEL: hw.module @Needs_Shuffle_2(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) {
// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%h, %k, %e, %b), (%i, %l, %f, %c), (%p, %n, %r, %t), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8):
// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8
// CHECK-NEXT: [[AND:%.+]] = comb.and [[ADD]], %arg2 : i8
// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg3) : (i8, i8) -> i8
// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8
// CHECK-NEXT: }
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC2]]#0, [[STATE:%.+]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8
// CHECK-NEXT: hw.output
// CHECK-NEXT: }

Expand Down
Loading