Skip to content

Commit

Permalink
[VectorCombine][X86] foldShuffleOfCastops - fold shuffle(cast(x),cast…
Browse files Browse the repository at this point in the history
…(y)) -> cast(shuffle(x,y)) iff cost efficient

Based off the existing foldShuffleOfBinops fold

Fixes #67803
  • Loading branch information
RKSimon committed Apr 3, 2024
1 parent 4d8a3f5 commit a12447a
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 37 deletions.
55 changes: 55 additions & 0 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class VectorCombine {
bool foldSingleElementStore(Instruction &I);
bool scalarizeLoadExtract(Instruction &I);
bool foldShuffleOfBinops(Instruction &I);
bool foldShuffleOfCastops(Instruction &I);
bool foldShuffleFromReductions(Instruction &I);
bool foldTruncFromReductions(Instruction &I);
bool foldSelectShuffle(Instruction &I, bool FromReduction = false);
Expand Down Expand Up @@ -1432,6 +1433,59 @@ bool VectorCombine::foldShuffleOfBinops(Instruction &I) {
return true;
}

/// Try to convert "shuffle (castop), (castop)" with a shared castop operand into
/// "castop (shuffle)".
bool VectorCombine::foldShuffleOfCastops(Instruction &I) {
Value *V0, *V1;
ArrayRef<int> Mask;
if (!match(&I, m_Shuffle(m_OneUse(m_Value(V0)), m_OneUse(m_Value(V1)),
m_Mask(Mask))))
return false;

auto *C0 = dyn_cast<CastInst>(V0);
auto *C1 = dyn_cast<CastInst>(V1);
if (!C0 || !C1)
return false;

Instruction::CastOps Opcode = C0->getOpcode();
if (Opcode == Instruction::BitCast || Opcode != C1->getOpcode() ||
C0->getSrcTy() != C1->getSrcTy())
return false;

auto *ShuffleDstTy = dyn_cast<FixedVectorType>(I.getType());
auto *CastDstTy = dyn_cast<FixedVectorType>(C0->getDestTy());
auto *CastSrcTy = dyn_cast<FixedVectorType>(C0->getSrcTy());
if (!ShuffleDstTy || !CastDstTy || !CastSrcTy ||
CastDstTy->getElementCount() != CastSrcTy->getElementCount())
return false;

auto *NewCastSrcTy =
FixedVectorType::get(CastSrcTy->getScalarType(), Mask.size());

// Try to replace a castop with a shuffle if the shuffle is not costly.
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
InstructionCost OldCost =
TTI.getCastInstrCost(Opcode, CastDstTy, CastSrcTy,
TTI::CastContextHint::None, CostKind) +
TTI.getCastInstrCost(Opcode, CastDstTy, CastSrcTy,
TTI::CastContextHint::None, CostKind);
OldCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
CastDstTy, Mask, CostKind);
InstructionCost NewCost = TTI.getCastInstrCost(
Opcode, ShuffleDstTy, NewCastSrcTy, TTI::CastContextHint::None, CostKind);
NewCost += TTI.getShuffleCost(TargetTransformInfo::SK_PermuteTwoSrc,
CastSrcTy, Mask, CostKind);
if (NewCost > OldCost)
return false;

Value *Shuf =
Builder.CreateShuffleVector(C0->getOperand(0), C1->getOperand(0), Mask);
Value *Cast = Builder.CreateCast(Opcode, Shuf, ShuffleDstTy);
// TODO Copy IR flags?
replaceValue(I, *Cast);
return true;
}

