Skip to content

Commit

Permalink
[DAGCombine] Fix multi-use miscompile in load combine
Browse files Browse the repository at this point in the history
The load combine replaces a number of original loads with one
new loads and also replaces the output chains of the original loads
with the output chain of the new load. This is only correct if
the old loads actually get removed, otherwise they may get
incorrectly reordered.

The code did enforce that all involved operations are one-use
(which also guarantees that the loads will be removed), with one
exceptions: For vector loads, multi-use was allowed to support
multiple extract elements from one load.

This patch collects these extract elements, and then validates
that the loads are only used inside them.

I think an alternative fix would be to replace the uses of the old
output chains with TokenFactors that include both the old output
chains and the new output chain. However, I think the proposed
patch is preferable, as the profitability of the transform in the
general multi-use case is unclear, as it may increase the overall
number of loads.

Fixes #80911.
  • Loading branch information
nikic committed Feb 12, 2024
1 parent 69ddf1e commit f122182
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 25 deletions.
33 changes: 24 additions & 9 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8668,6 +8668,7 @@ using SDByteProvider = ByteProvider<SDNode *>;
static std::optional<SDByteProvider>
calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
std::optional<uint64_t> VectorIndex,
SmallPtrSetImpl<SDNode *> &ExtractElements,
unsigned StartingIndex = 0) {

// Typical i64 by i8 pattern requires recursion up to 8 calls depth
Expand All @@ -8694,12 +8695,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,

switch (Op.getOpcode()) {
case ISD::OR: {
auto LHS =
calculateByteProvider(Op->getOperand(0), Index, Depth + 1, VectorIndex);
auto LHS = calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
VectorIndex, ExtractElements);
if (!LHS)
return std::nullopt;
auto RHS =
calculateByteProvider(Op->getOperand(1), Index, Depth + 1, VectorIndex);
auto RHS = calculateByteProvider(Op->getOperand(1), Index, Depth + 1,
VectorIndex, ExtractElements);
if (!RHS)
return std::nullopt;

Expand All @@ -8726,7 +8727,8 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
return Index < ByteShift
? SDByteProvider::getConstantZero()
: calculateByteProvider(Op->getOperand(0), Index - ByteShift,
Depth + 1, VectorIndex, Index);
Depth + 1, VectorIndex, ExtractElements,
Index);
}
case ISD::ANY_EXTEND:
case ISD::SIGN_EXTEND:
Expand All @@ -8743,11 +8745,12 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
SDByteProvider::getConstantZero())
: std::nullopt;
return calculateByteProvider(NarrowOp, Index, Depth + 1, VectorIndex,
StartingIndex);
ExtractElements, StartingIndex);
}
case ISD::BSWAP:
return calculateByteProvider(Op->getOperand(0), ByteWidth - Index - 1,
Depth + 1, VectorIndex, StartingIndex);
Depth + 1, VectorIndex, ExtractElements,
StartingIndex);
case ISD::EXTRACT_VECTOR_ELT: {
auto OffsetOp = dyn_cast<ConstantSDNode>(Op->getOperand(1));
if (!OffsetOp)
Expand All @@ -8772,8 +8775,9 @@ calculateByteProvider(SDValue Op, unsigned Index, unsigned Depth,
if ((*VectorIndex + 1) * NarrowByteWidth <= StartingIndex)
return std::nullopt;

ExtractElements.insert(Op.getNode());
return calculateByteProvider(Op->getOperand(0), Index, Depth + 1,
VectorIndex, StartingIndex);
VectorIndex, ExtractElements, StartingIndex);
}
case ISD::LOAD: {
auto L = cast<LoadSDNode>(Op.getNode());
Expand Down Expand Up @@ -9110,6 +9114,7 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
SDValue Chain;

SmallPtrSet<LoadSDNode *, 8> Loads;
SmallPtrSet<SDNode *, 8> ExtractElements;
std::optional<SDByteProvider> FirstByteProvider;
int64_t FirstOffset = INT64_MAX;

Expand All @@ -9119,7 +9124,9 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
unsigned ZeroExtendedBytes = 0;
for (int i = ByteWidth - 1; i >= 0; --i) {
auto P =
calculateByteProvider(SDValue(N, 0), i, 0, /*VectorIndex*/ std::nullopt,
calculateByteProvider(SDValue(N, 0), i, 0,
/*VectorIndex*/ std::nullopt, ExtractElements,

/*StartingIndex*/ i);
if (!P)
return SDValue();
Expand Down Expand Up @@ -9245,6 +9252,14 @@ SDValue DAGCombiner::MatchLoadCombine(SDNode *N) {
if (!Allowed || !Fast)
return SDValue();

// calculatebyteProvider() allows multi-use for vector loads. Ensure that
// all uses are in vector element extracts that are part of the pattern.
for (LoadSDNode *L : Loads)
if (L->getMemoryVT().isVector())
for (auto It = L->use_begin(); It != L->use_end(); ++It)
if (It.getUse().getResNo() == 0 && !ExtractElements.contains(*It))
return SDValue();

SDValue NewLoad =
DAG.getExtLoad(NeedsZext ? ISD::ZEXTLOAD : ISD::NON_EXTLOAD, SDLoc(N), VT,
Chain, FirstLoad->getBasePtr(),
Expand Down
8 changes: 5 additions & 3 deletions llvm/test/CodeGen/AArch64/load-combine.ll
Original file line number Diff line number Diff line change
Expand Up @@ -606,10 +606,12 @@ define void @short_vector_to_i32_unused_high_i8(ptr %in, ptr %out, ptr %p) {
; CHECK-LABEL: short_vector_to_i32_unused_high_i8:
; CHECK: // %bb.0:
; CHECK-NEXT: ldr s0, [x0]
; CHECK-NEXT: ldrh w9, [x0]
; CHECK-NEXT: ushll v0.8h, v0.8b, #0
; CHECK-NEXT: umov w8, v0.h[2]
; CHECK-NEXT: orr w8, w9, w8, lsl #16
; CHECK-NEXT: umov w8, v0.h[1]
; CHECK-NEXT: umov w9, v0.h[0]
; CHECK-NEXT: umov w10, v0.h[2]
; CHECK-NEXT: bfi w9, w8, #8, #8
; CHECK-NEXT: orr w8, w9, w10, lsl #16
; CHECK-NEXT: str w8, [x1]
; CHECK-NEXT: ret
%ld = load <4 x i8>, ptr %in, align 4
Expand Down
10 changes: 6 additions & 4 deletions llvm/test/CodeGen/AMDGPU/combine-vload-extract.ll
Original file line number Diff line number Diff line change
Expand Up @@ -205,12 +205,14 @@ define i64 @load_3xi16_combine(ptr addrspace(1) %p) #0 {
; GCN-LABEL: load_3xi16_combine:
; GCN: ; %bb.0:
; GCN-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GCN-NEXT: global_load_dword v2, v[0:1], off
; GCN-NEXT: global_load_ushort v3, v[0:1], off offset:4
; GCN-NEXT: global_load_dword v3, v[0:1], off
; GCN-NEXT: global_load_ushort v2, v[0:1], off offset:4
; GCN-NEXT: s_mov_b32 s4, 0xffff
; GCN-NEXT: s_waitcnt vmcnt(1)
; GCN-NEXT: v_mov_b32_e32 v0, v2
; GCN-NEXT: v_and_b32_e32 v0, 0xffff0000, v3
; GCN-NEXT: v_and_or_b32 v0, v3, s4, v0
; GCN-NEXT: s_waitcnt vmcnt(0)
; GCN-NEXT: v_mov_b32_e32 v1, v3
; GCN-NEXT: v_mov_b32_e32 v1, v2
; GCN-NEXT: s_setpc_b64 s[30:31]
%gep.p = getelementptr i16, ptr addrspace(1) %p, i32 1
%gep.2p = getelementptr i16, ptr addrspace(1) %p, i32 2
Expand Down
25 changes: 16 additions & 9 deletions llvm/test/CodeGen/X86/load-combine.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1283,26 +1283,33 @@ define i32 @zext_load_i32_by_i8_bswap_shl_16(ptr %arg) {
ret i32 %tmp8
}

; FIXME: This is a miscompile.
define i32 @pr80911_vector_load_multiuse(ptr %ptr, ptr %clobber) nounwind {
; CHECK-LABEL: pr80911_vector_load_multiuse:
; CHECK: # %bb.0:
; CHECK-NEXT: pushl %edi
; CHECK-NEXT: pushl %esi
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %ecx
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %edx
; CHECK-NEXT: movl (%edx), %esi
; CHECK-NEXT: movzwl (%edx), %eax
; CHECK-NEXT: movl $0, (%ecx)
; CHECK-NEXT: movl %esi, (%edx)
; CHECK-NEXT: movl {{[0-9]+}}(%esp), %esi
; CHECK-NEXT: movzbl (%esi), %ecx
; CHECK-NEXT: movzbl 1(%esi), %eax
; CHECK-NEXT: movzwl 2(%esi), %edi
; CHECK-NEXT: movl $0, (%edx)
; CHECK-NEXT: movw %di, 2(%esi)
; CHECK-NEXT: movb %al, 1(%esi)
; CHECK-NEXT: movb %cl, (%esi)
; CHECK-NEXT: shll $8, %eax
; CHECK-NEXT: orl %ecx, %eax
; CHECK-NEXT: popl %esi
; CHECK-NEXT: popl %edi
; CHECK-NEXT: retl
;
; CHECK64-LABEL: pr80911_vector_load_multiuse:
; CHECK64: # %bb.0:
; CHECK64-NEXT: movzwl (%rdi), %eax
; CHECK64-NEXT: movaps (%rdi), %xmm0
; CHECK64-NEXT: movl $0, (%rsi)
; CHECK64-NEXT: movl (%rdi), %ecx
; CHECK64-NEXT: movl %ecx, (%rdi)
; CHECK64-NEXT: movss %xmm0, (%rdi)
; CHECK64-NEXT: movaps %xmm0, -{{[0-9]+}}(%rsp)
; CHECK64-NEXT: movzwl -{{[0-9]+}}(%rsp), %eax
; CHECK64-NEXT: retq
%load = load <4 x i8>, ptr %ptr, align 16
store i32 0, ptr %clobber
Expand Down

0 comments on commit f122182

Please sign in to comment.