diff --git a/include/circt/Dialect/Arc/ArcCostModel.h b/include/circt/Dialect/Arc/ArcCostModel.h new file mode 100644 index 000000000000..09be420e3195 --- /dev/null +++ b/include/circt/Dialect/Arc/ArcCostModel.h @@ -0,0 +1,60 @@ +//===- ArcCostModel.h -----------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_ARC_COSTMODEL_H +#define CIRCT_ARC_COSTMODEL_H + +#include "circt/Dialect/Arc/ArcOps.h" +#include "mlir/IR/Operation.h" +#include "mlir/Pass/AnalysisManager.h" + +using namespace mlir; + +namespace circt { +namespace arc { + +// FIXME: May be refined and we have more accurate operation costs +enum class OperationCost : size_t { + NOCOST, + NORMALCOST, + PACKCOST = 2, + EXTRACTCOST = 3, + CONCATCOST = 3, + SAMEVECTORNOSHUFFLE = 0, + SAMEVECTORSHUFFLECOST = 2, + DIFFERENTVECTORNOSHUFFLE = 2, + DIFFERENTVECTORSHUFFLECOST = 3 +}; + +class ArcCostModel { +public: + ArcCostModel(Operation *op); + size_t getCost(Operation *op); + +private: + size_t computeOperationCost(Operation *op); + bool areElementsInOrder(Value operand, Operation *definingOp); + size_t countBodyOps(const MutableArrayRef ®ions); + + // gets the cost to pack the vectors we have some cases we need to consider: + // 1: the input is scalar so we can give it a cost of 1 + // 2: the input is a result of another vector but with no shuffling so the + // is 0 + // 3: the input is a result of another vector but with some shuffling so + // the cost is the (number of out of order elements) * 2 + // 4: the input is a mix of some vectors: + // a) same order we multiply by 2 + // b) shuffling we multiply by 3 + size_t getInputVectorsCost(VectorizeOp vecOp); + size_t getShufflingCost(const ValueRange &inputVec, bool isSame = false); +}; + +} // namespace arc +} // namespace circt + +#endif // CIRCT_ARC_COSTMODEL_H diff --git a/lib/Dialect/Arc/ArcCostModel.cpp b/lib/Dialect/Arc/ArcCostModel.cpp new file mode 100644 index 000000000000..48726a404979 --- /dev/null +++ b/lib/Dialect/Arc/ArcCostModel.cpp @@ -0,0 +1,106 @@ +//===- ArcCostModel.cpp ---------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Dialect/Arc/ArcCostModel.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include + +using namespace llvm; +using namespace circt; +using namespace arc; +using namespace std; + +size_t ArcCostModel::getCost(Operation *op) { return computeOperationCost(op); } + +size_t ArcCostModel::computeOperationCost(Operation *op) { + if (isa(op)) + return size_t(OperationCost::CONCATCOST); + if (isa(op)) + return size_t(OperationCost::EXTRACTCOST); + // We have some other functions that need to be handled in a different way + // arc::StateOp, arc::CallOp, mlir::func::CallOp and arc::VectorizeOp, each of + // these functions have bodies so the cost of the op equals the cost of its + // body. + if (isa(op) || isa(op) || + isa(op)) + return countBodyOps(op->getOperand(0).getDefiningOp()->getRegions()); + + if (isa(op)) + return countBodyOps(op->getRegions()) + + getInputVectorsCost(op->getOperand(0).getDefiningOp()); + + return size_t(OperationCost::NORMALCOST); +} + +size_t ArcCostModel::countBodyOps(const MutableArrayRef ®ions) { + size_t counter = 0; + for (auto ®ion : regions) + for (auto &block : region) + for (auto &op : block) + counter += computeOperationCost(&op); + + return counter; +} + +size_t ArcCostModel::getInputVectorsCost(VectorizeOp vecOp) { + size_t totalCost = 0; + for (auto inputVec : vecOp.getInputs()) { + if (auto otherVecOp = inputVec[0].getDefiningOp(); + all_of(inputVec.begin(), inputVec.end(), [&](auto element) { + return element.template getDefiningOp() == otherVecOp; + })) { + // This means that they came from the same vector or + // VectorizeOp == null so they are all scalars + + // Check if they all scalars we multiply by 2 (SHL/R + OR) + if (!otherVecOp) + totalCost += inputVec.size() * size_t(OperationCost::PACKCOST); + else + totalCost += inputVec == otherVecOp.getResults() + ? size_t(OperationCost::SAMEVECTORNOSHUFFLE) + : getShufflingCost(inputVec, true); + } else + // inputVector consists of elements from different vectotrize ops and + // may have scalars as well. + totalCost += getShufflingCost(inputVec); + } + return totalCost; +} + +size_t ArcCostModel::getShufflingCost(const ValueRange &inputVec, bool isSame) { + size_t totalCost = 0; + if (isSame) { + auto vecOp = inputVec[0].getDefiningOp(); + for (auto [elem, orig] : llvm::zip(inputVec, vecOp.getResults())) + if (elem != orig) + ++totalCost; + + return totalCost * size_t(OperationCost::SAMEVECTORSHUFFLECOST); + } + + for (size_t i = 0; i < inputVec.size(); ++i) { + auto otherVecOp = inputVec[i].getDefiningOp(); + // If the element is not a result of a vector operation then it's a result + // of a scalar operation, then it just needs to be packed into the vector. + if (!otherVecOp) + totalCost += size_t(OperationCost::PACKCOST); + else { + // If it's a result of a vector operation, then we have two cases: + // (1) Its order in `inputVec` is the same as its order in the result of + // the defining op. + // (2) the order is different. + size_t idx = find(otherVecOp.getResults().begin(), + otherVecOp.getResults().end(), inputVec[i]) - + otherVecOp.getResults().begin(); + totalCost += i == idx ? size_t(OperationCost::DIFFERENTVECTORNOSHUFFLE) + : size_t(OperationCost::DIFFERENTVECTORSHUFFLECOST); + } + } + return totalCost; +} diff --git a/lib/Dialect/Arc/CMakeLists.txt b/lib/Dialect/Arc/CMakeLists.txt index 00a58eb30d41..6c8aa365dd3f 100644 --- a/lib/Dialect/Arc/CMakeLists.txt +++ b/lib/Dialect/Arc/CMakeLists.txt @@ -3,6 +3,7 @@ set(CIRCT_Arc_Sources ArcFolds.cpp ArcOps.cpp ArcTypes.cpp + ArcCostModel.cpp ModelInfo.cpp )