Skip to content

Commit

Permalink
DAG: Move scalarizeExtractedVectorLoad to TargetLowering (llvm#122670)
Browse files Browse the repository at this point in the history
SimplifyDemandedVectorElts should be able to use this on loads
  • Loading branch information
arsenm authored Feb 4, 2025
1 parent 8fdd982 commit c55a765
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 89 deletions.
12 changes: 12 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -5622,6 +5622,18 @@ class TargetLowering : public TargetLoweringBase {
// joining their results. SDValue() is returned when expansion did not happen.
SDValue expandVectorNaryOpBySplitting(SDNode *Node, SelectionDAG &DAG) const;

/// Replace an extraction of a load with a narrowed load.
///
/// \param ResultVT type of the result extraction.
/// \param InVecVT type of the input vector to with bitcasts resolved.
/// \param EltNo index of the vector element to load.
/// \param OriginalLoad vector load that to be replaced.
/// \returns \p ResultVT Load on success SDValue() on failure.
SDValue scalarizeExtractedVectorLoad(EVT ResultVT, const SDLoc &DL,
EVT InVecVT, SDValue EltNo,
LoadSDNode *OriginalLoad,
SelectionDAG &DAG) const;

private:
SDValue foldSetCCWithAnd(EVT VT, SDValue N0, SDValue N1, ISD::CondCode Cond,
const SDLoc &DL, DAGCombinerInfo &DCI) const;
Expand Down
103 changes: 14 additions & 89 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,17 +385,6 @@ namespace {
bool getTruncatedStoreValue(StoreSDNode *ST, SDValue &Val);
bool extendLoadedValueToExtension(LoadSDNode *LD, SDValue &Val);

/// Replace an ISD::EXTRACT_VECTOR_ELT of a load with a narrowed
/// load.
///
/// \param EVE ISD::EXTRACT_VECTOR_ELT to be replaced.
/// \param InVecVT type of the input vector to EVE with bitcasts resolved.
/// \param EltNo index of the vector element to load.
/// \param OriginalLoad load that EVE came from to be replaced.
/// \returns EVE on success SDValue() on failure.
SDValue scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
SDValue EltNo,
LoadSDNode *OriginalLoad);
void ReplaceLoadWithPromotedLoad(SDNode *Load, SDNode *ExtLoad);
SDValue PromoteOperand(SDValue Op, EVT PVT, bool &Replace);
SDValue SExtPromoteOperand(SDValue Op, EVT PVT);
Expand Down Expand Up @@ -22719,81 +22708,6 @@ SDValue DAGCombiner::visitINSERT_VECTOR_ELT(SDNode *N) {
return SDValue();
}

SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
SDValue EltNo,
LoadSDNode *OriginalLoad) {
assert(OriginalLoad->isSimple());

EVT ResultVT = EVE->getValueType(0);
EVT VecEltVT = InVecVT.getVectorElementType();

// If the vector element type is not a multiple of a byte then we are unable
// to correctly compute an address to load only the extracted element as a
// scalar.
if (!VecEltVT.isByteSized())
return SDValue();

ISD::LoadExtType ExtTy =
ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
if (!TLI.isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
!TLI.shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
return SDValue();

Align Alignment = OriginalLoad->getAlign();
MachinePointerInfo MPI;
SDLoc DL(EVE);
if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
int Elt = ConstEltNo->getZExtValue();
unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
Alignment = commonAlignment(Alignment, PtrOff);
} else {
// Discard the pointer info except the address space because the memory
// operand can't represent this new access since the offset is variable.
MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
}

unsigned IsFast = 0;
if (!TLI.allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
OriginalLoad->getAddressSpace(), Alignment,
OriginalLoad->getMemOperand()->getFlags(),
&IsFast) ||
!IsFast)
return SDValue();

SDValue NewPtr = TLI.getVectorElementPointer(DAG, OriginalLoad->getBasePtr(),
InVecVT, EltNo);

// We are replacing a vector load with a scalar load. The new load must have
// identical memory op ordering to the original.
SDValue Load;
if (ResultVT.bitsGT(VecEltVT)) {
// If the result type of vextract is wider than the load, then issue an
// extending load instead.
ISD::LoadExtType ExtType =
TLI.isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT) ? ISD::ZEXTLOAD
: ISD::EXTLOAD;
Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
NewPtr, MPI, VecEltVT, Alignment,
OriginalLoad->getMemOperand()->getFlags(),
OriginalLoad->getAAInfo());
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
} else {
// The result type is narrower or the same width as the vector element
Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
Alignment, OriginalLoad->getMemOperand()->getFlags(),
OriginalLoad->getAAInfo());
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
if (ResultVT.bitsLT(VecEltVT))
Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
else
Load = DAG.getBitcast(ResultVT, Load);
}
++OpsNarrowed;
return Load;
}

