Skip to content

Commit

Permalink
[NVPTX] extend type support for nvvm.{min,max,mulhi,sad} (llvm#78385)
Browse files Browse the repository at this point in the history
Ensure intrinsics and auto-upgrades support i16, i32, and i64 for for
`nvvm.{min,max,mulhi,sad}`

- `nvvm.min` and `nvvm.max`: These are auto-upgraded to `select`
instructions but it is still nice to support the 16 bit variants just in
case any generators of IR are still trying to use these intrinsics.
- `nvvm.sad` added both the 16 and 64 bit variants, also marked this
instruction as speculateble. These directly correspond to the PTX
`sad.{u16,s16,u64,s64}` instructions.
- `nvvm.mulhi` added the 16 bit variants. These directly correspond to
the PTX `mul.hi.{s,u}16` instructions.
  • Loading branch information
AlexMaclean authored and ampandey-1995 committed Jan 19, 2024
1 parent c772754 commit b44fed8
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 42 deletions.
26 changes: 24 additions & 2 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,13 @@ let TargetPrefix = "nvvm" in {
// Multiplication
//

def int_nvvm_mulhi_s : ClangBuiltin<"__nvvm_mulhi_s">,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
[IntrNoMem, IntrSpeculatable, Commutative]>;
def int_nvvm_mulhi_us : ClangBuiltin<"__nvvm_mulhi_us">,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
[IntrNoMem, IntrSpeculatable, Commutative]>;

def int_nvvm_mulhi_i : ClangBuiltin<"__nvvm_mulhi_i">,
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
[IntrNoMem, IntrSpeculatable, Commutative]>;
Expand Down Expand Up @@ -730,12 +737,27 @@ let TargetPrefix = "nvvm" in {
// Sad
//

def int_nvvm_sad_s : ClangBuiltin<"__nvvm_sad_s">,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty],
[IntrNoMem, Commutative, IntrSpeculatable]>;
def int_nvvm_sad_us : ClangBuiltin<"__nvvm_sad_us">,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty],
[IntrNoMem, Commutative, IntrSpeculatable]>;

def int_nvvm_sad_i : ClangBuiltin<"__nvvm_sad_i">,
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
[IntrNoMem, Commutative]>;
[IntrNoMem, Commutative, IntrSpeculatable]>;
def int_nvvm_sad_ui : ClangBuiltin<"__nvvm_sad_ui">,
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
[IntrNoMem, Commutative]>;
[IntrNoMem, Commutative, IntrSpeculatable]>;

def int_nvvm_sad_ll : ClangBuiltin<"__nvvm_sad_ll">,
DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i64_ty, llvm_i64_ty],
[IntrNoMem, Commutative, IntrSpeculatable]>;
def int_nvvm_sad_ull : ClangBuiltin<"__nvvm_sad_ull">,
DefaultAttrsIntrinsic<[llvm_i64_ty], [llvm_i64_ty, llvm_i64_ty, llvm_i64_ty],
[IntrNoMem, Commutative, IntrSpeculatable]>;


