diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 3b7e24414c490c..0f76ad6c5e9288 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1063,6 +1063,45 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, } }; + // Sets common actions for f16 and bf16 for when there's only + // zvfhmin/zvfbfmin and we need to promote to f32 for most operations. + const auto SetCommonPromoteToF32Actions = [&](MVT VT) { + setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom); + setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT, + Custom); + setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom); + setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT, + Custom); + setOperationAction(ISD::SELECT_CC, VT, Expand); + setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, ISD::VP_SINT_TO_FP, + ISD::VP_UINT_TO_FP}, + VT, Custom); + setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR, + ISD::EXTRACT_SUBVECTOR, ISD::VECTOR_INTERLEAVE, + ISD::VECTOR_DEINTERLEAVE}, + VT, Custom); + MVT EltVT = VT.getVectorElementType(); + if (isTypeLegal(EltVT)) + setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); + else + setOperationAction(ISD::SPLAT_VECTOR, EltVT, Custom); + setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom); + + setOperationAction(ISD::FNEG, VT, Expand); + setOperationAction(ISD::FABS, VT, Expand); + setOperationAction(ISD::FCOPYSIGN, VT, Expand); + + // Custom split nxv32[b]f16 since nxv32[b]f32 is not legal. + if (getLMUL(VT) == RISCVII::VLMUL::LMUL_8) { + setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom); + setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom); + } else { + MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); + setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT); + setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT); + } + }; + if (Subtarget.hasVInstructionsF16()) { for (MVT VT : F16VecVTs) { if (!isTypeLegal(VT)) @@ -1073,83 +1112,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, for (MVT VT : F16VecVTs) { if (!isTypeLegal(VT)) continue; - setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom); - setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT, - Custom); - setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom); - setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT, - Custom); - setOperationAction(ISD::SELECT_CC, VT, Expand); - setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, - ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP}, - VT, Custom); - setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR, - ISD::EXTRACT_SUBVECTOR, ISD::VECTOR_INTERLEAVE, - ISD::VECTOR_DEINTERLEAVE}, - VT, Custom); - if (Subtarget.hasStdExtZfhmin()) - setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); - else - setOperationAction(ISD::SPLAT_VECTOR, MVT::f16, Custom); - // load/store - setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom); - - setOperationAction(ISD::FNEG, VT, Expand); - setOperationAction(ISD::FABS, VT, Expand); - setOperationAction(ISD::FCOPYSIGN, VT, Expand); - - // Custom split nxv32f16 since nxv32f32 is not legal. - if (VT == MVT::nxv32f16) { - setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom); - setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom); - continue; - } - // Add more promote ops. - MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); - setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT); - setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT); + SetCommonPromoteToF32Actions(VT); } } - // TODO: merge with zvfhmin if (Subtarget.hasVInstructionsBF16Minimal()) { for (MVT VT : BF16VecVTs) { if (!isTypeLegal(VT)) continue; - setOperationAction({ISD::FP_ROUND, ISD::FP_EXTEND}, VT, Custom); - setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT, - Custom); - setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom); - setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT, - Custom); - setOperationAction(ISD::SELECT_CC, VT, Expand); - setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP, - ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP}, - VT, Custom); - setOperationAction({ISD::CONCAT_VECTORS, ISD::INSERT_SUBVECTOR, - ISD::EXTRACT_SUBVECTOR, ISD::VECTOR_INTERLEAVE, - ISD::VECTOR_DEINTERLEAVE}, - VT, Custom); - if (Subtarget.hasStdExtZfbfmin()) - setOperationAction(ISD::SPLAT_VECTOR, VT, Custom); - else - setOperationAction(ISD::SPLAT_VECTOR, MVT::bf16, Custom); - setOperationAction({ISD::LOAD, ISD::STORE}, VT, Custom); - - setOperationAction(ISD::FNEG, VT, Expand); - setOperationAction(ISD::FABS, VT, Expand); - setOperationAction(ISD::FCOPYSIGN, VT, Expand); - - // Custom split nxv32f16 since nxv32f32 is not legal. - if (VT == MVT::nxv32bf16) { - setOperationAction(ZvfhminZvfbfminPromoteOps, VT, Custom); - setOperationAction(ZvfhminZvfbfminPromoteVPOps, VT, Custom); - continue; - } - // Add more promote ops. - MVT F32VecVT = MVT::getVectorVT(MVT::f32, VT.getVectorElementCount()); - setOperationPromotedToType(ZvfhminZvfbfminPromoteOps, VT, F32VecVT); - setOperationPromotedToType(ZvfhminZvfbfminPromoteVPOps, VT, F32VecVT); + SetCommonPromoteToF32Actions(VT); } }