Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve spinloop detection #786

Merged
merged 3 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ public Expression makeFloatCast(Expression operand, FloatType targetType, boolea
// -----------------------------------------------------------------------------------------------------------------
// Aggregates

public Expression makeConstruct(Type type, List<Expression> arguments) {
public Expression makeConstruct(Type type, List<? extends Expression> arguments) {
return new ConstructExpr(type, arguments);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

public final class ConstructExpr extends NaryExpressionBase<Type, ExpressionKind.Other> {

public ConstructExpr(Type type, List<Expression> arguments) {
public ConstructExpr(Type type, List<? extends Expression> arguments) {
super(type, ExpressionKind.Other.CONSTRUCT, List.copyOf(arguments));
checkArgument(type instanceof AggregateType || type instanceof ArrayType,
"Non-constructible type %s.", type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@

import com.dat3m.dartagnan.expression.Expression;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.expression.type.AggregateType;
import com.dat3m.dartagnan.expression.type.TypeFactory;
import com.dat3m.dartagnan.program.Function;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.Register;
import com.dat3m.dartagnan.program.Thread;
import com.dat3m.dartagnan.program.analysis.DominatorAnalysis;
import com.dat3m.dartagnan.program.analysis.LiveRegistersAnalysis;
import com.dat3m.dartagnan.program.analysis.LoopAnalysis;
import com.dat3m.dartagnan.program.event.*;
import com.dat3m.dartagnan.program.event.Event;
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.RegWriter;
import com.dat3m.dartagnan.program.event.Tag;
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.functions.FunctionCall;
import com.dat3m.dartagnan.program.event.lang.svcomp.SpinStart;
import com.dat3m.dartagnan.utils.DominatorTree;
import com.google.common.base.Preconditions;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.sosy_lab.common.configuration.Configuration;
Expand All @@ -26,13 +28,31 @@
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

/*
This pass instruments loops that do not cause a side effect in an iteration to terminate, i.e., to avoid spinning.
In other words, only the last loop iteration is allowed to be side-effect free.
Instrumentation:
loop_header:
__localLiveSnapshot <- ( <list of writable live registers> ) // To track local side effects
__globalSideEffect <- false // To track global side effects
// ---------- Loop body ----------
...
__globalSideEffect <- true // Store is a guaranteed global effect
store(...)
...
// ---------- Backjump ----------
// Local side effect if value of any live register changed.
__localSideEffect <- __localLiveSnapshot != ( <list of writable live registers> )
if (!(__localSideEffect || __globalSideEffect)) goto END_OF_T ### SPINLOOP
goto loop_header
NOTE: "<list of writable live registers>" refers to those registers
(1) that are live inside or after the loop
and (2) that are potentially written to inside the loop (i.e. are not invariant).
NOTE: This pass is required to detect liveness violations.
*/
public class DynamicSpinLoopDetection implements ProgramProcessor {
Expand All @@ -49,19 +69,17 @@ public void run(Program program) {
"DynamicSpinLoopDetection cannot be run on already unrolled programs.");

final LoopAnalysis loopAnalysis = LoopAnalysis.newInstance(program);
AnalysisStats stats = new AnalysisStats(0, 0);
AnalysisStats stats = new AnalysisStats(0);
for (Function func : Iterables.concat(program.getFunctions(), program.getThreads())) {
final List<LoopData> loops = computeLoopData(func, loopAnalysis);
loops.forEach(this::collectSideEffects);
loops.forEach(this::reduceToDominatingSideEffects);
final LiveRegistersAnalysis liveRegsAna = LiveRegistersAnalysis.forFunction(func);
loops.forEach(loop -> this.collectSideEffects(loop, liveRegsAna));
loops.forEach(this::instrumentLoop);
stats = stats.add(collectStats(loops));
}
IdReassignment.newInstance().run(program); // Reassign ids for the instrumented code.

// NOTE: We log "potential spin loops" as only those that are not also "static".
logger.info("Found {} static spin loops and {} potential spin loops.",
stats.numStaticSpinLoops, (stats.numPotentialSpinLoops - stats.numStaticSpinLoops));
logger.info("Found {} static spin loops.", stats.numStaticSpinLoops());
}

// ==================================================================================================
Expand All @@ -72,167 +90,108 @@ private List<LoopData> computeLoopData(Function func, LoopAnalysis loopAnalysis)
return loops.stream().map(LoopData::new).toList();
}

private void collectSideEffects(LoopData loop) {
// Loop shape is: loopStart -> Calling intrinsic __VERIFIER_spin_start -> #__VERIFIER_spin_start
if (loop.getStart().getSuccessor().getSuccessor() instanceof SpinStart) {
// A user-placed annotation guarantees absence of side effects.

// This checks assumes the following implementation of await_while
// #define await_while(cond) \
// for (int tmp = (__VERIFIER_loop_begin(), 0); __VERIFIER_spin_start(), \
// tmp = cond, __VERIFIER_spin_end(!tmp), tmp;)
return;
}

// FIXME: The reasoning about safe/unsafe registers is not correct because
// we do not traverse the control-flow but naively go top-down through the loop.
// We need to use proper dominator reasoning!

// Unsafe means the loop reads from the registers before writing to them.
Set<Register> unsafeRegisters = new HashSet<>();
// Safe means the loop wrote to these register before using them
Set<Register> safeRegisters = new HashSet<>();

private void collectSideEffects(LoopData loop, LiveRegistersAnalysis liveRegsAna) {
final Set<Register> writtenRegisters = new HashSet<>();
Event cur = loop.getStart();
do {
if (cur instanceof RegWriter writer) {
writtenRegisters.add(writer.getResultRegister());
}
if (cur.hasTag(Tag.WRITE) || (cur instanceof FunctionCall call &&
(!call.isDirectCall()
|| !call.getCalledFunction().isIntrinsic()
|| call.getCalledFunction().getIntrinsicInfo().writesMemory()))) {
// We assume side effects for all writes, writing intrinsics, and non-intrinsic function calls.
loop.sideEffects.add(cur);
continue;
}

if (cur instanceof RegReader regReader) {
final Set<Register> dataRegs = regReader.getRegisterReads().stream()
.map(Register.Read::register).collect(Collectors.toSet());
unsafeRegisters.addAll(Sets.difference(dataRegs, safeRegisters));
}

if (cur instanceof RegWriter writer) {
if (unsafeRegisters.contains(writer.getResultRegister())) {
// The loop writes to a register it previously read from.
// This means the next loop iteration will observe the newly written value,
// hence the loop is not side effect free.
loop.sideEffects.add(cur);
} else {
safeRegisters.add(writer.getResultRegister());
}
loop.globalSideEffects.add(cur);
}
} while ((cur = cur.getSuccessor()) != loop.getEnd().getSuccessor());

}

private void reduceToDominatingSideEffects(LoopData loop) {
if (loop.sideEffects.isEmpty()) {
return;
}

final List<Event> sideEffects = loop.sideEffects;
final Event start = loop.getStart();
final Event end = loop.getEnd();
final Predicate<Event> isAlwaysSideEffectful = (e -> e.cfImpliesExec() && sideEffects.contains(e));

final DominatorTree<Event> preDominatorTree = DominatorAnalysis.computePreDominatorTree(start, end);
final DominatorTree<Event> postDominatorTree = DominatorAnalysis.computePostDominatorTree(start, end);

// (1) Delete all side effects that are on exit paths (those have no dominator in the post-dominator tree
// because they are no predecessor of the end of the loop body).
final Predicate<Event> isOnExitPath = (e -> postDominatorTree.getImmediateDominator(e) == null);
sideEffects.removeIf(isOnExitPath);

// (2) Check if always side-effect-full at the end of an iteration directly before entering the next one.
// This is an approximation: If the end of the iteration is predominated by some side effect
// then we always observe side effects.
loop.isAlwaysSideEffectful = Streams.stream(preDominatorTree.getDominators(end)).anyMatch(isAlwaysSideEffectful);
if (loop.isAlwaysSideEffectful) {
return;
}

// (3) Delete all side effects that are dominated by another one
// NOTE: This can be implemented more efficiently, but maybe we don't need to.
for (int i = sideEffects.size() - 1; i >= 0; i--) {
final Event sideEffect = sideEffects.get(i);
final Iterable<Event> dominators = Iterables.concat(
preDominatorTree.getStrictDominators(sideEffect),
postDominatorTree.getStrictDominators(sideEffect)
);
final boolean isDominated = Iterables.tryFind(dominators, isAlwaysSideEffectful::test).isPresent();
if (isDominated) {
sideEffects.remove(i);
}
}
// Every live register that is written to is a potential local side effect.
loop.writtenLiveRegisters.addAll(Sets.intersection(
writtenRegisters,
liveRegsAna.getLiveRegistersAt(loop.getStart())
));
}

private void instrumentLoop(LoopData loop) {
hernanponcedeleon marked this conversation as resolved.
Show resolved Hide resolved
if (loop.isAlwaysSideEffectful) {
return;
}

final TypeFactory types = TypeFactory.getInstance();
final ExpressionFactory expressions = ExpressionFactory.getInstance();

final Function func = loop.loopInfo.function();
final int loopNum = loop.loopInfo.loopNumber();

final AggregateType liveRegsType = types.getAggregateType(Lists.transform(loop.writtenLiveRegisters, Register::getType));
final Expression liveRegistersVector = expressions.makeConstruct(liveRegsType, loop.writtenLiveRegisters);
final Register entryLiveStateRegister = func.newRegister("__localLiveSnapshot#" + loopNum, liveRegsType);
final Register tempReg = func.newRegister("__possiblySideEffectless#" + loopNum, types.getBooleanType());
final Register trackingReg = func.newRegister("__sideEffect#" + loopNum, types.getBooleanType());
final Register globalSideEffectReg = func.newRegister("__globalSideEffect#" + loopNum, types.getBooleanType());

final Event init = EventFactory.newLocal(trackingReg, expressions.makeFalse());
loop.getStart().insertAfter(init);
for (Event sideEffect : loop.sideEffects) {
// ---------------- Instrumentation ----------------
// Init tracking registers
loop.getStart().insertAfter(List.of(
EventFactory.newLocal(entryLiveStateRegister, liveRegistersVector),
EventFactory.newLocal(globalSideEffectReg, expressions.makeFalse())
));

// Track global side effects
for (Event sideEffect : loop.globalSideEffects) {
final List<Event> updateSideEffect = new ArrayList<>();
if (sideEffect.cfImpliesExec()) {
updateSideEffect.add(EventFactory.newLocal(trackingReg, expressions.makeTrue()));
updateSideEffect.add(EventFactory.newLocal(globalSideEffectReg, expressions.makeTrue()));
} else {
updateSideEffect.addAll(List.of(
EventFactory.newExecutionStatus(tempReg, sideEffect),
EventFactory.newLocal(trackingReg, expressions.makeOr(trackingReg, expressions.makeNot(tempReg)))
EventFactory.newLocal(globalSideEffectReg, expressions.makeOr(globalSideEffectReg, expressions.makeNot(tempReg)))
));

}
sideEffect.getPredecessor().insertAfter(updateSideEffect);
}

final Event assumeSideEffect = newSpinTerminator(expressions.makeNot(trackingReg), func);
assumeSideEffect.copyAllMetadataFrom(loop.getStart());
loop.getEnd().getPredecessor().insertAfter(assumeSideEffect);
// Check if any local or global side effects occurred. If not, spin!
final Register localSideEffectReg = func.newRegister("__localSideEffect#" + loopNum, types.getBooleanType());
final Expression hasSideEffect = expressions.makeOr(localSideEffectReg, globalSideEffectReg);

final Event assignLocalSideEffectReg = EventFactory.newLocal(localSideEffectReg, expressions.makeNEQ(entryLiveStateRegister, liveRegistersVector));
final Event assumeSideEffect = newSpinTerminator(expressions.makeNot(hasSideEffect), loop);
loop.getEnd().getPredecessor().insertAfter(List.of(
assignLocalSideEffectReg,
assumeSideEffect
));

// Special case: If the loop is fully side-effect-free, we can set its unrolling bound to 1.
if (loop.sideEffects.isEmpty()) {
if (loop.isSideEffectFree()) {
final Event loopBound = EventFactory.Svcomp.newLoopBound(expressions.makeValue(1, types.getArchType()));
loop.getStart().getPredecessor().insertAfter(loopBound);
}
}

private Event newSpinTerminator(Expression guard, Function func) {
private Event newSpinTerminator(Expression guard, LoopData loop) {
final Function func = loop.getStart().getFunction();
final Event terminator = func instanceof Thread thread ?
EventFactory.newJump(guard, (Label) thread.getExit())
: EventFactory.newAbortIf(guard);
terminator.addTags(Tag.SPINLOOP, Tag.NONTERMINATION, Tag.NOOPT);
terminator.addTags(Tag.SPINLOOP, Tag.NONTERMINATION);
hernanponcedeleon marked this conversation as resolved.
Show resolved Hide resolved
terminator.copyAllMetadataFrom(loop.getStart());
return terminator;
}

private AnalysisStats collectStats(List<LoopData> loops) {
int numPotentialSpinLoops = 0;
int numStaticSpinLoops = 0;
for (LoopData loop : loops) {
if (!loop.isAlwaysSideEffectful) {
numPotentialSpinLoops++;
if (loop.sideEffects.isEmpty()) {
numStaticSpinLoops++;
}
}
}
return new AnalysisStats(numPotentialSpinLoops, numStaticSpinLoops);
int numStaticSpinLoops = Math.toIntExact(loops.stream().filter(LoopData::isSideEffectFree).count());
return new AnalysisStats(numStaticSpinLoops);
}

// ==================================================================================================
// Inner data structures

private static class LoopData {
private final LoopAnalysis.LoopInfo loopInfo;
private final List<Event> sideEffects = new ArrayList<>();
private boolean isAlwaysSideEffectful = false;
private final List<Event> globalSideEffects = new ArrayList<>();
private final List<Register> writtenLiveRegisters = new ArrayList<>();

public boolean isSideEffectFree() {
return writtenLiveRegisters.isEmpty() && globalSideEffects.isEmpty();
}

private LoopData(LoopAnalysis.LoopInfo loopInfo) {
this.loopInfo = loopInfo;
Expand All @@ -248,11 +207,10 @@ public String toString() {
}
}

private record AnalysisStats(int numPotentialSpinLoops, int numStaticSpinLoops) {
private record AnalysisStats(int numStaticSpinLoops) {

private AnalysisStats add(AnalysisStats stats) {
return new AnalysisStats(this.numPotentialSpinLoops + stats.numPotentialSpinLoops,
this.numStaticSpinLoops + stats.numStaticSpinLoops);
return new AnalysisStats(this.numStaticSpinLoops + stats.numStaticSpinLoops);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import java.util.EnumSet;

import static com.dat3m.dartagnan.configuration.Arch.C11;
import static com.dat3m.dartagnan.configuration.Property.*;
import static com.dat3m.dartagnan.utils.ResourceHelper.getTestResourcePath;
import static com.dat3m.dartagnan.utils.Result.*;
import static org.junit.Assert.assertEquals;
import static com.dat3m.dartagnan.configuration.Property.*;

@RunWith(Parameterized.class)
public class VMMLocksTest extends AbstractCTest {
Expand Down Expand Up @@ -82,7 +82,7 @@ public static Iterable<Object[]> data() throws IOException {
{"mutex_musl-rel2rx_futex", C11, FAIL},
{"mutex_musl-rel2rx_unlock", C11, FAIL},
{"seqlock", C11, PASS},
{"clh_mutex", C11, UNKNOWN},
{"clh_mutex", C11, PASS},
{"clh_mutex-acq2rx", C11, FAIL},
{"ticket_awnsb_mutex", C11, PASS},
{"ticket_awnsb_mutex-acq2rx", C11, FAIL},
Expand Down
Loading