Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arc] Initial SLP support #7061

Merged
merged 2 commits into from
May 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/circt/Dialect/Arc/ArcPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ createAddTapsPass(const AddTapsOptions &options = {});
std::unique_ptr<mlir::Pass> createAllocateStatePass();
std::unique_ptr<mlir::Pass> createArcCanonicalizerPass();
std::unique_ptr<mlir::Pass> createDedupPass();
std::unique_ptr<mlir::Pass> createFindInitialVectorsPass();
std::unique_ptr<mlir::Pass> createGroupResetsAndEnablesPass();
std::unique_ptr<mlir::Pass>
createInferMemoriesPass(const InferMemoriesOptions &options = {});
Expand Down
6 changes: 6 additions & 0 deletions include/circt/Dialect/Arc/ArcPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def Dedup : Pass<"arc-dedup", "mlir::ModuleOp"> {
];
}

def FindInitialVectors : Pass<"find-initial-vectors", "mlir::ModuleOp"> {
let summary = "Finds the ops that can be grouped together into a vector";
let constructor = "circt::arc::createFindInitialVectorsPass()";
let dependentDialects = ["arc::ArcDialect"];
}

def GroupResetsAndEnables : Pass<"arc-group-resets-and-enables",
"mlir::ModuleOp"> {
let summary = "Group reset and enable conditions of lowered states";
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Arc/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_circt_dialect_library(CIRCTArcTransforms
AllocateState.cpp
ArcCanonicalizer.cpp
Dedup.cpp
FindInitialVectors.cpp
GroupResetsAndEnables.cpp
InferMemories.cpp
InferStateProperties.cpp
Expand Down
252 changes: 252 additions & 0 deletions lib/Dialect/Arc/Transforms/FindInitialVectors.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
//===- FindInitialVectors.cpp ---------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass implements a simple SLP vectorizer for Arc, the pass starts with
// `arc.state` operations as seeds in every new vector, then following the
// dependency graph nodes computes a rank to every operation in the module
// and assigns a rank to each one of them. After that it groups isomorphic
// operations together and put them in a vector.
//
//===----------------------------------------------------------------------===//

#include "circt/Dialect/Arc/ArcOps.h"
#include "circt/Dialect/Arc/ArcPasses.h"
#include "circt/Dialect/Comb/CombOps.h"
#include "circt/Dialect/HW/HWOps.h"
#include "circt/Dialect/Seq/SeqOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Types.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

#define DEBUG_TYPE "find-initial-vectors"

namespace circt {
namespace arc {
#define GEN_PASS_DEF_FINDINITIALVECTORS
#include "circt/Dialect/Arc/ArcPasses.h.inc"
} // namespace arc
} // namespace circt

using namespace circt;
using namespace arc;
using llvm::SmallMapVector;

namespace {
struct TopologicalOrder {
/// An integer rank assigned to each operation.
SmallMapVector<Operation *, unsigned, 32> opRanks;
LogicalResult compute(Block *block);
unsigned get(Operation *op) const {
const auto *it = opRanks.find(op);
assert(it != opRanks.end() && "op has no rank");
return it->second;
}
};
} // namespace

/// Assign each operation in the given block a topological rank. Stateful
/// elements are assigned rank 0. All other operations receive the maximum rank
/// of their users, plus one.
LogicalResult TopologicalOrder::compute(Block *block) {
LLVM_DEBUG(llvm::dbgs() << "- Computing topological order in block " << block
<< "\n");
struct WorklistItem {
WorklistItem(Operation *op) : userIt(op->user_begin()) {}
Operation::user_iterator userIt;
unsigned rank = 0;
};
SmallMapVector<Operation *, WorklistItem, 16> worklist;
for (auto &op : *block) {
if (opRanks.contains(&op))
continue;
worklist.insert({&op, WorklistItem(&op)});
while (!worklist.empty()) {
auto &[op, item] = worklist.back();
if (auto stateOp = dyn_cast<StateOp>(op)) {
if (stateOp.getLatency() > 0)
item.userIt = op->user_end();
} else if (auto writeOp = dyn_cast<MemoryWritePortOp>(op)) {
item.userIt = op->user_end();
}
if (item.userIt == op->user_end()) {
opRanks.insert({op, item.rank});
worklist.pop_back();
continue;
}
if (auto *rankIt = opRanks.find(*item.userIt); rankIt != opRanks.end()) {
item.rank = std::max(item.rank, rankIt->second + 1);
++item.userIt;
continue;
}
if (!worklist.insert({*item.userIt, WorklistItem(*item.userIt)}).second)
return op->emitError("dependency cycle");
}
}
return success();
}

namespace {
using Key = std::tuple<unsigned, StringRef, SmallVector<Type>,
SmallVector<Type>, DictionaryAttr>;

Key computeKey(Operation *op, unsigned rank) {
// The key = concat(op_rank, op_name, op_operands_types, op_result_types,
// op_attrs)
return std::make_tuple(
rank, op->getName().getStringRef(),
SmallVector<Type>(op->operand_type_begin(), op->operand_type_end()),
SmallVector<Type>(op->result_type_begin(), op->result_type_end()),
op->getAttrDictionary());
}

struct Vectorizer {
Vectorizer(Block *block) : block(block) {}
LogicalResult collectSeeds(Block *block) {
if (failed(order.compute(block)))
return failure();

for (auto &[op, rank] : order.opRanks)
candidates[computeKey(op, rank)].push_back(op);

return success();
}

LogicalResult vectorize();
// Store Isomorphic ops together
SmallMapVector<Key, SmallVector<Operation *>, 16> candidates;
TopologicalOrder order;
Block *block;
};
} // namespace

namespace llvm {
template <>
struct DenseMapInfo<Key> {
static inline Key getEmptyKey() {
return Key(0, StringRef(), SmallVector<Type>(), SmallVector<Type>(),
DictionaryAttr());
}

static inline Key getTombstoneKey() {
static StringRef tombStoneKeyOpName =
DenseMapInfo<StringRef>::getTombstoneKey();
return Key(1, tombStoneKeyOpName, SmallVector<Type>(), SmallVector<Type>(),
DictionaryAttr());
}

static unsigned getHashValue(const Key &key) {
return hash_value(std::get<0>(key)) ^ hash_value(std::get<1>(key)) ^
hash_value(std::get<2>(key)) ^ hash_value(std::get<3>(key)) ^
hash_value(std::get<4>(key));
}

static bool isEqual(const Key &lhs, const Key &rhs) { return lhs == rhs; }
};
} // namespace llvm

// When calling this function we assume that we have the candidate groups of
// isomorphic ops so we need to feed them to the `VectorizeOp`
LogicalResult Vectorizer::vectorize() {
if (failed(collectSeeds(block)))
return failure();

// Unachievable?! just in case!
if (candidates.empty())
return success();

// Iterate over every group of isomorphic ops
for (const auto &[key, ops] : candidates) {
// If the group has only one scalar then it doesn't worth vectorizing,
// We skip also ops with more than one result as `arc.vectorize` supports
// only one result in its body region.
if (ops.size() == 1 || ops[0]->getNumResults() > 1)
continue;

// Here, we have a bunch of isomorphic ops, we need to extract the operands
// results and attributes of every op and store them in a vector
// Holds the operands
SmallVector<SmallVector<Value, 4>> vectorOperands;
vectorOperands.resize(ops[0]->getNumOperands());
for (auto *op : ops)
for (auto [into, operand] : llvm::zip(vectorOperands, op->getOperands()))
into.push_back(operand);
SmallVector<ValueRange> operandValueRanges;
operandValueRanges.assign(vectorOperands.begin(), vectorOperands.end());
// Holds the results
SmallVector<Type> resultTypes(ops.size(), ops[0]->getResult(0).getType());

// Now construct the `VectorizeOp`
ImplicitLocOpBuilder builder(ops[0]->getLoc(), ops[0]);
auto vectorizeOp =
builder.create<VectorizeOp>(resultTypes, operandValueRanges);
// Now we have the operands, results and attributes, now we need to get
// the blocks.

// There was no blocks so we need to create one and set the insertion point
// at the first of this region
auto &vectorizeBlock = vectorizeOp.getBody().emplaceBlock();
builder.setInsertionPointToStart(&vectorizeBlock);

// Add the block arguments
// comb.and %x, %y
// comb.and %u, %v
// at this point the operands vector will be {{x, u}, {y, v}}
// we need to create an th block args, so we need the type and the location
// the type is a vector type
IRMapping argMapping;
for (auto [vecOperand, origOpernad] :
llvm::zip(vectorOperands, ops[0]->getOperands())) {
auto arg = vectorizeBlock.addArgument(vecOperand[0].getType(),
origOpernad.getLoc());
argMapping.map(origOpernad, arg);
}

auto *clonedOp = builder.clone(*ops[0], argMapping);
// `VectorizeReturnOp`
builder.create<VectorizeReturnOp>(clonedOp->getResult(0));

// Now replace the original ops with the vectorized ops
for (auto [op, result] : llvm::zip(ops, vectorizeOp->getResults())) {
op->getResult(0).replaceAllUsesWith(result);
op->erase();
}
}
return success();
}

namespace {
struct FindInitialVectorsPass
: public impl::FindInitialVectorsBase<FindInitialVectorsPass> {
void runOnOperation() override;
};
} // namespace

void FindInitialVectorsPass::runOnOperation() {
for (auto moduleOp : getOperation().getOps<hw::HWModuleOp>()) {
auto result = moduleOp.walk([&](Block *block) {
if (!mayHaveSSADominance(*block->getParent()))
if (failed(Vectorizer(block).vectorize()))
return WalkResult::interrupt();
return WalkResult::advance();
});
if (result.wasInterrupted())
return signalPassFailure();
}
}

std::unique_ptr<Pass> arc::createFindInitialVectorsPass() {
return std::make_unique<FindInitialVectorsPass>();
}
65 changes: 65 additions & 0 deletions test/Dialect/Arc/find-initial-vectors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// RUN: circt-opt %s --find-initial-vectors | FileCheck %s

hw.module @Foo(in %clock: !seq.clock, in %en: i1, in %inA: i3, in %inB: i3) {
%4 = arc.state @FooMux(%en, %21753, %4) clock %clock latency 1 : (i1, i3, i3) -> i3
%5 = arc.state @FooMux(%en, %21754, %5) clock %clock latency 1 : (i1, i3, i3) -> i3
%7 = arc.state @FooMux(%en, %21756, %7) clock %clock latency 1 : (i1, i3, i3) -> i3
%12 = arc.state @FooMux(%en, %91, %12) clock %clock latency 1 : (i1, i3, i3) -> i3
%15 = arc.state @FooMux(%en, %93, %15) clock %clock latency 1 : (i1, i3, i3) -> i3
%16 = arc.state @FooMux(%en, %94, %16) clock %clock latency 1 : (i1, i3, i3) -> i3

%21753 = comb.xor %200, %inA : i3
%21754 = comb.xor %201, %inA : i3
%21756 = comb.xor %202, %inA : i3

%91 = comb.add %100, %inB : i3
%93 = comb.add %101, %inB : i3
%94 = comb.add %102, %inB : i3

%100 = comb.mul %12, %inA : i3
%101 = comb.mul %15, %inA : i3
%102 = comb.sub %16, %inA : i3

%200 = comb.and %4, %inB : i3
%201 = comb.and %5, %inB : i3
%202 = comb.and %7, %inB : i3
}

arc.define @FooMux(%arg0: i1, %arg1: i3, %arg2: i3) -> i3 {
%0 = comb.mux bin %arg0, %arg1, %arg2 : i3
arc.output %0 : i3
}

// CHECK-LABEL: hw.module @Foo(in %clock : !seq.clock, in %en : i1, in %inA : i3, in %inB : i3) {
// CHECK-NEXT: [[VEC0:%.+]]:6 = arc.vectorize (%clock, %clock, %clock, %clock, %clock, %clock), (%en, %en, %en, %en, %en, %en), ([[VEC1:%.+]]#0, [[VEC1]]#1, [[VEC1]]#2, [[VEC2:%.+]]#0, [[VEC2]]#1, [[VEC2]]#2), ([[VEC0]]#0, [[VEC0]]#1, [[VEC0]]#2, [[VEC0]]#3, [[VEC0]]#4, [[VEC0]]#5) : (!seq.clock, !seq.clock, !seq.clock, !seq.clock, !seq.clock, !seq.clock, i1, i1, i1, i1, i1, i1, i3, i3, i3, i3, i3, i3, i3, i3, i3, i3, i3, i3) -> (i3, i3, i3, i3, i3, i3) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: !seq.clock, %arg1: i1, %arg2: i3, %arg3: i3):
// CHECK-NEXT: [[ANS:%.+]] = arc.state @FooMux(%arg1, %arg2, %arg3) clock %arg0 latency 1 : (i1, i3, i3) -> i3
// CHECK-NEXT: arc.vectorize.return [[ANS:%.+]] : i3
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC1]]:3 = arc.vectorize ([[VEC4:%.+]]#0, [[VEC4]]#1, [[VEC4]]#2), (%inA, %inA, %inA) : (i3, i3, i3, i3, i3, i3) -> (i3, i3, i3) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i3, %arg1: i3):
// CHECK-NEXT: [[ANS:%.+]] = comb.xor %arg0, %arg1 : i3
// CHECK-NEXT: arc.vectorize.return [[ANS:%.+]] : i3
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC2]]:3 = arc.vectorize ([[VEC3:%.+]]#0, [[VEC3]]#1, [[SCALAR:%.+]]), (%inB, %inB, %inB) : (i3, i3, i3, i3, i3, i3) -> (i3, i3, i3) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i3, %arg1: i3):
// CHECK-NEXT: [[ANS:%.+]] = comb.add %arg0, %arg1 : i3
// CHECK-NEXT: arc.vectorize.return [[ANS:%.+]] : i3
// CHECK-NEXT: }
// CHECK-NEXT: [[VEC3]]:2 = arc.vectorize ([[VEC0]]#3, [[VEC0]]#4), (%inA, %inA) : (i3, i3, i3, i3) -> (i3, i3) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i3, %arg1: i3):
// CHECK-NEXT: [[ANS:%.+]] = comb.mul %arg0, %arg1 : i3
// CHECK-NEXT: arc.vectorize.return [[ANS:%.+]] : i3
// CHECK-NEXT: }
// CHECK-NEXT: [[SCALAR]] = comb.sub [[VEC0]]#5, %inA : i3
// CHECK-NEXT: [[VEC4]]:3 = arc.vectorize ([[VEC0]]#0, [[VEC0]]#1, [[VEC0]]#2), (%inB, %inB, %inB) : (i3, i3, i3, i3, i3, i3) -> (i3, i3, i3) {
// CHECK-NEXT: ^[[BLOCK:[[:alnum:]]+]](%arg0: i3, %arg1: i3):
// CHECK-NEXT: [[ANS:%.+]] = comb.and %arg0, %arg1 : i3
// CHECK-NEXT: arc.vectorize.return [[ANS:%.+]] : i3
// CHECK-NEXT: }
// CHECK-NEXT: hw.output
// CHECK-NEXT: }
// CHECK-NEXT: arc.define @FooMux(%arg0: i1, %arg1: i3, %arg2: i3) -> i3 {
// CHECK-NEXT: [[ANS:%.+]] = comb.mux bin %arg0, %arg1, %arg2 : i3
// CHECK-NEXT: arc.output [[ANS:%.+]] : i3
// CHECK-NEXT: }
Loading