Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AArch64] Stack probing for function prologues #66524

Merged
merged 5 commits into from
Nov 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
336 changes: 307 additions & 29 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.cpp

Large diffs are not rendered by default.

17 changes: 15 additions & 2 deletions llvm/lib/Target/AArch64/AArch64FrameLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,26 @@ class AArch64FrameLowering : public TargetFrameLowering {
MachineBasicBlock::iterator MBBI) const;
void allocateStackSpace(MachineBasicBlock &MBB,
MachineBasicBlock::iterator MBBI,
bool NeedsRealignment, StackOffset AllocSize,
int64_t RealignmentPadding, StackOffset AllocSize,
bool NeedsWinCFI, bool *HasWinCFI, bool EmitCFI,
StackOffset InitialOffset) const;
StackOffset InitialOffset, bool FollowupAllocs) const;

/// Emit target zero call-used regs.
void emitZeroCallUsedRegs(BitVector RegsToZero,
MachineBasicBlock &MBB) const override;

/// Replace a StackProbe stub (if any) with the actual probe code inline
void inlineStackProbe(MachineFunction &MF,
MachineBasicBlock &PrologueMBB) const override;

void inlineStackProbeFixed(MachineBasicBlock::iterator MBBI,
Register ScratchReg, int64_t FrameSize,
StackOffset CFAOffset) const;

MachineBasicBlock::iterator
inlineStackProbeLoopExactMultiple(MachineBasicBlock::iterator MBBI,
int64_t NegProbeSize,
Register TargetReg) const;
};

} // End llvm namespace
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26815,3 +26815,9 @@ unsigned AArch64TargetLowering::getVectorTypeBreakdownForCallingConv(

return NumRegs;
}

bool AArch64TargetLowering::hasInlineStackProbe(
const MachineFunction &MF) const {
return !Subtarget->isTargetWindows() &&
MF.getInfo<AArch64FunctionInfo>()->hasStackProbing();
}
10 changes: 10 additions & 0 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,13 @@ const unsigned RoundingBitsPos = 22;
ArrayRef<MCPhysReg> getGPRArgRegs();
ArrayRef<MCPhysReg> getFPRArgRegs();

/// Maximum allowed number of unprobed bytes above SP at an ABI
/// boundary.
const unsigned StackProbeMaxUnprobedStack = 1024;

/// Maximum number of iterations to unroll for a constant size probing loop.
const unsigned StackProbeMaxLoopUnroll = 4;

} // namespace AArch64

class AArch64Subtarget;
Expand Down Expand Up @@ -966,6 +973,9 @@ class AArch64TargetLowering : public TargetLowering {
unsigned &NumIntermediates,
MVT &RegisterVT) const override;

/// True if stack clash protection is enabled for this functions.
bool hasInlineStackProbe(const MachineFunction &MF) const override;

private:
/// Keep a pointer to the AArch64Subtarget around so that we can
/// make the right decision when generating code for different targets.
Expand Down
91 changes: 90 additions & 1 deletion llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
//
//===----------------------------------------------------------------------===//

#include "AArch64ExpandImm.h"
#include "AArch64InstrInfo.h"
#include "AArch64ExpandImm.h"
#include "AArch64FrameLowering.h"
#include "AArch64MachineFunctionInfo.h"
#include "AArch64PointerAuth.h"
Expand All @@ -21,6 +21,7 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/CodeGen/LivePhysRegs.h"
#include "llvm/CodeGen/MachineBasicBlock.h"
#include "llvm/CodeGen/MachineCombinerPattern.h"
#include "llvm/CodeGen/MachineFrameInfo.h"
Expand Down Expand Up @@ -9461,6 +9462,94 @@ bool AArch64InstrInfo::isReallyTriviallyReMaterializable(
return TargetInstrInfo::isReallyTriviallyReMaterializable(MI);
}

