Skip to content

Commit

Permalink
Improve spinloop detection (hernanponcedeleon#786)
Browse files Browse the repository at this point in the history
* Improve DynamicSpinLoopDetection

* Update verdict of VMMLocksTest
  • Loading branch information
ThomasHaas authored and Tianrui Zheng committed Nov 26, 2024
1 parent 46308d9 commit 9de5e16
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public Expression makeFloatCast(Expression operand, FloatType targetType, boolea
// -----------------------------------------------------------------------------------------------------------------
// Aggregates

public Expression makeConstruct(Type type, List<Expression> arguments) {
public Expression makeConstruct(Type type, List<? extends Expression> arguments) {
return new ConstructExpr(type, arguments);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

public final class ConstructExpr extends NaryExpressionBase<Type, ExpressionKind.Other> {

public ConstructExpr(Type type, List<Expression> arguments) {
public ConstructExpr(Type type, List<? extends Expression> arguments) {
super(type, ExpressionKind.Other.CONSTRUCT, List.copyOf(arguments));
checkArgument(type instanceof AggregateType || type instanceof ArrayType,
"Non-constructible type %s.", type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@

import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.expression.type.AggregateType;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.program.Function;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.Register;
import com.dat3m.dartagnan.program.Thread;
import com.dat3m.dartagnan.program.analysis.DominatorAnalysis;
import com.dat3m.dartagnan.program.analysis.LiveRegistersAnalysis;
import com.dat3m.dartagnan.program.analysis.LoopAnalysis;
import com.dat3m.dartagnan.program.event.*;
import com.dat3m.dartagnan.program.event.Event;
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.RegWriter;
import com.dat3m.dartagnan.program.event.Tag;
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.functions.FunctionCall;
import com.dat3m.dartagnan.program.event.lang.svcomp.SpinStart;
import com.dat3m.dartagnan.utils.DominatorTree;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.sosy_lab.common.configuration.Configuration;
Expand All @@ -26,13 +28,31 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/*
This pass instruments loops that do not cause a side effect in an iteration to terminate, i.e., to avoid spinning.
In other words, only the last loop iteration is allowed to be side-effect free.
Instrumentation:
loop_header:
__localLiveSnapshot <- ( <list of writable live registers> ) // To track local side effects
__globalSideEffect <- false // To track global side effects
// ---------- Loop body ----------
...
__globalSideEffect <- true // Store is a guaranteed global effect
store(...)
...
// ---------- Backjump ----------
// Local side effect if value of any live register changed.
__localSideEffect <- __localLiveSnapshot != ( <list of writable live registers> )
if (!(__localSideEffect || __globalSideEffect)) goto END_OF_T ### SPINLOOP
goto loop_header
NOTE: "<list of writable live registers>" refers to those registers
(1) that are live inside or after the loop
and (2) that are potentially written to inside the loop (i.e. are not invariant).
NOTE: This pass is required to detect liveness violations.
*/
public class DynamicSpinLoopDetection implements ProgramProcessor {
Expand All @@ -49,19 +69,17 @@ public void run(Program program) {
"DynamicSpinLoopDetection cannot be run on already unrolled programs.");

final LoopAnalysis loopAnalysis = LoopAnalysis.newInstance(program);
AnalysisStats stats = new AnalysisStats(0, 0);
AnalysisStats stats = new AnalysisStats(0);
for (Function func : Iterables.concat(program.getFunctions(), program.getThreads())) {
final List<LoopData> loops = computeLoopData(func, loopAnalysis);
loops.forEach(this::collectSideEffects);
loops.forEach(this::reduceToDominatingSideEffects);
final LiveRegistersAnalysis liveRegsAna = LiveRegistersAnalysis.forFunction(func);
loops.forEach(loop -> this.collectSideEffects(loop, liveRegsAna));
loops.forEach(this::instrumentLoop);
stats = stats.add(collectStats(loops));
}
IdReassignment.newInstance().run(program); // Reassign ids for the instrumented code.

// NOTE: We log "potential spin loops" as only those that are not also "static".
logger.info("Found {} static spin loops and {} potential spin loops.",
stats.numStaticSpinLoops, (stats.numPotentialSpinLoops - stats.numStaticSpinLoops));
logger.info("Found {} static spin loops.", stats.numStaticSpinLoops());
}

// ==================================================================================================
Expand All @@ -72,167 +90,108 @@ private List<LoopData> computeLoopData(Function func, LoopAnalysis loopAnalysis)
return loops.stream().map(LoopData::new).toList();
}

private void collectSideEffects(LoopData loop) {
// Loop shape is: loopStart -> Calling intrinsic __VERIFIER_spin_start -> #__VERIFIER_spin_start
if (loop.getStart().getSuccessor().getSuccessor() instanceof SpinStart) {
// A user-placed annotation guarantees absence of side effects.

// This checks assumes the following implementation of await_while
// #define await_while(cond) \
// for (int tmp = (__VERIFIER_loop_begin(), 0); __VERIFIER_spin_start(), \
// tmp = cond, __VERIFIER_spin_end(!tmp), tmp;)
return;
}

// FIXME: The reasoning about safe/unsafe registers is not correct because
// we do not traverse the control-flow but naively go top-down through the loop.
// We need to use proper dominator reasoning!

// Unsafe means the loop reads from the registers before writing to them.
Set<Register> unsafeRegisters = new HashSet<>();
// Safe means the loop wrote to these register before using them
Set<Register> safeRegisters = new HashSet<>();

private void collectSideEffects(LoopData loop, LiveRegistersAnalysis liveRegsAna) {
final Set<Register> writtenRegisters = new HashSet<>();
Event cur = loop.getStart();
do {
if (cur instanceof RegWriter writer) {
writtenRegisters.add(writer.getResultRegister());
}
if (cur.hasTag(Tag.WRITE) || (cur instanceof FunctionCall call &&
(!call.isDirectCall()
|| !call.getCalledFunction().isIntrinsic()
|| call.getCalledFunction().getIntrinsicInfo().writesMemory()))) {
// We assume side effects for all writes, writing intrinsics, and non-intrinsic function calls.
loop.sideEffects.add(cur);
continue;
}

if (cur instanceof RegReader regReader) {
final Set<Register> dataRegs = regReader.getRegisterReads().stream()
.map(Register.Read::register).collect(Collectors.toSet());
unsafeRegisters.addAll(Sets.difference(dataRegs, safeRegisters));
}

if (cur instanceof RegWriter writer) {
if (unsafeRegisters.contains(writer.getResultRegister())) {
// The loop writes to a register it previously read from.
// This means the next loop iteration will observe the newly written value,
// hence the loop is not side effect free.
loop.sideEffects.add(cur);
} else {
safeRegisters.add(writer.getResultRegister());
}
loop.globalSideEffects.add(cur);
}
} while ((cur = cur.getSuccessor()) != loop.getEnd().getSuccessor());

}

private void reduceToDominatingSideEffects(LoopData loop) {
if (loop.sideEffects.isEmpty()) {
return;
}

final List<Event> sideEffects = loop.sideEffects;
final Event start = loop.getStart();
final Event end = loop.getEnd();
final Predicate<Event> isAlwaysSideEffectful = (e -> e.cfImpliesExec() && sideEffects.contains(e));

final DominatorTree<Event> preDominatorTree = DominatorAnalysis.computePreDominatorTree(start, end);
final DominatorTree<Event> postDominatorTree = DominatorAnalysis.computePostDominatorTree(start, end);

// (1) Delete all side effects that are on exit paths (those have no dominator in the post-dominator tree
// because they are no predecessor of the end of the loop body).
final Predicate<Event> isOnExitPath = (e -> postDominatorTree.getImmediateDominator(e) == null);
sideEffects.removeIf(isOnExitPath);

// (2) Check if always side-effect-full at the end of an iteration directly before entering the next one.
// This is an approximation: If the end of the iteration is predominated by some side effect
// then we always observe side effects.
loop.isAlwaysSideEffectful = Streams.stream(preDominatorTree.getDominators(end)).anyMatch(isAlwaysSideEffectful);
if (loop.isAlwaysSideEffectful) {
return;
}

// (3) Delete all side effects that are dominated by another one
// NOTE: This can be implemented more efficiently, but maybe we don't need to.
for (int i = sideEffects.size() - 1; i >= 0; i--) {
final Event sideEffect = sideEffects.get(i);
final Iterable<Event> dominators = Iterables.concat(
preDominatorTree.getStrictDominators(sideEffect),
postDominatorTree.getStrictDominators(sideEffect)
);
final boolean isDominated = Iterables.tryFind(dominators, isAlwaysSideEffectful::test).isPresent();
if (isDominated) {
sideEffects.remove(i);
}
}
// Every live register that is written to is a potential local side effect.
loop.writtenLiveRegisters.addAll(Sets.intersection(
writtenRegisters,
liveRegsAna.getLiveRegistersAt(loop.getStart())
));
}

private void instrumentLoop(LoopData loop) {
if (loop.isAlwaysSideEffectful) {
return;
}

final TypeFactory types = TypeFactory.getInstance();
final ExpressionFactory expressions = ExpressionFactory.getInstance();

final Function func = loop.loopInfo.function();
final int loopNum = loop.loopInfo.loopNumber();

final AggregateType liveRegsType = types.getAggregateType(Lists.transform(loop.writtenLiveRegisters, Register::getType));
final Expression liveRegistersVector = expressions.makeConstruct(liveRegsType, loop.writtenLiveRegisters);
final Register entryLiveStateRegister = func.newRegister("__localLiveSnapshot#" + loopNum, liveRegsType);
final Register tempReg = func.newRegister("__possiblySideEffectless#" + loopNum, types.getBooleanType());
final Register trackingReg = func.newRegister("__sideEffect#" + loopNum, types.getBooleanType());
final Register globalSideEffectReg = func.newRegister("__globalSideEffect#" + loopNum, types.getBooleanType());

final Event init = EventFactory.newLocal(trackingReg, expressions.makeFalse());
loop.getStart().insertAfter(init);
for (Event sideEffect : loop.sideEffects) {
// ---------------- Instrumentation ----------------
// Init tracking registers
loop.getStart().insertAfter(List.of(
EventFactory.newLocal(entryLiveStateRegister, liveRegistersVector),
EventFactory.newLocal(globalSideEffectReg, expressions.makeFalse())
));

// Track global side effects
for (Event sideEffect : loop.globalSideEffects) {
final List<Event> updateSideEffect = new ArrayList<>();
if (sideEffect.cfImpliesExec()) {
updateSideEffect.add(EventFactory.newLocal(trackingReg, expressions.makeTrue()));
updateSideEffect.add(EventFactory.newLocal(globalSideEffectReg, expressions.makeTrue()));
} else {
updateSideEffect.addAll(List.of(
EventFactory.newExecutionStatus(tempReg, sideEffect),
EventFactory.newLocal(trackingReg, expressions.makeOr(trackingReg, expressions.makeNot(tempReg)))
EventFactory.newLocal(globalSideEffectReg, expressions.makeOr(globalSideEffectReg, expressions.makeNot(tempReg)))
));

}
sideEffect.getPredecessor().insertAfter(updateSideEffect);
}

final Event assumeSideEffect = newSpinTerminator(expressions.makeNot(trackingReg), func);
assumeSideEffect.copyAllMetadataFrom(loop.getStart());
loop.getEnd().getPredecessor().insertAfter(assumeSideEffect);
// Check if any local or global side effects occurred. If not, spin!
final Register localSideEffectReg = func.newRegister("__localSideEffect#" + loopNum, types.getBooleanType());
final Expression hasSideEffect = expressions.makeOr(localSideEffectReg, globalSideEffectReg);

final Event assignLocalSideEffectReg = EventFactory.newLocal(localSideEffectReg, expressions.makeNEQ(entryLiveStateRegister, liveRegistersVector));
final Event assumeSideEffect = newSpinTerminator(expressions.makeNot(hasSideEffect), loop);
loop.getEnd().getPredecessor().insertAfter(List.of(
assignLocalSideEffectReg,
assumeSideEffect
));

// Special case: If the loop is fully side-effect-free, we can set its unrolling bound to 1.
if (loop.sideEffects.isEmpty()) {
if (loop.isSideEffectFree()) {
final Event loopBound = EventFactory.Svcomp.newLoopBound(expressions.makeValue(1, types.getArchType()));
loop.getStart().getPredecessor().insertAfter(loopBound);
}
}

private Event newSpinTerminator(Expression guard, Function func) {
private Event newSpinTerminator(Expression guard, LoopData loop) {
final Function func = loop.getStart().getFunction();
final Event terminator = func instanceof Thread thread ?
EventFactory.newJump(guard, (Label) thread.getExit())
: EventFactory.newAbortIf(guard);
terminator.addTags(Tag.SPINLOOP, Tag.NONTERMINATION, Tag.NOOPT);
terminator.addTags(Tag.SPINLOOP, Tag.NONTERMINATION);
terminator.copyAllMetadataFrom(loop.getStart());
return terminator;
}

private AnalysisStats collectStats(List<LoopData> loops) {
int numPotentialSpinLoops = 0;
int numStaticSpinLoops = 0;
for (LoopData loop : loops) {
if (!loop.isAlwaysSideEffectful) {
numPotentialSpinLoops++;
if (loop.sideEffects.isEmpty()) {
numStaticSpinLoops++;
}
}
}
return new AnalysisStats(numPotentialSpinLoops, numStaticSpinLoops);
int numStaticSpinLoops = Math.toIntExact(loops.stream().filter(LoopData::isSideEffectFree).count());
return new AnalysisStats(numStaticSpinLoops);
}

// ==================================================================================================
// Inner data structures

private static class LoopData {
private final LoopAnalysis.LoopInfo loopInfo;
private final List<Event> sideEffects = new ArrayList<>();
private boolean isAlwaysSideEffectful = false;
private final List<Event> globalSideEffects = new ArrayList<>();
private final List<Register> writtenLiveRegisters = new ArrayList<>();

public boolean isSideEffectFree() {
return writtenLiveRegisters.isEmpty() && globalSideEffects.isEmpty();
}

private LoopData(LoopAnalysis.LoopInfo loopInfo) {
this.loopInfo = loopInfo;
Expand All @@ -248,11 +207,10 @@ public String toString() {
}
}

private record AnalysisStats(int numPotentialSpinLoops, int numStaticSpinLoops) {
private record AnalysisStats(int numStaticSpinLoops) {

private AnalysisStats add(AnalysisStats stats) {
return new AnalysisStats(this.numPotentialSpinLoops + stats.numPotentialSpinLoops,
this.numStaticSpinLoops + stats.numStaticSpinLoops);
return new AnalysisStats(this.numStaticSpinLoops + stats.numStaticSpinLoops);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import java.util.EnumSet;

import static com.dat3m.dartagnan.configuration.Arch.C11;
import static com.dat3m.dartagnan.configuration.Property.*;
import static com.dat3m.dartagnan.utils.ResourceHelper.getTestResourcePath;
import static com.dat3m.dartagnan.utils.Result.*;
import static org.junit.Assert.assertEquals;
import static com.dat3m.dartagnan.configuration.Property.*;

@RunWith(Parameterized.class)
public class VMMLocksTest extends AbstractCTest {
Expand Down Expand Up @@ -82,7 +82,7 @@ public static Iterable<Object[]> data() throws IOException {
{"mutex_musl-rel2rx_futex", C11, FAIL},
{"mutex_musl-rel2rx_unlock", C11, FAIL},
{"seqlock", C11, PASS},
{"clh_mutex", C11, UNKNOWN},
{"clh_mutex", C11, PASS},
{"clh_mutex-acq2rx", C11, FAIL},
{"ticket_awnsb_mutex", C11, PASS},
{"ticket_awnsb_mutex-acq2rx", C11, FAIL},
Expand Down

0 comments on commit 9de5e16

Please sign in to comment.