Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow bounds to be saved to and loaded from files #759

Merged
merged 14 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dartagnan/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
<dependency>
<groupId>org.apache.maven</groupId>
<artifactId>maven-model</artifactId>
<version>3.3.9</version>
<version>${maven-model.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
<version>${commons-csv.version}</version>
</dependency>

<!-- Z3 dependency (OS independent) -->
Expand Down
69 changes: 60 additions & 9 deletions dartagnan/src/main/java/com/dat3m/dartagnan/Dartagnan.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.dat3m.dartagnan.program.event.core.Assert;
import com.dat3m.dartagnan.program.event.core.CondJump;
import com.dat3m.dartagnan.program.event.core.Load;
import com.dat3m.dartagnan.program.processing.LoopUnrolling;
import com.dat3m.dartagnan.utils.Result;
import com.dat3m.dartagnan.utils.Utils;
import com.dat3m.dartagnan.utils.options.BaseOptions;
Expand All @@ -34,6 +35,10 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.CharSource;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVParser;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.maven.model.io.xpp3.MavenXpp3Reader;
Expand All @@ -51,15 +56,15 @@

import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.math.BigInteger;
import java.nio.file.Path;
import java.util.*;

import static com.dat3m.dartagnan.GlobalSettings.getOrCreateOutputDirectory;
import static com.dat3m.dartagnan.configuration.OptionInfo.collectOptions;
import static com.dat3m.dartagnan.configuration.OptionNames.PHANTOM_REFERENCES;
import static com.dat3m.dartagnan.configuration.OptionNames.TARGET;
import static com.dat3m.dartagnan.configuration.OptionNames.*;
import static com.dat3m.dartagnan.configuration.Property.*;
import static com.dat3m.dartagnan.program.analysis.SyntacticContextAnalysis.*;
import static com.dat3m.dartagnan.utils.GitInfo.*;
Expand Down Expand Up @@ -333,17 +338,26 @@ public static String generateResultSummary(VerificationTask task, ProverEnvironm
}
} else if (result == UNKNOWN && modelChecker.hasModel()) {
// We reached unrolling bounds.
summary.append("=========== Not fully unrolled loops ============\n");
final List<Event> reachedBounds = new ArrayList<>();
for (Event ev : p.getThreadEventsWithAllTags(Tag.BOUND)) {
final boolean isReached = TRUE.equals(model.evaluate(encCtx.execution(ev)));
if (isReached) {
summary
.append("\t")
.append(synContext.getSourceLocationWithContext(ev, true))
.append("\n");
if (TRUE.equals(model.evaluate(encCtx.execution(ev)))) {
reachedBounds.add(ev);
}
}
summary.append("=========== Not fully unrolled loops ============\n");
for (Event bound : reachedBounds) {
summary
.append("\t")
.append(synContext.getSourceLocationWithContext(bound, true))
.append("\n");
}
summary.append("=================================================\n");

try {
increaseBoundAndDump(reachedBounds, task.getConfig());
} catch (IOException e) {
logger.warn("Failed to save bounds file: {}", e.getLocalizedMessage());
}
}
summary.append(result).append("\n");
} else {
Expand Down Expand Up @@ -398,6 +412,43 @@ public static String generateResultSummary(VerificationTask task, ProverEnvironm
return summary.toString();
}

private static void increaseBoundAndDump(List<Event> boundEvents, Configuration config) throws IOException {
final File boundsFile = new File(config.hasProperty(BOUNDS_SAVE_PATH) ?
config.getProperty(BOUNDS_SAVE_PATH) :
GlobalSettings.getBoundsFile());

// Parse old entries
final List<CSVRecord> entries;
try (CSVParser parser = CSVParser.parse(new FileReader(boundsFile), CSVFormat.DEFAULT)) {
entries = parser.getRecords();
}

// Compute update for entries
final Map<Integer, Integer> loopId2UpdatedBound = new HashMap<>();
for (Event e : boundEvents) {
assert e instanceof CondJump;
final CondJump loopJump = (CondJump) e;
final int loopId = LoopUnrolling.getPersistentLoopId(loopJump);
final int bound = LoopUnrolling.getUnrollingBoundAnnotation(loopJump);
loopId2UpdatedBound.put(loopId, bound + 1);
}

// Write new entries
try (CSVPrinter csvPrinter = new CSVPrinter(new FileWriter(boundsFile, false), CSVFormat.DEFAULT)) {
for (CSVRecord entry : entries) {
final int entryId = Integer.parseInt(entry.get(0));
if (!loopId2UpdatedBound.containsKey(entryId)) {
csvPrinter.printRecord(entry);
} else {
final String[] content = entry.values();
content[1] = String.valueOf(loopId2UpdatedBound.get(entryId));
csvPrinter.printRecord(Arrays.asList(content));
}
}
csvPrinter.flush();
}
}

private static void printWarningIfThreadStartFailed(Program p, EncodingContext encoder, ProverEnvironment prover)
throws SolverException {
for (Event e : p.getThreadEvents()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ public static String getCatDirectory() {
return env + "/cat";
}

public static String getBoundsFile() {
return getOutputDirectory() + "/bounds.csv";
}

hernanponcedeleon marked this conversation as resolved.
Show resolved Hide resolved
public static String getOrCreateOutputDirectory() throws IOException {
String path = getOutputDirectory();
Files.createDirectories(Paths.get(path));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ public class OptionNames {
// Base Options
public static final String PROPERTY = "property";
public static final String BOUND = "bound";
public static final String BOUNDS_LOAD_PATH = "bound.load";
public static final String BOUNDS_SAVE_PATH = "bound.save";
public static final String TARGET = "target";
public static final String METHOD = "method";
public static final String SOLVER = "solver";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.dat3m.dartagnan.program.event.metadata;

public record UnrollingBound(int value) implements Metadata { }
Original file line number Diff line number Diff line change
@@ -1,25 +1,37 @@
package com.dat3m.dartagnan.program.processing;

import com.dat3m.dartagnan.GlobalSettings;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.program.Function;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.Thread;
import com.dat3m.dartagnan.program.analysis.SyntacticContextAnalysis;
import com.dat3m.dartagnan.program.event.Event;
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.EventUser;
import com.dat3m.dartagnan.program.event.Tag;
import com.dat3m.dartagnan.program.event.core.CondJump;
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.lang.svcomp.LoopBound;
import com.dat3m.dartagnan.program.event.metadata.UnrollingBound;
import com.dat3m.dartagnan.program.event.metadata.UnrollingId;
import com.google.common.base.Preconditions;
import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.sosy_lab.common.configuration.*;

import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.*;

import static com.dat3m.dartagnan.configuration.OptionNames.BOUND;
import static com.dat3m.dartagnan.configuration.OptionNames.*;

@Options
public class LoopUnrolling implements ProgramProcessor {
Expand All @@ -39,15 +51,33 @@ public class LoopUnrolling implements ProgramProcessor {
@IntegerOption(min = 1)
private int bound = 1;

public int getUnrollingBound() { return bound; }
public int getUnrollingBound() {
return bound;
}

public void setUnrollingBound(int bound) {
Preconditions.checkArgument(bound >= 1, "The unrolling bound must be positive.");
this.bound = bound;
}

@Option(name = BOUNDS_LOAD_PATH,
description = "Path to the CSV file containing loop bounds.",
secure = true)
private String boundsLoadPath = "";

@Option(name = BOUNDS_SAVE_PATH,
description = "Path to the CSV file to save loop bounds.",
secure = true)
private String boundsSavePath = GlobalSettings.getBoundsFile();

// =====================================================================

private LoopUnrolling() { }
// We use this once for loading bounds from files (if any), and then to track
// all computed loop bounds for later storing them into a file (if desired).
private Map<Function, Map<CondJump, Integer>> globalLoopBoundsMap = new HashMap<>();

private LoopUnrolling() {
}

private LoopUnrolling(Configuration config) throws InvalidConfigurationException {
this();
Expand All @@ -69,16 +99,21 @@ public void run(Program program) {
logger.warn("Skipped unrolling: Program is already unrolled.");
return;
}

globalLoopBoundsMap = loadLoopBoundsMapFromFile(program, boundsLoadPath);

final int defaultBound = this.bound;
program.getFunctions().forEach(this::run);
program.getThreads().forEach(this::run);
program.markAsUnrolled(defaultBound);
IdReassignment.newInstance().run(program); // Reassign ids because of newly created events

dumpLoopBoundsMapToFile(program, globalLoopBoundsMap, boundsSavePath);
globalLoopBoundsMap = null; // Save up some memory

logger.info("Program unrolled {} times", defaultBound);
}


private void run(Function function) {
function.getEvents().forEach(e -> e.setMetadata(new UnrollingId(e.getGlobalId()))); // Track ids before unrolling
unrollLoopsInFunction(function, bound);
Expand All @@ -95,7 +130,6 @@ private void unrollLoopsInFunction(Function func, int defaultBound) {
}

private Map<CondJump, Integer> computeLoopBoundsMap(Function func, int defaultBound) {

LoopBound curBoundAnnotation = null;
final Map<CondJump, Integer> loopBoundsMap = new HashMap<>();
for (Event event : func.getEvents()) {
Expand All @@ -119,6 +153,15 @@ private Map<CondJump, Integer> computeLoopBoundsMap(Function func, int defaultBo
}
}
}

// Merge with loaded bounds if those exist.
if(globalLoopBoundsMap.containsKey(func)) {
final Map<CondJump, Integer> loopBoundsMapFromFile = globalLoopBoundsMap.get(func);
loopBoundsMapFromFile.forEach((key, value) -> loopBoundsMap.merge(key, value, Math::max));
}
// Remember bounds for function for dumping.
globalLoopBoundsMap.put(func, loopBoundsMap);

return loopBoundsMap;
}

Expand Down Expand Up @@ -146,6 +189,7 @@ private void unrollLoop(CondJump loopBackJump, int bound) {
boundEvent.getPredecessor().insertAfter(endOfLoopMarker);

boundEvent.copyAllMetadataFrom(loopBackJump);
boundEvent.setMetadata(new UnrollingBound(bound));
endOfLoopMarker.copyAllMetadataFrom(loopBackJump);

} else {
Expand Down Expand Up @@ -195,4 +239,82 @@ private Event newBoundEvent(Function func) {
return boundEvent;
}

// ------------------------------------------------------------------------
// Functions related to loading and storing bound maps

private boolean pathIsSpecified(String path) {
return !path.isEmpty();
}

public static int getPersistentLoopId(CondJump loopBackjump) {
final UnrollingId id = loopBackjump.getMetadata(UnrollingId.class);
return id != null ? id.value() : loopBackjump.getGlobalId();
}

public static int getUnrollingBoundAnnotation(CondJump boundEvent) {
Preconditions.checkArgument(boundEvent.hasTag(Tag.BOUND));
return boundEvent.getMetadata(UnrollingBound.class).value();
}

private Map<Function, Map<CondJump, Integer>> loadLoopBoundsMapFromFile(Program program, String filePath) {
if (!pathIsSpecified(filePath)) {
return new HashMap<>();
}
if (!Files.exists(Path.of(filePath))) {
logger.warn("There is no bounds file at path {} . Using default bounds.", filePath);
return new HashMap<>();
}

// Compute mapping from ids to loop events
final Map<Integer, CondJump> idToJump = new HashMap<>();
program.getFunctions().forEach(f -> f.getEvents(CondJump.class).forEach(
jump -> idToJump.put(getPersistentLoopId(jump), jump))
);

// Read CSV file to find bounds for loop events
final Map<Function, Map<CondJump, Integer>> loopBoundsMapPerFunction = new HashMap<>();
try (Reader reader = new FileReader(filePath)) {
Iterable<CSVRecord> records = CSVFormat.DEFAULT.parse(reader);
for (CSVRecord record : records) {
final int loopId = Integer.parseInt(record.get(0));
final int bound = Integer.parseInt(record.get(1));
final CondJump loopJump = idToJump.get(loopId);
if (loopJump == null) {
logger.warn("Loaded bounds file does not match with the program. Ignoring file.");
loopBoundsMapPerFunction.clear();
break;
}
loopBoundsMapPerFunction
.computeIfAbsent(loopJump.getFunction(), key -> new HashMap<>())
.put(loopJump, bound);
}
} catch (IOException e) {
logger.warn("Failed to read bounds file: {}", e.getLocalizedMessage());
}

return loopBoundsMapPerFunction;
}

private void dumpLoopBoundsMapToFile(Program program, Map<Function, Map<CondJump, Integer>> loopBounds, String filePath) {
if (!pathIsSpecified(filePath)) {
return;
}

final SyntacticContextAnalysis synContext = SyntacticContextAnalysis.newInstance(program);
try (CSVPrinter csvPrinter = new CSVPrinter( new FileWriter(filePath, false), CSVFormat.DEFAULT)) {
for (Map<CondJump, Integer> loopBoundsMap : loopBounds.values()) {
for (Map.Entry<CondJump, Integer> entry : loopBoundsMap.entrySet()) {
final CondJump loopJump = entry.getKey();
final int loopId = getPersistentLoopId(loopJump);
final int loopBound = entry.getValue();
final String sourceLoc = synContext.getSourceLocationWithContext(loopJump, false);
csvPrinter.printRecord(loopId, loopBound, sourceLoc);
}
}
csvPrinter.flush();
} catch (IOException e) {
logger.warn("Failed to save bounds file: {}", e.getLocalizedMessage());
}
}

}
2 changes: 2 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
<guava.version>32.1.2-jre</guava.version>
<junit.version>4.13.2</junit.version>
<log4j.version>2.23.0</log4j.version>
<maven-model.version>3.3.9</maven-model.version>
<commons-csv.version>1.12.0</commons-csv.version>
<mockito.version>5.11.0</mockito.version>
<rsyntaxtextarea.version>3.3.4</rsyntaxtextarea.version>

Expand Down
Loading