Skip to content

Commit

Permalink
[X86] Adding lowerings for vector ISD::LRINT and ISD::LLRINT (#90065)
Browse files Browse the repository at this point in the history
- [V]CVTP[D,S]2DQ supports `f64/f32` -> `i32` conversions that can be
mapped to `llvm.lrint.vNi32.vNf64/32` since SSE2. AVX and AVX512 added
256-bit and 512-bit support;
- VCVTP[D,S]2QQ supports `f64/f32` -> `i64` conversions that can be
mapped to `llvm.l[l]rint.vNi64.vNf64/32` since AVX512DQ. All 128-bit,
256-bit (require AVX512VL) and 512-bit are supported.
  • Loading branch information
phoebewang committed May 3, 2024
1 parent 1949856 commit fd3e7e3
Show file tree
Hide file tree
Showing 5 changed files with 911 additions and 695 deletions.
46 changes: 46 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::FABS, MVT::v2f64, Custom);
setOperationAction(ISD::FCOPYSIGN, MVT::v2f64, Custom);

setOperationAction(ISD::LRINT, MVT::v4f32, Custom);

for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
setOperationAction(ISD::SMAX, VT, VT == MVT::v8i16 ? Legal : Custom);
setOperationAction(ISD::SMIN, VT, VT == MVT::v8i16 ? Legal : Custom);
Expand Down Expand Up @@ -1431,6 +1433,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::FMINIMUM, VT, Custom);
}

setOperationAction(ISD::LRINT, MVT::v8f32, Custom);
setOperationAction(ISD::LRINT, MVT::v4f64, Custom);

// (fp_to_int:v8i16 (v8f32 ..)) requires the result type to be promoted
// even though v8i16 is a legal type.
setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v8i16, MVT::v8i32);
Expand Down Expand Up @@ -1731,6 +1736,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1 })
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
}
if (Subtarget.hasDQI() && Subtarget.hasVLX()) {
for (MVT VT : {MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) {
setOperationAction(ISD::LRINT, VT, Legal);
setOperationAction(ISD::LLRINT, VT, Legal);
}
}

// This block controls legalization for 512-bit operations with 8/16/32/64 bit
// elements. 512-bits can be disabled based on prefer-vector-width and
Expand Down Expand Up @@ -1765,6 +1776,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(ISD::STRICT_FMA, VT, Legal);
setOperationAction(ISD::FCOPYSIGN, VT, Custom);
}
setOperationAction(ISD::LRINT, MVT::v16f32,
Subtarget.hasDQI() ? Legal : Custom);
setOperationAction(ISD::LRINT, MVT::v8f64,
Subtarget.hasDQI() ? Legal : Custom);
if (Subtarget.hasDQI())
setOperationAction(ISD::LLRINT, MVT::v8f64, Legal);

for (MVT VT : { MVT::v16i1, MVT::v16i8 }) {
setOperationPromotedToType(ISD::FP_TO_SINT , VT, MVT::v16i32);
Expand Down Expand Up @@ -2488,6 +2505,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
ISD::FMAXNUM,
ISD::SUB,
ISD::LOAD,
ISD::LRINT,
ISD::LLRINT,
ISD::MLOAD,
ISD::STORE,
ISD::MSTORE,
Expand Down Expand Up @@ -21161,8 +21180,12 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
SDValue X86TargetLowering::LowerLRINT_LLRINT(SDValue Op,
SelectionDAG &DAG) const {
SDValue Src = Op.getOperand(0);
EVT DstVT = Op.getSimpleValueType();
MVT SrcVT = Src.getSimpleValueType();

if (SrcVT.isVector())
return DstVT.getScalarType() == MVT::i32 ? Op : SDValue();

if (SrcVT == MVT::f16)
return SDValue();

Expand Down Expand Up @@ -51542,6 +51565,22 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
SDValue Src = N->getOperand(0);
EVT SrcVT = Src.getValueType();
SDLoc DL(N);

if (!Subtarget.hasDQI() || !Subtarget.hasVLX() || VT != MVT::v2i64 ||
SrcVT != MVT::v2f32)
return SDValue();

return DAG.getNode(X86ISD::CVTP2SI, DL, VT,
DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32, Src,
DAG.getUNDEF(SrcVT)));
}

