From 7818e2b17b46ede926ae1e5e678356247d006b88 Mon Sep 17 00:00:00 2001 From: Mohamed Atef Date: Sat, 27 Jul 2024 01:12:38 +0300 Subject: [PATCH] [Arc] Make the canonicalizer shuffle the input vector elements before merging --- .../Arc/Transforms/ArcCanonicalizer.cpp | 50 +++++++++++++++-- test/Dialect/Arc/arc-canonicalizer.mlir | 53 ++++++++++++++----- 2 files changed, 84 insertions(+), 19 deletions(-) diff --git a/lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp b/lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp index b03ef1647fdc..3f64a68bde1f 100644 --- a/lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp +++ b/lib/Dialect/Arc/Transforms/ArcCanonicalizer.cpp @@ -619,18 +619,58 @@ MergeVectorizeOps::matchAndRewrite(VectorizeOp vecOp, // Ensure that the input vector matches the output of the `otherVecOp` // Make sure that the results of the otherVecOp have only one use auto otherVecOp = inputVec[0].getDefiningOp(); - if (!otherVecOp || inputVec != otherVecOp.getResults() || - otherVecOp == vecOp || + if (!otherVecOp || otherVecOp == vecOp || !llvm::all_of(otherVecOp.getResults(), - [](auto result) { return result.hasOneUse(); })) { + [](auto result) { return result.hasOneUse(); }) || + !llvm::all_of(inputVec, [&](auto result) { + return result.template getDefiningOp() == otherVecOp; + })) { newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end()); continue; } + + // Here, all elements are from the same `VectorizeOp`. + // If all elements of the input vector come from the same `VectorizeOp` + // sort the vectors by their indices + DenseMap resultIdxMap; + for (auto [resultIdx, result] : llvm::enumerate(otherVecOp.getResults())) + resultIdxMap[result] = resultIdx; + + SmallVector tempVec(inputVec.begin(), inputVec.end()); + llvm::sort(tempVec, [&](Value a, Value b) { + return resultIdxMap[a] < resultIdxMap[b]; + }); + + // Check if inputVec matches the result after sorting. + if (tempVec != SmallVector(otherVecOp.getResults().begin(), + otherVecOp.getResults().end())) { + newOperands.insert(newOperands.end(), inputVec.begin(), inputVec.end()); + continue; + } + + DenseMap fromRealIdxToSortedIdx; + for (auto [inIdx, in] : llvm::enumerate(inputVec)) + fromRealIdxToSortedIdx[inIdx] = resultIdxMap[in]; + // If this flag is set that means we changed the IR so we cannot return // failure canBeMerged = true; - newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(), - otherVecOp.getOperands().end()); + + // If the results got shuffled, then shuffle the operands before merging. + if (inputVec != otherVecOp.getResults()) { + for (auto otherVecOpInputVec : otherVecOp.getInputs()) { + // use the tempVec again instead of creating another one. + tempVec = SmallVector(inputVec.size()); + for (auto [realIdx, opernad] : llvm::enumerate(otherVecOpInputVec)) + tempVec[realIdx] = + otherVecOpInputVec[fromRealIdxToSortedIdx[realIdx]]; + + newOperands.insert(newOperands.end(), tempVec.begin(), tempVec.end()); + } + + } else + newOperands.insert(newOperands.end(), otherVecOp.getOperands().begin(), + otherVecOp.getOperands().end()); auto &otherBlock = otherVecOp.getBody().front(); for (auto &otherArg : otherBlock.getArguments()) { diff --git a/test/Dialect/Arc/arc-canonicalizer.mlir b/test/Dialect/Arc/arc-canonicalizer.mlir index 0da5216ace5e..eae501a4e17d 100644 --- a/test/Dialect/Arc/arc-canonicalizer.mlir +++ b/test/Dialect/Arc/arc-canonicalizer.mlir @@ -453,22 +453,47 @@ in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) { } // CHECK-LABEL: hw.module @Needs_Shuffle(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) { -// CHECK-NEXT: [[VEC0:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { -// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8): -// CHECK-NEXT: [[OUT:%.+]] = comb.add %arg0, %arg1 : i8 -// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8 -// CHECK-NEXT: } -// CHECK-NEXT: [[VEC1:%.+]]:4 = arc.vectorize ([[VEC0]]#1, [[VEC0]]#0, [[VEC0]]#2, [[VEC0]]#3), (%n, %p, %r, %t) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { -// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8): -// CHECK-NEXT: [[OUT:%.+]] = comb.and %arg0, %arg1 : i8 -// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8 +// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%b, %e, %h, %k), (%c, %f, %i, %l), (%p, %n, %r, %t), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { +// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8): +// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8 +// CHECK-NEXT: [[AND:%.+]] = comb.and [[ADD]], %arg2 : i8 +// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg3) : (i8, i8) -> i8 +// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8 // CHECK-NEXT: } -// CHECK-NEXT: [[VEC2:%.+]]:4 = arc.vectorize ([[VEC1]]#1, [[VEC1]]#0, [[VEC1]]#2, [[VEC1]]#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { -// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8): -// CHECK-NEXT: [[OUT:%.+]] = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8 -// CHECK-NEXT: arc.vectorize.return [[OUT]] : i8 +// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8 +// CHECK-NEXT: hw.output +// CHECK-NEXT: } + +hw.module @Needs_Shuffle_2(in %b: i8, in %e: i8, in %h: i8, in %k: i8, in %c: i8, in %f: i8, +in %i: i8, in %l: i8, in %n: i8, in %p: i8, in %r: i8, in %t: i8, in %en: i1, +in %clock: !seq.clock, in %o: i8, in %v: i8, in %q: i8, in %s: i8) { + %R:4 = arc.vectorize(%b, %e, %h, %k), (%c, %f, %i, %l) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.add %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %L:4 = arc.vectorize(%R#3, %R#2, %R#1, %R#0), (%n, %p, %r, %t): (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0: i8, %arg1: i8): + %ret = comb.and %arg0, %arg1: i8 + arc.vectorize.return %ret: i8 + } + %C:4 = arc.vectorize(%L#1, %L#0, %L#2, %L#3), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { + ^bb0(%arg0 : i8, %arg1: i8): + %1692 = arc.call @Just_A_Dummy_Func(%arg0, %arg1) : (i8, i8) -> i8 + arc.vectorize.return %1692 : i8 + } + %4 = arc.state @FooMux(%en, %C#0, %4) clock %clock latency 1 : (i1, i8, i8) -> i8 +} + +// CHECK-LABEL: hw.module @Needs_Shuffle_2(in %b : i8, in %e : i8, in %h : i8, in %k : i8, in %c : i8, in %f : i8, in %i : i8, in %l : i8, in %n : i8, in %p : i8, in %r : i8, in %t : i8, in %en : i1, in %clock : !seq.clock, in %o : i8, in %v : i8, in %q : i8, in %s : i8) { +// CHECK-NEXT: [[VEC:%.+]]:4 = arc.vectorize (%h, %k, %e, %b), (%i, %l, %f, %c), (%p, %n, %r, %t), (%o, %v, %q, %s) : (i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8, i8) -> (i8, i8, i8, i8) { +// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i8, %arg1: i8, %arg2: i8, %arg3: i8): +// CHECK-NEXT: [[ADD:%.+]] = comb.add %arg0, %arg1 : i8 +// CHECK-NEXT: [[AND:%.+]] = comb.and [[ADD]], %arg2 : i8 +// CHECK-NEXT: [[CALL:%.+]] = arc.call @Just_A_Dummy_Func([[AND]], %arg3) : (i8, i8) -> i8 +// CHECK-NEXT: arc.vectorize.return [[CALL]] : i8 // CHECK-NEXT: } -// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC2]]#0, [[STATE:%.+]]) clock %clock latency 1 : (i1, i8, i8) -> i8 +// CHECK-NEXT: [[STATE:%.+]] = arc.state @FooMux(%en, [[VEC]]#0, [[STATE]]) clock %clock latency 1 : (i1, i8, i8) -> i8 // CHECK-NEXT: hw.output // CHECK-NEXT: }