/// Given a commutative reduction, the order of the input lanes does not alter
/// the results. We can use this to remove certain shuffles feeding the
/// reduction, removing the need to shuffle at all.
Expand Down Expand Up @@ -1986,6 +2040,7 @@ bool VectorCombine::run() {
break;
case Instruction::ShuffleVector:
MadeChange |= foldShuffleOfBinops(I);
MadeChange |= foldShuffleOfCastops(I);
MadeChange |= foldSelectShuffle(I);
break;
case Instruction::BitCast:
Expand Down
6 changes: 1 addition & 5 deletions llvm/test/Transforms/PhaseOrdering/X86/pr67803.ll
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@ define <4 x i64> @PR67803(<4 x i64> %x, <4 x i64> %y, <4 x i64> %a, <4 x i64> %b
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i64> [[X:%.*]] to <8 x i32>
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i64> [[Y:%.*]] to <8 x i32>
; CHECK-NEXT: [[TMP2:%.*]] = icmp sgt <8 x i32> [[TMP0]], [[TMP1]]
; CHECK-NEXT: [[CMP_I21:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> poison, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
; CHECK-NEXT: [[SEXT_I22:%.*]] = sext <4 x i1> [[CMP_I21]] to <4 x i32>
; CHECK-NEXT: [[CMP_I:%.*]] = shufflevector <8 x i1> [[TMP2]], <8 x i1> poison, <4 x i32> <i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[SEXT_I:%.*]] = sext <4 x i1> [[CMP_I]] to <4 x i32>
; CHECK-NEXT: [[TMP3:%.*]] = shufflevector <4 x i32> [[SEXT_I22]], <4 x i32> [[SEXT_I]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i32>
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <4 x i64> [[A:%.*]] to <32 x i8>
; CHECK-NEXT: [[TMP6:%.*]] = shufflevector <32 x i8> [[TMP5]], <32 x i8> poison, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[TMP7:%.*]] = bitcast <4 x i64> [[B:%.*]] to <32 x i8>
Expand Down
58 changes: 26 additions & 32 deletions llvm/test/Transforms/VectorCombine/X86/shuffle-of-casts.ll
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

define <16 x i32> @concat_zext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
; CHECK-LABEL: @concat_zext_v8i16_v16i32(
; CHECK-NEXT: [[X0:%.*]] = zext <8 x i16> [[A0:%.*]] to <8 x i32>
; CHECK-NEXT: [[X1:%.*]] = zext <8 x i16> [[A1:%.*]] to <8 x i32>
; CHECK-NEXT: [[R:%.*]] = shufflevector <8 x i32> [[X0]], <8 x i32> [[X1]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i16> [[A0:%.*]], <8 x i16> [[A1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[R:%.*]] = zext <16 x i16> [[TMP1]] to <16 x i32>
; CHECK-NEXT: ret <16 x i32> [[R]]
;
%x0 = zext <8 x i16> %a0 to <8 x i32>
Expand All @@ -19,9 +18,8 @@ define <16 x i32> @concat_zext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {

define <16 x i32> @concat_sext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
; CHECK-LABEL: @concat_sext_v8i16_v16i32(
; CHECK-NEXT: [[X0:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
; CHECK-NEXT: [[X1:%.*]] = sext <8 x i16> [[A1:%.*]] to <8 x i32>
; CHECK-NEXT: [[R:%.*]] = shufflevector <8 x i32> [[X0]], <8 x i32> [[X1]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i16> [[A0:%.*]], <8 x i16> [[A1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[R:%.*]] = sext <16 x i16> [[TMP1]] to <16 x i32>
; CHECK-NEXT: ret <16 x i32> [[R]]
;
%x0 = sext <8 x i16> %a0 to <8 x i32>
Expand All @@ -32,9 +30,8 @@ define <16 x i32> @concat_sext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {

define <8 x i32> @concat_sext_v4i1_v8i32(<4 x i1> %a0, <4 x i1> %a1) {
; CHECK-LABEL: @concat_sext_v4i1_v8i32(
; CHECK-NEXT: [[X0:%.*]] = sext <4 x i1> [[A0:%.*]] to <4 x i32>
; CHECK-NEXT: [[X1:%.*]] = sext <4 x i1> [[A1:%.*]] to <4 x i32>
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x i32> [[X0]], <4 x i32> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i1> [[A0:%.*]], <4 x i1> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[R:%.*]] = sext <8 x i1> [[TMP1]] to <8 x i32>
; CHECK-NEXT: ret <8 x i32> [[R]]
;
%x0 = sext <4 x i1> %a0 to <4 x i32>
Expand All @@ -45,9 +42,8 @@ define <8 x i32> @concat_sext_v4i1_v8i32(<4 x i1> %a0, <4 x i1> %a1) {

define <8 x i16> @concat_trunc_v4i32_v8i16(<4 x i32> %a0, <4 x i32> %a1) {
; CHECK-LABEL: @concat_trunc_v4i32_v8i16(
; CHECK-NEXT: [[X0:%.*]] = trunc <4 x i32> [[A0:%.*]] to <4 x i16>
; CHECK-NEXT: [[X1:%.*]] = trunc <4 x i32> [[A1:%.*]] to <4 x i16>
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x i16> [[X0]], <4 x i16> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[A0:%.*]], <4 x i32> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[R:%.*]] = trunc <8 x i32> [[TMP1]] to <8 x i16>
; CHECK-NEXT: ret <8 x i16> [[R]]
;
%x0 = trunc <4 x i32> %a0 to <4 x i16>
Expand All @@ -58,9 +54,8 @@ define <8 x i16> @concat_trunc_v4i32_v8i16(<4 x i32> %a0, <4 x i32> %a1) {

define <8 x ptr> @concat_inttoptr_v4i32_v8iptr(<4 x i32> %a0, <4 x i32> %a1) {
; CHECK-LABEL: @concat_inttoptr_v4i32_v8iptr(
; CHECK-NEXT: [[X0:%.*]] = inttoptr <4 x i32> [[A0:%.*]] to <4 x ptr>
; CHECK-NEXT: [[X1:%.*]] = inttoptr <4 x i32> [[A1:%.*]] to <4 x ptr>
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x ptr> [[X0]], <4 x ptr> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x i32> [[A0:%.*]], <4 x i32> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[R:%.*]] = inttoptr <8 x i32> [[TMP1]] to <8 x ptr>
; CHECK-NEXT: ret <8 x ptr> [[R]]
;
%x0 = inttoptr <4 x i32> %a0 to <4 x ptr>
Expand All @@ -71,9 +66,8 @@ define <8 x ptr> @concat_inttoptr_v4i32_v8iptr(<4 x i32> %a0, <4 x i32> %a1) {

define <16 x i64> @concat_ptrtoint_v8i16_v16i32(<8 x ptr> %a0, <8 x ptr> %a1) {
; CHECK-LABEL: @concat_ptrtoint_v8i16_v16i32(
; CHECK-NEXT: [[X0:%.*]] = ptrtoint <8 x ptr> [[A0:%.*]] to <8 x i64>
; CHECK-NEXT: [[X1:%.*]] = ptrtoint <8 x ptr> [[A1:%.*]] to <8 x i64>
; CHECK-NEXT: [[R:%.*]] = shufflevector <8 x i64> [[X0]], <8 x i64> [[X1]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x ptr> [[A0:%.*]], <8 x ptr> [[A1:%.*]], <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
; CHECK-NEXT: [[R:%.*]] = ptrtoint <16 x ptr> [[TMP1]] to <16 x i64>
; CHECK-NEXT: ret <16 x i64> [[R]]
;
%x0 = ptrtoint <8 x ptr> %a0 to <8 x i64>
Expand All @@ -83,11 +77,16 @@ define <16 x i64> @concat_ptrtoint_v8i16_v16i32(<8 x ptr> %a0, <8 x ptr> %a1) {
}

define <8 x double> @concat_fpext_v4f32_v8f64(<4 x float> %a0, <4 x float> %a1) {
; CHECK-LABEL: @concat_fpext_v4f32_v8f64(
; CHECK-NEXT: [[X0:%.*]] = fpext <4 x float> [[A0:%.*]] to <4 x double>
; CHECK-NEXT: [[X1:%.*]] = fpext <4 x float> [[A1:%.*]] to <4 x double>
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x double> [[X0]], <4 x double> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: ret <8 x double> [[R]]
; SSE-LABEL: @concat_fpext_v4f32_v8f64(
; SSE-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; SSE-NEXT: [[R:%.*]] = fpext <8 x float> [[TMP1]] to <8 x double>
; SSE-NEXT: ret <8 x double> [[R]]
;
; AVX-LABEL: @concat_fpext_v4f32_v8f64(
; AVX-NEXT: [[X0:%.*]] = fpext <4 x float> [[A0:%.*]] to <4 x double>
; AVX-NEXT: [[X1:%.*]] = fpext <4 x float> [[A1:%.*]] to <4 x double>
; AVX-NEXT: [[R:%.*]] = shufflevector <4 x double> [[X0]], <4 x double> [[X1]], <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; AVX-NEXT: ret <8 x double> [[R]]
;
%x0 = fpext <4 x float> %a0 to <4 x double>
%x1 = fpext <4 x float> %a1 to <4 x double>
Expand All @@ -112,9 +111,8 @@ define <16 x float> @concat_fptrunc_v8f64_v16f32(<8 x double> %a0, <8 x double>

define <16 x i32> @rconcat_sext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
; CHECK-LABEL: @rconcat_sext_v8i16_v16i32(
; CHECK-NEXT: [[X0:%.*]] = sext <8 x i16> [[A0:%.*]] to <8 x i32>
; CHECK-NEXT: [[X1:%.*]] = sext <8 x i16> [[A1:%.*]] to <8 x i32>
; CHECK-NEXT: [[R:%.*]] = shufflevector <8 x i32> [[X0]], <8 x i32> [[X1]], <16 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <8 x i16> [[A0:%.*]], <8 x i16> [[A1:%.*]], <16 x i32> <i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15, i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>
; CHECK-NEXT: [[R:%.*]] = sext <16 x i16> [[TMP1]] to <16 x i32>
; CHECK-NEXT: ret <16 x i32> [[R]]
;
%x0 = sext <8 x i16> %a0 to <8 x i32>
Expand All @@ -127,9 +125,8 @@ define <16 x i32> @rconcat_sext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {

define <8 x double> @interleave_fpext_v4f32_v8f64(<4 x float> %a0, <4 x float> %a1) {
; CHECK-LABEL: @interleave_fpext_v4f32_v8f64(
; CHECK-NEXT: [[X0:%.*]] = fpext <4 x float> [[A0:%.*]] to <4 x double>
; CHECK-NEXT: [[X1:%.*]] = fpext <4 x float> [[A1:%.*]] to <4 x double>
; CHECK-NEXT: [[R:%.*]] = shufflevector <4 x double> [[X0]], <4 x double> [[X1]], <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x float> [[A0:%.*]], <4 x float> [[A1:%.*]], <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
; CHECK-NEXT: [[R:%.*]] = fpext <8 x float> [[TMP1]] to <8 x double>
; CHECK-NEXT: ret <8 x double> [[R]]
;
%x0 = fpext <4 x float> %a0 to <4 x double>
Expand Down Expand Up @@ -184,6 +181,3 @@ define <16 x i32> @concat_sext_zext_v8i16_v16i32(<8 x i16> %a0, <8 x i16> %a1) {
%r = shufflevector <8 x i32> %x0, <8 x i32> %x1, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, i32 9, i32 10, i32 11, i32 12, i32 13, i32 14, i32 15>
ret <16 x i32> %r
}
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
; AVX: {{.*}}
; SSE: {{.*}}

0 comments on commit a12447a

Please sign in to comment.