MachineBasicBlock::iterator
AArch64InstrInfo::probedStackAlloc(MachineBasicBlock::iterator MBBI,
Register TargetReg, bool FrameSetup) const {
assert(TargetReg != AArch64::SP && "New top of stack cannot aleady be in SP");

MachineBasicBlock &MBB = *MBBI->getParent();
MachineFunction &MF = *MBB.getParent();
const AArch64InstrInfo *TII =
MF.getSubtarget<AArch64Subtarget>().getInstrInfo();
int64_t ProbeSize = MF.getInfo<AArch64FunctionInfo>()->getStackProbeSize();
DebugLoc DL = MBB.findDebugLoc(MBBI);

MachineFunction::iterator MBBInsertPoint = std::next(MBB.getIterator());
MachineBasicBlock *LoopTestMBB =
MF.CreateMachineBasicBlock(MBB.getBasicBlock());
MF.insert(MBBInsertPoint, LoopTestMBB);
MachineBasicBlock *LoopBodyMBB =
MF.CreateMachineBasicBlock(MBB.getBasicBlock());
MF.insert(MBBInsertPoint, LoopBodyMBB);
MachineBasicBlock *ExitMBB = MF.CreateMachineBasicBlock(MBB.getBasicBlock());
MF.insert(MBBInsertPoint, ExitMBB);
MachineInstr::MIFlag Flags =
FrameSetup ? MachineInstr::FrameSetup : MachineInstr::NoFlags;

// LoopTest:
// SUB SP, SP, #ProbeSize
emitFrameOffset(*LoopTestMBB, LoopTestMBB->end(), DL, AArch64::SP,
AArch64::SP, StackOffset::getFixed(-ProbeSize), TII, Flags);

// CMP SP, TargetReg
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(AArch64::SUBSXrx64),
AArch64::XZR)
.addReg(AArch64::SP)
.addReg(TargetReg)
.addImm(AArch64_AM::getArithExtendImm(AArch64_AM::UXTX, 0))
.setMIFlags(Flags);

// B.<Cond> LoopExit
BuildMI(*LoopTestMBB, LoopTestMBB->end(), DL, TII->get(AArch64::Bcc))
.addImm(AArch64CC::LE)
.addMBB(ExitMBB)
.setMIFlags(Flags);

// STR XZR, [SP]
BuildMI(*LoopBodyMBB, LoopBodyMBB->end(), DL, TII->get(AArch64::STRXui))
.addReg(AArch64::XZR)
.addReg(AArch64::SP)
.addImm(0)
.setMIFlags(Flags);

// B loop
BuildMI(*LoopBodyMBB, LoopBodyMBB->end(), DL, TII->get(AArch64::B))
.addMBB(LoopTestMBB)
.setMIFlags(Flags);

// LoopExit:
// MOV SP, TargetReg
BuildMI(*ExitMBB, ExitMBB->end(), DL, TII->get(AArch64::ADDXri), AArch64::SP)
.addReg(TargetReg)
.addImm(0)
.addImm(AArch64_AM::getShifterImm(AArch64_AM::LSL, 0))
.setMIFlags(Flags);

// STR XZR, [SP]
BuildMI(*ExitMBB, ExitMBB->end(), DL, TII->get(AArch64::STRXui))
.addReg(AArch64::XZR)
.addReg(AArch64::SP)
.addImm(0)
.setMIFlags(Flags);
Comment on lines +9528 to +9533
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @momchil-velikov , I'm encountering an issue where the stack probe is overwriting stack memory, similar to the issue I faced before. I apologize for not addressing this earlier, but backporting these changes introduced significant overhead.

The issue seems to be caused by this final stack probe instruction that is overwriting a stack value. I'm considering removing this probe instruction as a potential solution.

I don't anticipate this change will compromise security. Given our current stack probing strategy, we're always within one page of the most recent probe. Therefore, any subsequent instructions accessing memory at [sp] or above will either be valid or trigger a guard page fault.

I can't definitively confirm whether this issue is a result of backporting or an inherent problem until I upgrade to LLVM 18 which is some ways away.