/// Transform a vector binary operation into a scalar binary operation by moving
/// the math/logic after an extract element of a vector.
static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
Expand Down Expand Up @@ -23272,8 +23186,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
ISD::isNormalLoad(VecOp.getNode()) &&
!Index->hasPredecessor(VecOp.getNode())) {
auto *VecLoad = dyn_cast<LoadSDNode>(VecOp);
if (VecLoad && VecLoad->isSimple())
return scalarizeExtractedVectorLoad(N, VecVT, Index, VecLoad);
if (VecLoad && VecLoad->isSimple()) {
if (SDValue Scalarized = TLI.scalarizeExtractedVectorLoad(
ExtVT, SDLoc(N), VecVT, Index, VecLoad, DAG)) {
++OpsNarrowed;
return Scalarized;
}
}
}

// Perform only after legalization to ensure build_vector / vector_shuffle
Expand Down Expand Up @@ -23361,7 +23280,13 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
if (Elt == -1)
return DAG.getUNDEF(LVT);

return scalarizeExtractedVectorLoad(N, VecVT, Index, LN0);
if (SDValue Scalarized =
TLI.scalarizeExtractedVectorLoad(LVT, DL, VecVT, Index, LN0, DAG)) {
++OpsNarrowed;
return Scalarized;
}

return SDValue();
}

// Simplify (build_vec (ext )) to (bitcast (build_vec ))
Expand Down
74 changes: 74 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12114,3 +12114,77 @@ SDValue TargetLowering::expandVectorNaryOpBySplitting(SDNode *Node,
SDValue SplitOpHi = DAG.getNode(Opcode, DL, HiVT, HiOps);
return DAG.getNode(ISD::CONCAT_VECTORS, DL, VT, SplitOpLo, SplitOpHi);
}

SDValue TargetLowering::scalarizeExtractedVectorLoad(EVT ResultVT,
const SDLoc &DL,
EVT InVecVT, SDValue EltNo,
LoadSDNode *OriginalLoad,
SelectionDAG &DAG) const {
assert(OriginalLoad->isSimple());

EVT VecEltVT = InVecVT.getVectorElementType();

// If the vector element type is not a multiple of a byte then we are unable
// to correctly compute an address to load only the extracted element as a
// scalar.
if (!VecEltVT.isByteSized())
return SDValue();

ISD::LoadExtType ExtTy =
ResultVT.bitsGT(VecEltVT) ? ISD::EXTLOAD : ISD::NON_EXTLOAD;
if (!isOperationLegalOrCustom(ISD::LOAD, VecEltVT) ||
!shouldReduceLoadWidth(OriginalLoad, ExtTy, VecEltVT))
return SDValue();

Align Alignment = OriginalLoad->getAlign();
MachinePointerInfo MPI;
if (auto *ConstEltNo = dyn_cast<ConstantSDNode>(EltNo)) {
int Elt = ConstEltNo->getZExtValue();
unsigned PtrOff = VecEltVT.getSizeInBits() * Elt / 8;
MPI = OriginalLoad->getPointerInfo().getWithOffset(PtrOff);
Alignment = commonAlignment(Alignment, PtrOff);
} else {
// Discard the pointer info except the address space because the memory
// operand can't represent this new access since the offset is variable.
MPI = MachinePointerInfo(OriginalLoad->getPointerInfo().getAddrSpace());
Alignment = commonAlignment(Alignment, VecEltVT.getSizeInBits() / 8);
}

unsigned IsFast = 0;
if (!allowsMemoryAccess(*DAG.getContext(), DAG.getDataLayout(), VecEltVT,
OriginalLoad->getAddressSpace(), Alignment,
OriginalLoad->getMemOperand()->getFlags(), &IsFast) ||
!IsFast)
return SDValue();

SDValue NewPtr =
getVectorElementPointer(DAG, OriginalLoad->getBasePtr(), InVecVT, EltNo);

// We are replacing a vector load with a scalar load. The new load must have
// identical memory op ordering to the original.
SDValue Load;
if (ResultVT.bitsGT(VecEltVT)) {
// If the result type of vextract is wider than the load, then issue an
// extending load instead.
ISD::LoadExtType ExtType = isLoadExtLegal(ISD::ZEXTLOAD, ResultVT, VecEltVT)
? ISD::ZEXTLOAD
: ISD::EXTLOAD;
Load = DAG.getExtLoad(ExtType, DL, ResultVT, OriginalLoad->getChain(),
NewPtr, MPI, VecEltVT, Alignment,
OriginalLoad->getMemOperand()->getFlags(),
OriginalLoad->getAAInfo());
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
} else {
// The result type is narrower or the same width as the vector element
Load = DAG.getLoad(VecEltVT, DL, OriginalLoad->getChain(), NewPtr, MPI,
Alignment, OriginalLoad->getMemOperand()->getFlags(),
OriginalLoad->getAAInfo());
DAG.makeEquivalentMemoryOrdering(OriginalLoad, Load);
if (ResultVT.bitsLT(VecEltVT))
Load = DAG.getNode(ISD::TRUNCATE, DL, ResultVT, Load);
else
Load = DAG.getBitcast(ResultVT, Load);
}

return Load;
}

0 comments on commit c55a765

Please sign in to comment.