Skip to content

Commit

Permalink
[NVPTX] Fix bugs involving maximum/minimum and bf16
Browse files Browse the repository at this point in the history
We would crash on sufficiently old NV hardware (Volta or so) due to
incorrectly marking certain operations legal.
  • Loading branch information
majnemer committed Aug 20, 2024
1 parent ea1f05e commit a9ce181
Show file tree
Hide file tree
Showing 2 changed files with 1,382 additions and 253 deletions.
76 changes: 47 additions & 29 deletions llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,29 +429,50 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,

auto setFP16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
LegalizeAction NoF16Action) {
setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
bool IsOpSupported = STI.allowFP16Math();
switch (Op) {
// Several FP16 instructions are available on sm_80 only.
case ISD::FMINNUM:
case ISD::FMAXNUM:
case ISD::FMAXNUM_IEEE:
case ISD::FMINNUM_IEEE:
case ISD::FMAXIMUM:
case ISD::FMINIMUM:
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
break;
}
setOperationAction(Op, VT, IsOpSupported ? Action : NoF16Action);
};

auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
LegalizeAction NoBF16Action) {
bool IsOpSupported = STI.hasBF16Math();
// Few instructions are available on sm_90 only
switch(Op) {
case ISD::FADD:
case ISD::FMUL:
case ISD::FSUB:
case ISD::SELECT:
case ISD::SELECT_CC:
case ISD::SETCC:
case ISD::FEXP2:
case ISD::FCEIL:
case ISD::FFLOOR:
case ISD::FNEARBYINT:
case ISD::FRINT:
case ISD::FROUNDEVEN:
case ISD::FTRUNC:
IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
break;
switch (Op) {
// Several BF16 instructions are available on sm_90 only.
case ISD::FADD:
case ISD::FMUL:
case ISD::FSUB:
case ISD::SELECT:
case ISD::SELECT_CC:
case ISD::SETCC:
case ISD::FEXP2:
case ISD::FCEIL:
case ISD::FFLOOR:
case ISD::FNEARBYINT:
case ISD::FRINT:
case ISD::FROUNDEVEN:
case ISD::FTRUNC:
IsOpSupported = STI.getSmVersion() >= 90 && STI.getPTXVersion() >= 78;
break;
// Several BF16 instructions are available on sm_80 only.
case ISD::FMINNUM:
case ISD::FMAXNUM:
case ISD::FMAXNUM_IEEE:
case ISD::FMINNUM_IEEE:
case ISD::FMAXIMUM:
case ISD::FMINIMUM:
IsOpSupported &= STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
break;
}
setOperationAction(
Op, VT, IsOpSupported ? Action : NoBF16Action);
Expand Down Expand Up @@ -838,26 +859,23 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}

// max.f16, max.f16x2 and max.NaN are supported on sm_80+.
auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
bool IsAtLeastSm80 = STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
return IsAtLeastSm80 ? Legal : NotSm80Action;
};
for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Promote), Promote);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
setFP16OperationAction(Op, MVT::f16, Legal, Promote);
setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
if (getOperationAction(Op, MVT::bf16) == Promote)
AddPromotedToType(Op, MVT::bf16, MVT::f32);
}
bool SupportsF32MinMaxNaN =
STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70;
for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
setFP16OperationAction(Op, MVT::f16, GetMinMaxAction(Expand), Expand);
setFP16OperationAction(Op, MVT::bf16, Legal, Expand);
setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
setOperationAction(Op, MVT::f32, SupportsF32MinMaxNaN ? Legal : Expand);
setFP16OperationAction(Op, MVT::f16, Legal, Expand);
setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
setBF16OperationAction(Op, MVT::bf16, Legal, Expand);
setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
}

Expand Down
Loading

0 comments on commit a9ce181

Please sign in to comment.