Please let me know if my understanding is incorrect. I appreciate your work on this patchset and thank you for your assistance! :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming TargetReg is lower than SP on entry at "LoopTest", the final store must be to an address lower than SP on entry to the sequence. (And since SP is always 16-byte aligned, it's impossible to have an issue with partial overlap.) Since that's freshly allocated memory, nothing should care what it contains.

I guess maybe weird things could happen if something tries to allocate 0 bytes of memory? Probably something that needs to be fixed, but it's unlikely you'd run into it from C code.

Also, maybe worth checking your code isn't depending on a value of an uninitialized variable.

Removing the final store completely breaks the guarantees the code is supposed to provide: if the allocation is less than the page size, the only store is the final store. So if the compiled code, for example, calls "alloca(1024)" in a loop, you can skip over an arbitrary number of pages.

Copy link
Contributor

@oskarwirga oskarwirga Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for that explanation, yes I see it leaves a large gap. Theres like a million optimizations and mitigations slapped on this code so it gets difficult understanding the root cause.

I guess maybe weird things could happen if something tries to allocate 0 bytes of memory? Probably something that needs to be fixed, but it's unlikely you'd run into it from C code.

Thanks for this, it made me look into it a bit more and we're not allocating 0 bytes, but we are allocating 0x10 bytes which might cause some issues. I appreciate your quick reply thank you again!

EDIT: We are allocating 0 bytes lol I read it wrong. OK so thats whats the issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed this in #74806


ExitMBB->splice(ExitMBB->end(), &MBB, std::next(MBBI), MBB.end());
ExitMBB->transferSuccessorsAndUpdatePHIs(&MBB);

LoopTestMBB->addSuccessor(ExitMBB);
LoopTestMBB->addSuccessor(LoopBodyMBB);
LoopBodyMBB->addSuccessor(LoopTestMBB);
MBB.addSuccessor(LoopTestMBB);

// Update liveins.
if (MF.getRegInfo().reservedRegsFrozen()) {
recomputeLiveIns(*LoopTestMBB);
recomputeLiveIns(*LoopBodyMBB);
recomputeLiveIns(*ExitMBB);
}

return ExitMBB->begin();
}

#define GET_INSTRINFO_HELPERS
#define GET_INSTRMAP_INFO
#include "AArch64GenInstrInfo.inc"
7 changes: 7 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,13 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
bool isLegalAddressingMode(unsigned NumBytes, int64_t Offset,
unsigned Scale) const;

// Decrement the SP, issuing probes along the way. `TargetReg` is the new top
// of the stack. `FrameSetup` is passed as true, if the allocation is a part
// of constructing the activation frame of a function.
MachineBasicBlock::iterator probedStackAlloc(MachineBasicBlock::iterator MBBI,
Register TargetReg,
bool FrameSetup) const;

#define GET_INSTRINFO_HELPER_DECLS
#include "AArch64GenInstrInfo.inc"

Expand Down
25 changes: 23 additions & 2 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -936,7 +936,8 @@ include "SMEInstrFormats.td"
// Miscellaneous instructions.
//===----------------------------------------------------------------------===//

let Defs = [SP], Uses = [SP], hasSideEffects = 1, isCodeGenOnly = 1 in {
let hasSideEffects = 1, isCodeGenOnly = 1 in {
let Defs = [SP], Uses = [SP] in {
// We set Sched to empty list because we expect these instructions to simply get
// removed in most cases.
def ADJCALLSTACKDOWN : Pseudo<(outs), (ins i32imm:$amt1, i32imm:$amt2),
Expand All @@ -945,7 +946,27 @@ def ADJCALLSTACKDOWN : Pseudo<(outs), (ins i32imm:$amt1, i32imm:$amt2),
def ADJCALLSTACKUP : Pseudo<(outs), (ins i32imm:$amt1, i32imm:$amt2),
[(AArch64callseq_end timm:$amt1, timm:$amt2)]>,
Sched<[]>;
} // Defs = [SP], Uses = [SP], hasSideEffects = 1, isCodeGenOnly = 1

}

let Defs = [SP, NZCV], Uses = [SP] in {
// Probed stack allocation of a constant size, used in function prologues when
// stack-clash protection is enabled.
def PROBED_STACKALLOC : Pseudo<(outs GPR64:$scratch),
(ins i64imm:$stacksize, i64imm:$fixed_offset,
i64imm:$scalable_offset),
[]>,
Sched<[]>;

// Probed stack allocation of a variable size, used in function prologues when
// stack-clash protection is enabled.
def PROBED_STACKALLOC_VAR : Pseudo<(outs),
(ins GPR64sp:$target),
[]>,
Sched<[]>;

} // Defs = [SP, NZCV], Uses = [SP] in
} // hasSideEffects = 1, isCodeGenOnly = 1

