diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h index 68a62638b9d5888..6a61ef63c2a0549 100644 --- a/llvm/lib/Transforms/Vectorize/VPlan.h +++ b/llvm/lib/Transforms/Vectorize/VPlan.h @@ -2476,6 +2476,10 @@ class VPReductionRecipe : public VPSingleDefRecipe { /// Generate the reduction in the loop void execute(VPTransformState &State) override; + /// Return the cost of VPReductionRecipe. + InstructionCost computeCost(ElementCount VF, + VPCostContext &Ctx) const override; + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// Print the recipe. void print(raw_ostream &O, const Twine &Indent, diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp index 2948ecc580edc05..368d6e58a5578ec 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp @@ -2071,6 +2071,40 @@ void VPReductionEVLRecipe::execute(VPTransformState &State) { State.set(this, NewRed, /*IsScalar*/ true); } +InstructionCost VPReductionRecipe::computeCost(ElementCount VF, + VPCostContext &Ctx) const { + RecurKind RdxKind = RdxDesc.getRecurrenceKind(); + Type *ElementTy = Ctx.Types.inferScalarType(this); + auto *VectorTy = cast(ToVectorTy(ElementTy, VF)); + TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput; + unsigned Opcode = RdxDesc.getOpcode(); + + // TODO: Support any-of and in-loop reductions. + assert( + (!RecurrenceDescriptor::isAnyOfRecurrenceKind(RdxKind) || + ForceTargetInstructionCost.getNumOccurrences() > 0) && + "Any-of reduction not implemented in VPlan-based cost model currently."); + assert( + (!cast(getOperand(0))->isInLoop() || + ForceTargetInstructionCost.getNumOccurrences() > 0) && + "In-loop reduction not implemented in VPlan-based cost model currently."); + + assert(ElementTy->getTypeID() == RdxDesc.getRecurrenceType()->getTypeID() && + "Inferred type and recurrence type mismatch."); + + // Cost = Reduction cost + BinOp cost + InstructionCost Cost = + Ctx.TTI.getArithmeticInstrCost(Opcode, ElementTy, CostKind); + if (RecurrenceDescriptor::isMinMaxRecurrenceKind(RdxKind)) { + Intrinsic::ID Id = getMinMaxReductionIntrinsicOp(RdxKind); + return Cost + Ctx.TTI.getMinMaxReductionCost( + Id, VectorTy, RdxDesc.getFastMathFlags(), CostKind); + } + + return Cost + Ctx.TTI.getArithmeticReductionCost( + Opcode, VectorTy, RdxDesc.getFastMathFlags(), CostKind); +} + #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void VPReductionRecipe::print(raw_ostream &O, const Twine &Indent, VPSlotTracker &SlotTracker) const {