Skip to content

Commit

Permalink
[RISCV] Promote bf16 ops to f32 with zvfbfmin (llvm#108937)
Browse files Browse the repository at this point in the history
For f16 with zvfhmin, we promote most ops and VP ops to f32. This does
the same for bf16 with zvfbfmin, so the two fp types should now be in
sync.

There are a few places in the custom lowering where we need to check for
a LMUL 8 f16/bf16 vector that can't be promoted and must be split, this
extracts that out into isPromotedOpNeedingSplit.

In a follow up NFC we can deduplicate the code that sets up the
promotions.
  • Loading branch information
lukel97 authored and tmsri committed Sep 19, 2024
1 parent a693c12 commit 0c0d915
Show file tree
Hide file tree
Showing 44 changed files with 17,683 additions and 1,582 deletions.
145 changes: 70 additions & 75 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
};

// TODO: support more ops.
static const unsigned ZvfhminPromoteOps[] = {
static const unsigned ZvfhminZvfbfminPromoteOps[] = {
ISD::FMINNUM, ISD::FMAXNUM, ISD::FADD, ISD::FSUB,
ISD::FMUL, ISD::FMA, ISD::FDIV, ISD::FSQRT,
ISD::FCEIL, ISD::FTRUNC, ISD::FFLOOR, ISD::FROUND,
Expand All @@ -951,30 +951,31 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::STRICT_FMA};

// TODO: support more vp ops.
static const unsigned ZvfhminPromoteVPOps[] = {ISD::VP_FADD,
ISD::VP_FSUB,
ISD::VP_FMUL,
ISD::VP_FDIV,
ISD::VP_FMA,
ISD::VP_REDUCE_FADD,
ISD::VP_REDUCE_SEQ_FADD,
ISD::VP_REDUCE_FMIN,
ISD::VP_REDUCE_FMAX,
ISD::VP_SQRT,
ISD::VP_FMINNUM,
ISD::VP_FMAXNUM,
ISD::VP_FCEIL,
ISD::VP_FFLOOR,
ISD::VP_FROUND,
ISD::VP_FROUNDEVEN,
ISD::VP_FROUNDTOZERO,
ISD::VP_FRINT,
ISD::VP_FNEARBYINT,
ISD::VP_SETCC,
ISD::VP_FMINIMUM,
ISD::VP_FMAXIMUM,
ISD::VP_REDUCE_FMINIMUM,
ISD::VP_REDUCE_FMAXIMUM};
static const unsigned ZvfhminZvfbfminPromoteVPOps[] = {
ISD::VP_FADD,
ISD::VP_FSUB,
ISD::VP_FMUL,
ISD::VP_FDIV,
ISD::VP_FMA,
ISD::VP_REDUCE_FADD,
ISD::VP_REDUCE_SEQ_FADD,
ISD::VP_REDUCE_FMIN,
ISD::VP_REDUCE_FMAX,
ISD::VP_SQRT,
ISD::VP_FMINNUM,
ISD::VP_FMAXNUM,
ISD::VP_FCEIL,
ISD::VP_FFLOOR,
ISD::VP_FROUND,
ISD::VP_FROUNDEVEN,
ISD::VP_FROUNDTOZERO,
ISD::VP_FRINT,
ISD::VP_FNEARBYINT,
ISD::VP_SETCC,
ISD::VP_FMINIMUM,
ISD::VP_FMAXIMUM,
ISD::VP_REDUCE_FMINIMUM,
ISD::VP_REDUCE_FMAXIMUM};

// Sets common operation actions on RVV floating-point vector types.
const auto SetCommonVFPActions = [&](MVT VT) {
Expand Down Expand Up @@ -1097,20 +1098,20 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FABS, VT, Expand);
setOperationAction(ISD::FCOPYSIGN, VT, Expand);

// Custom split nxv32f16 since nxv32f32 if not legal.
// Custom split nxv32f16 since nxv32f32 is not legal.
if (VT == MVT::nxv32f16) {
setOperationAction(ZvfhminPromoteOps, VT, Custom);
setOperationAction(ZvfhminPromoteVPOps, VT, Custom);
setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom);
setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom);
continue;
}
// Add more promote ops.
MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
setOperationPromotedToType(ZvfhminPromoteOps, VT, F32VecVT);
setOperationPromotedToType(ZvfhminPromoteVPOps, VT, F32VecVT);
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
}
}

// TODO: Could we merge some code with zvfhmin?
// TODO: merge with zvfhmin
if (Subtarget.hasVInstructionsBF16Minimal()) {
for (MVT VT : BF16VecVTs) {
if (!isTypeLegal(VT))
Expand Down Expand Up @@ -1139,7 +1140,16 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FABS, VT, Expand);
setOperationAction(ISD::FCOPYSIGN, VT, Expand);

// TODO: Promote to fp32.
// Custom split nxv32f16 since nxv32f32 is not legal.
if (VT == MVT::nxv32bf16) {
setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom);
setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom);
continue;
}
// Add more promote ops.
MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount());
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
}
}

Expand Down Expand Up @@ -1375,8 +1385,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// TODO: could split the f16 vector into two vectors and do promotion.
if (!isTypeLegal(F32VecVT))
continue;
setOperationPromotedToType(ZvfhminPromoteOps, VT, F32VecVT);
setOperationPromotedToType(ZvfhminPromoteVPOps, VT, F32VecVT);
setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT);
setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT);
continue;
}

Expand Down Expand Up @@ -6333,6 +6343,17 @@ static bool hasMaskOp(unsigned Opcode) {
return false;
}