/// Attempt to pre-truncate inputs to arithmetic ops if it will simplify
/// the codegen.
/// e.g. TRUNC( BINOP( X, Y ) ) --> BINOP( TRUNC( X ), TRUNC( Y ) )
Expand Down Expand Up @@ -51888,6 +51927,11 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
return DAG.getNode(X86ISD::MMX_MOVD2W, DL, MVT::i32, BCSrc);
}

// Try to combine (trunc (vNi64 (lrint x))) to (vNi32 (lrint x)).
if (Src.getOpcode() == ISD::LRINT && VT.getScalarType() == MVT::i32 &&
Src.hasOneUse())
return DAG.getNode(ISD::LRINT, DL, VT, Src.getOperand(0));

return SDValue();
}

Expand Down Expand Up @@ -56834,6 +56878,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::UINT_TO_FP:
case ISD::STRICT_UINT_TO_FP:
return combineUIntToFP(N, DAG, Subtarget);
case ISD::LRINT:
case ISD::LLRINT: return combineLRINT_LLRINT(N, DAG, Subtarget);
case ISD::FADD:
case ISD::FSUB: return combineFaddFsub(N, DAG, Subtarget);
case X86ISD::VFCMULC:
Expand Down
35 changes: 35 additions & 0 deletions llvm/lib/Target/X86/X86InstrAVX512.td
Original file line number Diff line number Diff line change
Expand Up @@ -8811,7 +8811,18 @@ let Predicates = [HasVLX] in {
def : Pat<(X86mcvttp2ui (v2f64 (X86VBroadcastld64 addr:$src)),
v4i32x_info.ImmAllZerosV, VK2WM:$mask),
(VCVTTPD2UDQZ128rmbkz VK2WM:$mask, addr:$src)>;

def : Pat<(v4i32 (lrint VR128X:$src)), (VCVTPS2DQZ128rr VR128X:$src)>;
def : Pat<(v4i32 (lrint (loadv4f32 addr:$src))), (VCVTPS2DQZ128rm addr:$src)>;
def : Pat<(v8i32 (lrint VR256X:$src)), (VCVTPS2DQZ256rr VR256X:$src)>;
def : Pat<(v8i32 (lrint (loadv8f32 addr:$src))), (VCVTPS2DQZ256rm addr:$src)>;
def : Pat<(v4i32 (lrint VR256X:$src)), (VCVTPD2DQZ256rr VR256X:$src)>;
def : Pat<(v4i32 (lrint (loadv4f64 addr:$src))), (VCVTPD2DQZ256rm addr:$src)>;
}
def : Pat<(v16i32 (lrint VR512:$src)), (VCVTPS2DQZrr VR512:$src)>;
def : Pat<(v16i32 (lrint (loadv16f32 addr:$src))), (VCVTPS2DQZrm addr:$src)>;
def : Pat<(v8i32 (lrint VR512:$src)), (VCVTPD2DQZrr VR512:$src)>;
def : Pat<(v8i32 (lrint (loadv8f64 addr:$src))), (VCVTPD2DQZrm addr:$src)>;

let Predicates = [HasDQI, HasVLX] in {
def : Pat<(v2i64 (X86cvtp2Int (bc_v4f32 (v2f64 (X86vzload64 addr:$src))))),
Expand Down Expand Up @@ -8857,6 +8868,30 @@ let Predicates = [HasDQI, HasVLX] in {
(X86cvttp2ui (bc_v4f32 (v2f64 (X86vzload64 addr:$src)))),
v2i64x_info.ImmAllZerosV)),
(VCVTTPS2UQQZ128rmkz VK2WM:$mask, addr:$src)>;

def : Pat<(v4i64 (lrint VR128X:$src)), (VCVTPS2QQZ256rr VR128X:$src)>;
def : Pat<(v4i64 (lrint (loadv4f32 addr:$src))), (VCVTPS2QQZ256rm addr:$src)>;
def : Pat<(v4i64 (llrint VR128X:$src)), (VCVTPS2QQZ256rr VR128X:$src)>;
def : Pat<(v4i64 (llrint (loadv4f32 addr:$src))), (VCVTPS2QQZ256rm addr:$src)>;
def : Pat<(v2i64 (lrint VR128X:$src)), (VCVTPD2QQZ128rr VR128X:$src)>;
def : Pat<(v2i64 (lrint (loadv2f64 addr:$src))), (VCVTPD2QQZ128rm addr:$src)>;
def : Pat<(v4i64 (lrint VR256X:$src)), (VCVTPD2QQZ256rr VR256X:$src)>;
def : Pat<(v4i64 (lrint (loadv4f64 addr:$src))), (VCVTPD2QQZ256rm addr:$src)>;
def : Pat<(v2i64 (llrint VR128X:$src)), (VCVTPD2QQZ128rr VR128X:$src)>;
def : Pat<(v2i64 (llrint (loadv2f64 addr:$src))), (VCVTPD2QQZ128rm addr:$src)>;
def : Pat<(v4i64 (llrint VR256X:$src)), (VCVTPD2QQZ256rr VR256X:$src)>;
def : Pat<(v4i64 (llrint (loadv4f64 addr:$src))), (VCVTPD2QQZ256rm addr:$src)>;
}