let isReMaterializable = 1, isCodeGenOnly = 1 in {
// FIXME: The following pseudo instructions are only needed because remat
Expand Down
43 changes: 37 additions & 6 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,45 @@ AArch64FunctionInfo::AArch64FunctionInfo(const Function &F,
if (const auto *BTE = mdconst::extract_or_null<ConstantInt>(
F.getParent()->getModuleFlag("branch-target-enforcement")))
BranchTargetEnforcement = BTE->getZExtValue();
return;
} else {
const StringRef BTIEnable =
F.getFnAttribute("branch-target-enforcement").getValueAsString();
assert(BTIEnable.equals_insensitive("true") ||
BTIEnable.equals_insensitive("false"));
momchil-velikov marked this conversation as resolved.
Show resolved Hide resolved
BranchTargetEnforcement = BTIEnable.equals_insensitive("true");
}

const StringRef BTIEnable =
F.getFnAttribute("branch-target-enforcement").getValueAsString();
assert(BTIEnable.equals_insensitive("true") ||
BTIEnable.equals_insensitive("false"));
BranchTargetEnforcement = BTIEnable.equals_insensitive("true");
// The default stack probe size is 4096 if the function has no
// stack-probe-size attribute. This is a safe default because it is the
// smallest possible guard page size.
uint64_t ProbeSize = 4096;
if (F.hasFnAttribute("stack-probe-size"))
ProbeSize = F.getFnAttributeAsParsedInteger("stack-probe-size");
else if (const auto *PS = mdconst::extract_or_null<ConstantInt>(
F.getParent()->getModuleFlag("stack-probe-size")))
ProbeSize = PS->getZExtValue();
assert(int64_t(ProbeSize) > 0 && "Invalid stack probe size");

if (STI->isTargetWindows()) {
if (!F.hasFnAttribute("no-stack-arg-probe"))
StackProbeSize = ProbeSize;
} else {
// Round down to the stack alignment.
uint64_t StackAlign =
STI->getFrameLowering()->getTransientStackAlign().value();
ProbeSize = std::max(StackAlign, ProbeSize & ~(StackAlign - 1U));
StringRef ProbeKind;
if (F.hasFnAttribute("probe-stack"))
ProbeKind = F.getFnAttribute("probe-stack").getValueAsString();
else if (const auto *PS = dyn_cast_or_null<MDString>(
F.getParent()->getModuleFlag("probe-stack")))
ProbeKind = PS->getString();
if (ProbeKind.size()) {
if (ProbeKind != "inline-asm")
report_fatal_error("Unsupported stack probing method");
StackProbeSize = ProbeSize;
}
}
}

MachineFunctionInfo *AArch64FunctionInfo::clone(
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/AArch64/AArch64MachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
/// True if the function need asynchronous unwind information.
mutable std::optional<bool> NeedsAsyncDwarfUnwindInfo;

int64_t StackProbeSize = 0;

public:
AArch64FunctionInfo(const Function &F, const AArch64Subtarget *STI);

Expand Down Expand Up @@ -456,6 +458,10 @@ class AArch64FunctionInfo final : public MachineFunctionInfo {
HasStreamingModeChanges = HasChanges;
}

bool hasStackProbing() const { return StackProbeSize != 0; }

int64_t getStackProbeSize() const { return StackProbeSize; }

private:
// Hold the lists of LOHs.
MILOHContainer LOHContainerSet;
Expand Down
Loading
Loading