//
// Floor Ceil
Expand Down
17 changes: 10 additions & 7 deletions llvm/lib/IR/AutoUpgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,8 @@ static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) {
Expand = true;
else if (Name.consume_front("max.") || Name.consume_front("min."))
// nvvm.{min,max}.{i,ii,ui,ull}
Expand = Name == "i" || Name == "ll" || Name == "ui" || Name == "ull";
Expand = Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
Name == "ui" || Name == "ull";
else if (Name.consume_front("atomic.load.add."))
// nvvm.atomic.load.add.{f32.p,f64.p}
Expand = Name.starts_with("f32.p") || Name.starts_with("f64.p");
Expand Down Expand Up @@ -4132,19 +4133,21 @@ void llvm::UpgradeIntrinsicCall(CallBase *CI, Function *NewFn) {
Value *Val = CI->getArgOperand(1);
Rep = Builder.CreateAtomicRMW(AtomicRMWInst::FAdd, Ptr, Val, MaybeAlign(),
AtomicOrdering::SequentiallyConsistent);
} else if (IsNVVM && (Name == "max.i" || Name == "max.ll" ||
Name == "max.ui" || Name == "max.ull")) {
} else if (IsNVVM && Name.consume_front("max.") &&
(Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
Name == "ui" || Name == "ull")) {
Value *Arg0 = CI->getArgOperand(0);
Value *Arg1 = CI->getArgOperand(1);
Value *Cmp = Name.ends_with(".ui") || Name.ends_with(".ull")
Value *Cmp = Name.starts_with("u")
? Builder.CreateICmpUGE(Arg0, Arg1, "max.cond")
: Builder.CreateICmpSGE(Arg0, Arg1, "max.cond");
Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "max");
} else if (IsNVVM && (Name == "min.i" || Name == "min.ll" ||
Name == "min.ui" || Name == "min.ull")) {
} else if (IsNVVM && Name.consume_front("min.") &&
(Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
Name == "ui" || Name == "ull")) {
Value *Arg0 = CI->getArgOperand(0);
Value *Arg1 = CI->getArgOperand(1);
Value *Cmp = Name.ends_with(".ui") || Name.ends_with(".ull")
Value *Cmp = Name.starts_with("u")
? Builder.CreateICmpULE(Arg0, Arg1, "min.cond")
: Builder.CreateICmpSLE(Arg0, Arg1, "min.cond");
Rep = Builder.CreateSelect(Cmp, Arg0, Arg1, "min");
Expand Down
13 changes: 12 additions & 1 deletion llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -770,11 +770,14 @@ defm INT_NVVM_FMAN : MIN_MAX<"max">;
// Multiplication
//

def INT_NVVM_MULHI_S : F_MATH_2<"mul.hi.s16 \t$dst, $src0, $src1;", Int16Regs,
Int16Regs, Int16Regs, int_nvvm_mulhi_s>;
def INT_NVVM_MULHI_US : F_MATH_2<"mul.hi.u16 \t$dst, $src0, $src1;", Int16Regs,
Int16Regs, Int16Regs, int_nvvm_mulhi_us>;
def INT_NVVM_MULHI_I : F_MATH_2<"mul.hi.s32 \t$dst, $src0, $src1;", Int32Regs,
Int32Regs, Int32Regs, int_nvvm_mulhi_i>;
def INT_NVVM_MULHI_UI : F_MATH_2<"mul.hi.u32 \t$dst, $src0, $src1;", Int32Regs,
Int32Regs, Int32Regs, int_nvvm_mulhi_ui>;

def INT_NVVM_MULHI_LL : F_MATH_2<"mul.hi.s64 \t$dst, $src0, $src1;", Int64Regs,
Int64Regs, Int64Regs, int_nvvm_mulhi_ll>;
def INT_NVVM_MULHI_ULL : F_MATH_2<"mul.hi.u64 \t$dst, $src0, $src1;", Int64Regs,
Expand Down Expand Up @@ -851,10 +854,18 @@ def INT_NVVM_DIV_RP_D : F_MATH_2<"div.rp.f64 \t$dst, $src0, $src1;",
// Sad
//

def INT_NVVM_SAD_S : F_MATH_3<"sad.s16 \t$dst, $src0, $src1, $src2;",
Int16Regs, Int16Regs, Int16Regs, Int16Regs, int_nvvm_sad_s>;
def INT_NVVM_SAD_US : F_MATH_3<"sad.u16 \t$dst, $src0, $src1, $src2;",
Int16Regs, Int16Regs, Int16Regs, Int16Regs, int_nvvm_sad_us>;
def INT_NVVM_SAD_I : F_MATH_3<"sad.s32 \t$dst, $src0, $src1, $src2;",
Int32Regs, Int32Regs, Int32Regs, Int32Regs, int_nvvm_sad_i>;
def INT_NVVM_SAD_UI : F_MATH_3<"sad.u32 \t$dst, $src0, $src1, $src2;",
Int32Regs, Int32Regs, Int32Regs, Int32Regs, int_nvvm_sad_ui>;
def INT_NVVM_SAD_LL : F_MATH_3<"sad.s64 \t$dst, $src0, $src1, $src2;",
Int64Regs, Int64Regs, Int64Regs, Int64Regs, int_nvvm_sad_ll>;
def INT_NVVM_SAD_ULL : F_MATH_3<"sad.u64 \t$dst, $src0, $src1, $src2;",
Int64Regs, Int64Regs, Int64Regs, Int64Regs, int_nvvm_sad_ull>;

//
// Floor Ceil
Expand Down
84 changes: 52 additions & 32 deletions llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@ declare float @llvm.nvvm.h2f(i16)
declare i32 @llvm.nvvm.abs.i(i32)
declare i64 @llvm.nvvm.abs.ll(i64)

declare i16 @llvm.nvvm.max.s(i16, i16)
declare i32 @llvm.nvvm.max.i(i32, i32)
declare i64 @llvm.nvvm.max.ll(i64, i64)
declare i16 @llvm.nvvm.max.us(i16, i16)
declare i32 @llvm.nvvm.max.ui(i32, i32)
declare i64 @llvm.nvvm.max.ull(i64, i64)
declare i16 @llvm.nvvm.min.s(i16, i16)
declare i32 @llvm.nvvm.min.i(i32, i32)
declare i64 @llvm.nvvm.min.ll(i64, i64)
declare i16 @llvm.nvvm.min.us(i16, i16)
declare i32 @llvm.nvvm.min.ui(i32, i32)
declare i64 @llvm.nvvm.min.ull(i64, i64)

Expand Down Expand Up @@ -65,38 +69,54 @@ define void @abs(i32 %a, i64 %b) {
}

; CHECK-LABEL: @min_max
define void @min_max(i32 %a1, i32 %a2, i64 %b1, i64 %b2) {
; CHECK: [[maxi:%[a-zA-Z0-9.]+]] = icmp sge i32 %a1, %a2
; CHECK: select i1 [[maxi]], i32 %a1, i32 %a2
%r1 = call i32 @llvm.nvvm.max.i(i32 %a1, i32 %a2)

; CHECK: [[maxll:%[a-zA-Z0-9.]+]] = icmp sge i64 %b1, %b2
; CHECK: select i1 [[maxll]], i64 %b1, i64 %b2
%r2 = call i64 @llvm.nvvm.max.ll(i64 %b1, i64 %b2)

; CHECK: [[maxui:%[a-zA-Z0-9.]+]] = icmp uge i32 %a1, %a2
; CHECK: select i1 [[maxui]], i32 %a1, i32 %a2
%r3 = call i32 @llvm.nvvm.max.ui(i32 %a1, i32 %a2)

; CHECK: [[maxull:%[a-zA-Z0-9.]+]] = icmp uge i64 %b1, %b2
; CHECK: select i1 [[maxull]], i64 %b1, i64 %b2
%r4 = call i64 @llvm.nvvm.max.ull(i64 %b1, i64 %b2)

; CHECK: [[mini:%[a-zA-Z0-9.]+]] = icmp sle i32 %a1, %a2
; CHECK: select i1 [[mini]], i32 %a1, i32 %a2
%r5 = call i32 @llvm.nvvm.min.i(i32 %a1, i32 %a2)

; CHECK: [[minll:%[a-zA-Z0-9.]+]] = icmp sle i64 %b1, %b2
; CHECK: select i1 [[minll]], i64 %b1, i64 %b2
%r6 = call i64 @llvm.nvvm.min.ll(i64 %b1, i64 %b2)

; CHECK: [[minui:%[a-zA-Z0-9.]+]] = icmp ule i32 %a1, %a2
; CHECK: select i1 [[minui]], i32 %a1, i32 %a2
%r7 = call i32 @llvm.nvvm.min.ui(i32 %a1, i32 %a2)

; CHECK: [[minull:%[a-zA-Z0-9.]+]] = icmp ule i64 %b1, %b2
; CHECK: select i1 [[minull]], i64 %b1, i64 %b2
%r8 = call i64 @llvm.nvvm.min.ull(i64 %b1, i64 %b2)
define void @min_max(i16 %a1, i16 %a2, i32 %b1, i32 %b2, i64 %c1, i64 %c2) {
; CHECK: [[maxs:%[a-zA-Z0-9.]+]] = icmp sge i16 %a1, %a2
; CHECK: select i1 [[maxs]], i16 %a1, i16 %a2
%r1 = call i16 @llvm.nvvm.max.s(i16 %a1, i16 %a2)

; CHECK: [[maxi:%[a-zA-Z0-9.]+]] = icmp sge i32 %b1, %b2
; CHECK: select i1 [[maxi]], i32 %b1, i32 %b2
%r2 = call i32 @llvm.nvvm.max.i(i32 %b1, i32 %b2)

; CHECK: [[maxll:%[a-zA-Z0-9.]+]] = icmp sge i64 %c1, %c2
; CHECK: select i1 [[maxll]], i64 %c1, i64 %c2
%r3 = call i64 @llvm.nvvm.max.ll(i64 %c1, i64 %c2)

; CHECK: [[maxus:%[a-zA-Z0-9.]+]] = icmp uge i16 %a1, %a2
; CHECK: select i1 [[maxus]], i16 %a1, i16 %a2
%r4 = call i16 @llvm.nvvm.max.us(i16 %a1, i16 %a2)

; CHECK: [[maxui:%[a-zA-Z0-9.]+]] = icmp uge i32 %b1, %b2
; CHECK: select i1 [[maxui]], i32 %b1, i32 %b2
%r5 = call i32 @llvm.nvvm.max.ui(i32 %b1, i32 %b2)

; CHECK: [[maxull:%[a-zA-Z0-9.]+]] = icmp uge i64 %c1, %c2
; CHECK: select i1 [[maxull]], i64 %c1, i64 %c2
%r6 = call i64 @llvm.nvvm.max.ull(i64 %c1, i64 %c2)

; CHECK: [[mins:%[a-zA-Z0-9.]+]] = icmp sle i16 %a1, %a2
; CHECK: select i1 [[mins]], i16 %a1, i16 %a2
%r7 = call i16 @llvm.nvvm.min.s(i16 %a1, i16 %a2)

; CHECK: [[mini:%[a-zA-Z0-9.]+]] = icmp sle i32 %b1, %b2
; CHECK: select i1 [[mini]], i32 %b1, i32 %b2
%r8 = call i32 @llvm.nvvm.min.i(i32 %b1, i32 %b2)

; CHECK: [[minll:%[a-zA-Z0-9.]+]] = icmp sle i64 %c1, %c2
; CHECK: select i1 [[minll]], i64 %c1, i64 %c2
%r9 = call i64 @llvm.nvvm.min.ll(i64 %c1, i64 %c2)

; CHECK: [[minus:%[a-zA-Z0-9.]+]] = icmp ule i16 %a1, %a2
; CHECK: select i1 [[minus]], i16 %a1, i16 %a2
%r10 = call i16 @llvm.nvvm.min.us(i16 %a1, i16 %a2)

; CHECK: [[minui:%[a-zA-Z0-9.]+]] = icmp ule i32 %b1, %b2
; CHECK: select i1 [[minui]], i32 %b1, i32 %b2
%r11 = call i32 @llvm.nvvm.min.ui(i32 %b1, i32 %b2)

; CHECK: [[minull:%[a-zA-Z0-9.]+]] = icmp ule i64 %c1, %c2
; CHECK: select i1 [[minull]], i64 %c1, i64 %c2
%r12 = call i64 @llvm.nvvm.min.ull(i64 %c1, i64 %c2)

ret void
}
104 changes: 104 additions & 0 deletions llvm/test/CodeGen/NVPTX/mulhi-intrins.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_50 | FileCheck %s
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_50 | %ptxas-verify %}

define i16 @test_mulhi_i16(i16 %x, i16 %y) {
; CHECK-LABEL: test_mulhi_i16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<4>;
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u16 %rs1, [test_mulhi_i16_param_0];
; CHECK-NEXT: ld.param.u16 %rs2, [test_mulhi_i16_param_1];
; CHECK-NEXT: mul.hi.s16 %rs3, %rs1, %rs2;
; CHECK-NEXT: cvt.u32.u16 %r1, %rs3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
; CHECK-NEXT: ret;
%1 = call i16 @llvm.nvvm.mulhi.s(i16 %x, i16 %y)
ret i16 %1
}

define i16 @test_mulhi_u16(i16 %x, i16 %y) {
; CHECK-LABEL: test_mulhi_u16(
; CHECK: {
; CHECK-NEXT: .reg .b16 %rs<4>;
; CHECK-NEXT: .reg .b32 %r<2>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u16 %rs1, [test_mulhi_u16_param_0];
; CHECK-NEXT: ld.param.u16 %rs2, [test_mulhi_u16_param_1];
; CHECK-NEXT: mul.hi.u16 %rs3, %rs1, %rs2;
; CHECK-NEXT: cvt.u32.u16 %r1, %rs3;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r1;
; CHECK-NEXT: ret;
%1 = call i16 @llvm.nvvm.mulhi.us(i16 %x, i16 %y)
ret i16 %1
}

define i32 @test_mulhi_i32(i32 %x, i32 %y) {
; CHECK-LABEL: test_mulhi_i32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_mulhi_i32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_mulhi_i32_param_1];
; CHECK-NEXT: mul.hi.s32 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3;
; CHECK-NEXT: ret;
%1 = call i32 @llvm.nvvm.mulhi.i(i32 %x, i32 %y)
ret i32 %1
}

define i32 @test_mulhi_u32(i32 %x, i32 %y) {
; CHECK-LABEL: test_mulhi_u32(
; CHECK: {
; CHECK-NEXT: .reg .b32 %r<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u32 %r1, [test_mulhi_u32_param_0];
; CHECK-NEXT: ld.param.u32 %r2, [test_mulhi_u32_param_1];
; CHECK-NEXT: mul.hi.u32 %r3, %r1, %r2;
; CHECK-NEXT: st.param.b32 [func_retval0+0], %r3;
; CHECK-NEXT: ret;
%1 = call i32 @llvm.nvvm.mulhi.ui(i32 %x, i32 %y)
ret i32 %1
}

define i64 @test_mulhi_i64(i64 %x, i64 %y) {
; CHECK-LABEL: test_mulhi_i64(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [test_mulhi_i64_param_0];
; CHECK-NEXT: ld.param.u64 %rd2, [test_mulhi_i64_param_1];
; CHECK-NEXT: mul.hi.s64 %rd3, %rd1, %rd2;
; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd3;
; CHECK-NEXT: ret;
%1 = call i64 @llvm.nvvm.mulhi.ll(i64 %x, i64 %y)
ret i64 %1
}

define i64 @test_mulhi_u64(i64 %x, i64 %y) {
; CHECK-LABEL: test_mulhi_u64(
; CHECK: {
; CHECK-NEXT: .reg .b64 %rd<4>;
; CHECK-EMPTY:
; CHECK-NEXT: // %bb.0:
; CHECK-NEXT: ld.param.u64 %rd1, [test_mulhi_u64_param_0];
; CHECK-NEXT: ld.param.u64 %rd2, [test_mulhi_u64_param_1];
; CHECK-NEXT: mul.hi.u64 %rd3, %rd1, %rd2;
; CHECK-NEXT: st.param.b64 [func_retval0+0], %rd3;
; CHECK-NEXT: ret;
%1 = call i64 @llvm.nvvm.mulhi.ull(i64 %x, i64 %y)
ret i64 %1
}

declare i16 @llvm.nvvm.mulhi.s(i16, i16)
declare i16 @llvm.nvvm.mulhi.us(i16, i16)
declare i32 @llvm.nvvm.mulhi.i(i32, i32)
declare i32 @llvm.nvvm.mulhi.ui(i32, i32)
declare i64 @llvm.nvvm.mulhi.ll(i64, i64)
declare i64 @llvm.nvvm.mulhi.ull(i64, i64)
Loading

0 comments on commit b44fed8

Please sign in to comment.