let Predicates = [HasDQI] in {
def : Pat<(v8i64 (lrint VR256X:$src)), (VCVTPS2QQZrr VR256X:$src)>;
def : Pat<(v8i64 (lrint (loadv8f32 addr:$src))), (VCVTPS2QQZrm addr:$src)>;
def : Pat<(v8i64 (llrint VR256X:$src)), (VCVTPS2QQZrr VR256X:$src)>;
def : Pat<(v8i64 (llrint (loadv8f32 addr:$src))), (VCVTPS2QQZrm addr:$src)>;
def : Pat<(v8i64 (lrint VR512:$src)), (VCVTPD2QQZrr VR512:$src)>;
def : Pat<(v8i64 (lrint (loadv8f64 addr:$src))), (VCVTPD2QQZrm addr:$src)>;
def : Pat<(v8i64 (llrint VR512:$src)), (VCVTPD2QQZrr VR512:$src)>;
def : Pat<(v8i64 (llrint (loadv8f64 addr:$src))), (VCVTPD2QQZrm addr:$src)>;
}

let Predicates = [HasVLX] in {
Expand Down
15 changes: 14 additions & 1 deletion llvm/lib/Target/X86/X86InstrSSE.td
Original file line number Diff line number Diff line change
Expand Up @@ -1554,7 +1554,6 @@ def CVTPS2DQrm : PDI<0x5B, MRMSrcMem, (outs VR128:$dst), (ins f128mem:$src),
(v4i32 (X86cvtp2Int (memopv4f32 addr:$src))))]>,
Sched<[WriteCvtPS2ILd]>, SIMD_EXC;


// Convert Packed Double FP to Packed DW Integers
let Predicates = [HasAVX, NoVLX], Uses = [MXCSR], mayRaiseFPException = 1 in {
// The assembler can recognize rr 256-bit instructions by seeing a ymm
Expand Down Expand Up @@ -1586,6 +1585,20 @@ def VCVTPD2DQYrm : SDI<0xE6, MRMSrcMem, (outs VR128:$dst), (ins f256mem:$src),
VEX, VEX_L, Sched<[WriteCvtPD2IYLd]>, WIG;
}

let Predicates = [HasAVX] in {
def : Pat<(v4i32 (lrint VR128:$src)), (VCVTPS2DQrr VR128:$src)>;
def : Pat<(v4i32 (lrint (loadv4f32 addr:$src))), (VCVTPS2DQrm addr:$src)>;
def : Pat<(v8i32 (lrint VR256:$src)), (VCVTPS2DQYrr VR256:$src)>;
def : Pat<(v8i32 (lrint (loadv8f32 addr:$src))), (VCVTPS2DQYrm addr:$src)>;
def : Pat<(v4i32 (lrint VR256:$src)), (VCVTPD2DQYrr VR256:$src)>;
def : Pat<(v4i32 (lrint (loadv4f64 addr:$src))), (VCVTPD2DQYrm addr:$src)>;
}

let Predicates = [UseSSE2] in {
def : Pat<(v4i32 (lrint VR128:$src)), (CVTPS2DQrr VR128:$src)>;
def : Pat<(v4i32 (lrint (loadv4f32 addr:$src))), (CVTPS2DQrm addr:$src)>;
}

def : InstAlias<"vcvtpd2dqx\t{$src, $dst|$dst, $src}",
(VCVTPD2DQrr VR128:$dst, VR128:$src), 0, "att">;
def : InstAlias<"vcvtpd2dqy\t{$src, $dst|$dst, $src}",
Expand Down
Loading

0 comments on commit fd3e7e3

Please sign in to comment.