Skip to content

Commit

Permalink
[RISCV] Handle FP riscv_masked_strided_load with 0 stride. (#84576)
Browse files Browse the repository at this point in the history
Previously, we tried to create an integer extending load. We need to a
non-extending FP load instead.

Fixes #84541.
  • Loading branch information
topperc authored Mar 11, 2024
1 parent 3f6bc1a commit d8d2dea
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
11 changes: 8 additions & 3 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9080,15 +9080,20 @@ SDValue RISCVTargetLowering::LowerINTRINSIC_W_CHAIN(SDValue Op,
SDValue Result, Chain;

// TODO: We restrict this to unmasked loads currently in consideration of
// the complexity of hanlding all falses masks.
if (IsUnmasked && isNullConstant(Stride)) {
MVT ScalarVT = ContainerVT.getVectorElementType();
// the complexity of handling all falses masks.
MVT ScalarVT = ContainerVT.getVectorElementType();
if (IsUnmasked && isNullConstant(Stride) && ContainerVT.isInteger()) {
SDValue ScalarLoad =
DAG.getExtLoad(ISD::ZEXTLOAD, DL, XLenVT, Load->getChain(), Ptr,
ScalarVT, Load->getMemOperand());
Chain = ScalarLoad.getValue(1);
Result = lowerScalarSplat(SDValue(), ScalarLoad, VL, ContainerVT, DL, DAG,
Subtarget);
} else if (IsUnmasked && isNullConstant(Stride) && isTypeLegal(ScalarVT)) {
SDValue ScalarLoad = DAG.getLoad(ScalarVT, DL, Load->getChain(), Ptr,
Load->getMemOperand());
Chain = ScalarLoad.getValue(1);
Result = DAG.getSplat(ContainerVT, DL, ScalarLoad);
} else {
SDValue IntID = DAG.getTargetConstant(
IsUnmasked ? Intrinsic::riscv_vlse : Intrinsic::riscv_vlse_mask, DL,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -915,3 +915,42 @@ bb4: ; preds = %bb4, %bb2
bb16: ; preds = %bb4, %bb
ret void
}

define void @gather_zero_stride_fp(ptr noalias nocapture %A, ptr noalias nocapture readonly %B) {
; CHECK-LABEL: gather_zero_stride_fp:
; CHECK: # %bb.0: # %entry
; CHECK-NEXT: lui a2, 1
; CHECK-NEXT: add a2, a0, a2
; CHECK-NEXT: vsetivli zero, 8, e32, m1, ta, ma
; CHECK-NEXT: .LBB15_1: # %vector.body
; CHECK-NEXT: # =>This Inner Loop Header: Depth=1
; CHECK-NEXT: flw fa5, 0(a1)
; CHECK-NEXT: vle32.v v8, (a0)
; CHECK-NEXT: vfadd.vf v8, v8, fa5
; CHECK-NEXT: vse32.v v8, (a0)
; CHECK-NEXT: addi a0, a0, 128
; CHECK-NEXT: addi a1, a1, 640
; CHECK-NEXT: bne a0, a2, .LBB15_1
; CHECK-NEXT: # %bb.2: # %for.cond.cleanup
; CHECK-NEXT: ret
entry:
br label %vector.body

vector.body: ; preds = %vector.body, %entry
%index = phi i64 [ 0, %entry ], [ %index.next, %vector.body ]
%vec.ind = phi <8 x i64> [ zeroinitializer, %entry ], [ %vec.ind.next, %vector.body ]
%i = mul nuw nsw <8 x i64> %vec.ind, <i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5, i64 5>
%i1 = getelementptr inbounds float, ptr %B, <8 x i64> %i
%wide.masked.gather = call <8 x float> @llvm.masked.gather.v8f32.v32p0(<8 x ptr> %i1, i32 4, <8 x i1> <i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true, i1 true>, <8 x float> undef)
%i2 = getelementptr inbounds float, ptr %A, i64 %index
%wide.load = load <8 x float>, ptr %i2, align 4
%i4 = fadd <8 x float> %wide.load, %wide.masked.gather
store <8 x float> %i4, ptr %i2, align 4
%index.next = add nuw i64 %index, 32
%vec.ind.next = add <8 x i64> %vec.ind, <i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32, i64 32>
%i6 = icmp eq i64 %index.next, 1024
br i1 %i6, label %for.cond.cleanup, label %vector.body

for.cond.cleanup: ; preds = %vector.body
ret void
}

0 comments on commit d8d2dea

Please sign in to comment.