static bool isPromotedOpNeedingSplit(SDValue Op,
const RISCVSubtarget &Subtarget) {
if (Op.getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
return true;
if (Op.getValueType() == MVT::nxv32bf16)
return true;
return false;
}

static SDValue SplitVectorOp(SDValue Op, SelectionDAG &DAG) {
auto [LoVT, HiVT] = DAG.GetSplitDestVTs(Op.getValueType());
SDLoc DL(Op);
Expand Down Expand Up @@ -6670,9 +6691,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
}
case ISD::FMAXIMUM:
case ISD::FMINIMUM:
if (Op.getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVectorOp(Op, DAG);
return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
case ISD::FP_EXTEND:
Expand All @@ -6688,8 +6707,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16())) ||
Op.getValueType().getScalarType() == MVT::bf16)) {
if (Op.getValueType() == MVT::nxv32f16 ||
Op.getValueType() == MVT::nxv32bf16)
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVectorOp(Op, DAG);
// int -> f32
SDLoc DL(Op);
Expand All @@ -6709,8 +6727,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16())) ||
Op1.getValueType().getScalarType() == MVT::bf16)) {
if (Op1.getValueType() == MVT::nxv32f16 ||
Op1.getValueType() == MVT::nxv32bf16)
if (isPromotedOpNeedingSplit(Op1, Subtarget))
return SplitVectorOp(Op, DAG);
// [b]f16 -> f32
SDLoc DL(Op);
Expand Down Expand Up @@ -6942,9 +6959,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FRINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
if (Op.getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVectorOp(Op, DAG);
return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
case ISD::LRINT:
Expand Down Expand Up @@ -7002,9 +7017,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::VP_REDUCE_FMAX:
case ISD::VP_REDUCE_FMINIMUM:
case ISD::VP_REDUCE_FMAXIMUM:
if (Op.getOperand(1).getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op.getOperand(1), Subtarget))
return SplitVectorReductionOp(Op, DAG);
return lowerVPREDUCE(Op, DAG);
case ISD::VP_REDUCE_AND:
Expand Down Expand Up @@ -7251,9 +7264,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
return DAG.getSetCC(DL, VT, RHS, LHS, CCVal);
}

if (Op.getOperand(0).getSimpleValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op.getOperand(0), Subtarget))
return SplitVectorOp(Op, DAG);

return lowerFixedLengthVectorSetccToRVV(Op, DAG);
Expand Down Expand Up @@ -7295,9 +7306,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FMA:
case ISD::FMINNUM:
case ISD::FMAXNUM:
if (Op.getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVectorOp(Op, DAG);
[[fallthrough]];
case ISD::AVGFLOORS:
Expand Down Expand Up @@ -7345,9 +7354,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FCOPYSIGN:
if (Op.getValueType() == MVT::f16 || Op.getValueType() == MVT::bf16)
return lowerFCOPYSIGN(Op, DAG, Subtarget);
if (Op.getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVectorOp(Op, DAG);
return lowerFixedLengthVectorFCOPYSIGNToRVV(Op, DAG);
case ISD::STRICT_FADD:
Expand All @@ -7356,9 +7363,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::STRICT_FDIV:
case ISD::STRICT_FSQRT:
case ISD::STRICT_FMA:
if (Op.getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitStrictFPVectorOp(Op, DAG);
return lowerToScalableOp(Op, DAG);
case ISD::STRICT_FSETCC:
Expand Down Expand Up @@ -7415,9 +7420,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::VP_FMINNUM:
case ISD::VP_FMAXNUM:
case ISD::VP_FCOPYSIGN:
if (Op.getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVPOp(Op, DAG);
[[fallthrough]];
case ISD::VP_SRA:
Expand All @@ -7443,8 +7446,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16())) ||
Op.getValueType().getScalarType() == MVT::bf16)) {
if (Op.getValueType() == MVT::nxv32f16 ||
Op.getValueType() == MVT::nxv32bf16)
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVectorOp(Op, DAG);
// int -> f32
SDLoc DL(Op);
Expand All @@ -7464,8 +7466,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16())) ||
Op1.getValueType().getScalarType() == MVT::bf16)) {
if (Op1.getValueType() == MVT::nxv32f16 ||
Op1.getValueType() == MVT::nxv32bf16)
if (isPromotedOpNeedingSplit(Op1, Subtarget))
return SplitVectorOp(Op, DAG);
// [b]f16 -> f32
SDLoc DL(Op);
Expand All @@ -7478,9 +7479,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
}
return lowerVPFPIntConvOp(Op, DAG);
case ISD::VP_SETCC:
if (Op.getOperand(0).getSimpleValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op.getOperand(0), Subtarget))
return SplitVPOp(Op, DAG);
if (Op.getOperand(0).getSimpleValueType().getVectorElementType() == MVT::i1)
return lowerVPSetCCMaskOp(Op, DAG);
Expand Down Expand Up @@ -7515,16 +7514,12 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::VP_FROUND:
case ISD::VP_FROUNDEVEN:
case ISD::VP_FROUNDTOZERO:
if (Op.getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVPOp(Op, DAG);
return lowerVectorFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
case ISD::VP_FMAXIMUM:
case ISD::VP_FMINIMUM:
if (Op.getValueType() == MVT::nxv32f16 &&
(Subtarget.hasVInstructionsF16Minimal() &&
!Subtarget.hasVInstructionsF16()))
if (isPromotedOpNeedingSplit(Op, Subtarget))
return SplitVPOp(Op, DAG);
return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
case ISD::EXPERIMENTAL_VP_SPLICE:
Expand Down
Loading

0 comments on commit 0c0d915

Please sign in to comment.