Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow bounds to be saved to and loaded from files #759

Merged
merged 14 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion dartagnan/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@
<dependency>
<groupId>org.apache.maven</groupId>
<artifactId>maven-model</artifactId>
<version>3.3.9</version>
<version>${maven-model.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-csv</artifactId>
<version>${commons-csv.version}</version>
</dependency>

<!-- Z3 dependency (OS independent) -->
Expand Down
50 changes: 50 additions & 0 deletions dartagnan/src/main/java/com/dat3m/dartagnan/Dartagnan.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import com.dat3m.dartagnan.program.event.core.Assert;
import com.dat3m.dartagnan.program.event.core.CondJump;
import com.dat3m.dartagnan.program.event.core.Load;
import com.dat3m.dartagnan.program.event.metadata.UnrollingBound;
import com.dat3m.dartagnan.program.event.metadata.UnrollingId;
import com.dat3m.dartagnan.utils.Result;
import com.dat3m.dartagnan.utils.Utils;
import com.dat3m.dartagnan.utils.options.BaseOptions;
Expand All @@ -34,6 +36,10 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.io.CharSource;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.maven.model.io.xpp3.MavenXpp3Reader;
Expand All @@ -51,13 +57,17 @@

import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.io.Writer;
import java.math.BigInteger;
import java.nio.file.Path;
import java.util.*;

import static com.dat3m.dartagnan.GlobalSettings.getOrCreateOutputDirectory;
import static com.dat3m.dartagnan.configuration.OptionInfo.collectOptions;
import static com.dat3m.dartagnan.configuration.OptionNames.BOUNDS_SAVE_PATH;
import static com.dat3m.dartagnan.configuration.OptionNames.PHANTOM_REFERENCES;
import static com.dat3m.dartagnan.configuration.OptionNames.TARGET;
import static com.dat3m.dartagnan.configuration.Property.*;
Expand Down Expand Up @@ -341,6 +351,7 @@ public static String generateResultSummary(VerificationTask task, ProverEnvironm
.append("\t")
.append(synContext.getSourceLocationWithContext(ev, true))
.append("\n");
increaseBoundAndDump(ev, task.getConfig());
}
}
summary.append("=================================================\n");
Expand Down Expand Up @@ -398,6 +409,45 @@ public static String generateResultSummary(VerificationTask task, ProverEnvironm
return summary.toString();
}

private static void increaseBoundAndDump(Event ev, Configuration config) {

String evId = String.valueOf(ev.getMetadata(UnrollingId.class).value());
String incBound = String.valueOf(ev.getMetadata(UnrollingBound.class).value() + 1);

// We read from and write to the same CSV file,
// thus we need to split this in two loops
List<String[]> modifiedRecords = new ArrayList<>();
// We read the file written by the LoopUnrolling pass,
// thus we use BOUNDS_SAVE_PATH also for the reader
try (Reader reader = new FileReader(config.hasProperty(BOUNDS_SAVE_PATH) ?
config.getProperty(BOUNDS_SAVE_PATH) :
GlobalSettings.getBoundsFile())) {
for (CSVRecord record : CSVFormat.DEFAULT.parse(reader)) {
ThomasHaas marked this conversation as resolved.
Show resolved Hide resolved
String nextId = record.get(0);
String nextBound = record.get(1);
String sourceLoc = record.get(2);
if (nextId.equals(evId)) {
nextBound = incBound;
}
modifiedRecords.add(new String[] { nextId, nextBound, sourceLoc });
}
} catch (IOException e) {
e.printStackTrace();
}

try (Writer writer = new FileWriter(config.hasProperty(BOUNDS_SAVE_PATH) ?
config.getProperty(BOUNDS_SAVE_PATH) :
GlobalSettings.getBoundsFile(), false);
CSVPrinter csvPrinter = new CSVPrinter(writer, CSVFormat.DEFAULT)) {
for (String[] record : modifiedRecords) {
csvPrinter.printRecord(record[0], record[1], record[2]);
}
csvPrinter.flush();
} catch (IOException e) {
e.printStackTrace();
}
}

private static void printWarningIfThreadStartFailed(Program p, EncodingContext encoder, ProverEnvironment prover)
throws SolverException {
for (Event e : p.getThreadEvents()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ public static String getCatDirectory() {
return env + "/cat";
}

public static String getBoundsFile() {
return getOutputDirectory() + "/bounds.csv";
}

hernanponcedeleon marked this conversation as resolved.
Show resolved Hide resolved
public static String getOrCreateOutputDirectory() throws IOException {
String path = getOutputDirectory();
Files.createDirectories(Paths.get(path));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ public class OptionNames {
// Base Options
public static final String PROPERTY = "property";
public static final String BOUND = "bound";
public static final String BOUNDS_LOAD_PATH = "bound.load";
public static final String BOUNDS_SAVE_PATH = "bound.save";
public static final String TARGET = "target";
public static final String METHOD = "method";
public static final String SOLVER = "solver";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.dat3m.dartagnan.program.event.metadata;

public record UnrollingBound(int value) implements Metadata { }
Original file line number Diff line number Diff line change
@@ -1,25 +1,38 @@
package com.dat3m.dartagnan.program.processing;

import com.dat3m.dartagnan.GlobalSettings;
import com.dat3m.dartagnan.expression.ExpressionFactory;
import com.dat3m.dartagnan.program.Function;
import com.dat3m.dartagnan.program.Program;
import com.dat3m.dartagnan.program.Thread;
import com.dat3m.dartagnan.program.analysis.SyntacticContextAnalysis;
import com.dat3m.dartagnan.program.event.Event;
import com.dat3m.dartagnan.program.event.EventFactory;
import com.dat3m.dartagnan.program.event.EventUser;
import com.dat3m.dartagnan.program.event.Tag;
import com.dat3m.dartagnan.program.event.core.CondJump;
import com.dat3m.dartagnan.program.event.core.Label;
import com.dat3m.dartagnan.program.event.lang.svcomp.LoopBound;
import com.dat3m.dartagnan.program.event.metadata.UnrollingBound;
import com.dat3m.dartagnan.program.event.metadata.UnrollingId;
import com.google.common.base.Preconditions;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.sosy_lab.common.configuration.*;

import org.apache.commons.csv.CSVFormat;
import org.apache.commons.csv.CSVPrinter;
import org.apache.commons.csv.CSVRecord;

import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.io.Reader;
import java.io.Writer;
import java.util.*;
import java.util.function.Predicate;

import static com.dat3m.dartagnan.configuration.OptionNames.BOUND;
import static com.dat3m.dartagnan.configuration.OptionNames.*;

@Options
public class LoopUnrolling implements ProgramProcessor {
Expand All @@ -45,6 +58,16 @@ public void setUnrollingBound(int bound) {
this.bound = bound;
}

@Option(name = BOUNDS_LOAD_PATH,
description = "Path to the CSV file containing loop bounds.",
secure = true)
private String bounds_load_path = "";

@Option(name = BOUNDS_SAVE_PATH,
description = "Path to the CSV file to save loop bounds.",
secure = true)
private String bounds_save_path = GlobalSettings.getBoundsFile();

// =====================================================================

private LoopUnrolling() { }
Expand All @@ -69,6 +92,7 @@ public void run(Program program) {
logger.warn("Skipped unrolling: Program is already unrolled.");
return;
}
createBoundsFileIfMissing();
final int defaultBound = this.bound;
program.getFunctions().forEach(this::run);
program.getThreads().forEach(this::run);
Expand All @@ -78,7 +102,6 @@ public void run(Program program) {
logger.info("Program unrolled {} times", defaultBound);
}


private void run(Function function) {
function.getEvents().forEach(e -> e.setMetadata(new UnrollingId(e.getGlobalId()))); // Track ids before unrolling
unrollLoopsInFunction(function, bound);
Expand All @@ -89,13 +112,36 @@ private void unrollLoopsInFunction(Function func, int defaultBound) {
return;
}
final Map<CondJump, Integer> loopBoundsMap = computeLoopBoundsMap(func, defaultBound);
Map<CondJump, Integer> mergedBounds = new HashMap<>(loopBoundsMap);
if(!bounds_load_path.isEmpty()) {
final Map<CondJump, Integer> loopBoundsMapFromFile = loadLoopBoundsMapFromFile(func);
loopBoundsMapFromFile.forEach((key, value) -> mergedBounds.merge(key, value, Math::max));
}
func.getEvents(CondJump.class).stream()
.filter(loopBoundsMap::containsKey)
.forEach(j -> unrollLoop(j, loopBoundsMap.get(j)));
.filter(mergedBounds::containsKey)
.forEach(j -> unrollLoop(j, mergedBounds.get(j)));
}

private Map<CondJump, Integer> loadLoopBoundsMapFromFile(Function func) {
Map<CondJump, Integer> loopBoundsMapFromFile = new HashMap<>();
try (Reader reader = new FileReader(bounds_load_path)) {
Iterable<CSVRecord> records = CSVFormat.DEFAULT.parse(reader);
for (CSVRecord record : records) {
int nexId = Integer.parseInt(record.get(0));
int nextBound = Integer.parseInt(record.get(1));
Predicate<Event> predicate = e -> e.getGlobalId() == nexId;
if(func.getEvents(CondJump.class).stream().anyMatch(predicate)) {
CondJump loop = func.getEvents(CondJump.class).stream().filter(predicate).findAny().get();
loopBoundsMapFromFile.put(loop, nextBound);
}
}
} catch (IOException e) {
e.printStackTrace();
}
return loopBoundsMapFromFile;
}

private Map<CondJump, Integer> computeLoopBoundsMap(Function func, int defaultBound) {

LoopBound curBoundAnnotation = null;
final Map<CondJump, Integer> loopBoundsMap = new HashMap<>();
for (Event event : func.getEvents()) {
Expand Down Expand Up @@ -128,6 +174,7 @@ private void unrollLoop(CondJump loopBackJump, int bound) {
Preconditions.checkArgument(loopBegin.getLocalId() < loopBackJump.getLocalId(),
"The jump does not belong to a loop.");

dumpBoundIfMissing(loopBackJump, bound);
int iterCounter = 0;
while (++iterCounter <= bound) {
if (iterCounter == bound) {
Expand All @@ -146,6 +193,7 @@ private void unrollLoop(CondJump loopBackJump, int bound) {
boundEvent.getPredecessor().insertAfter(endOfLoopMarker);

boundEvent.copyAllMetadataFrom(loopBackJump);
boundEvent.setMetadata(new UnrollingBound(bound));
endOfLoopMarker.copyAllMetadataFrom(loopBackJump);

} else {
Expand Down Expand Up @@ -195,4 +243,38 @@ private Event newBoundEvent(Function func) {
return boundEvent;
}

private void createBoundsFileIfMissing() {
File file = new File(bounds_save_path);
if (!file.exists()) {
try {
file.createNewFile();
} catch (IOException e) {
e.printStackTrace();
}
}
}

private void dumpBoundIfMissing(Event jump, Integer bound) {
String evId = String.valueOf(jump.getMetadata(UnrollingId.class).value());
final SyntacticContextAnalysis synContext = SyntacticContextAnalysis.newInstance(jump.getFunction().getProgram());
String sourceLoc = synContext.getSourceLocationWithContext(jump, false);
try (Reader reader = new FileReader(bounds_load_path);
Writer writer = new FileWriter(bounds_save_path, true);
CSVPrinter csvPrinter = new CSVPrinter(writer, CSVFormat.DEFAULT)) {
boolean found = false;
for (CSVRecord record : CSVFormat.DEFAULT.parse(reader)) {
String nextId = record.get(0);
if (found = nextId.equals(evId)) {
break;
}
}
if (!found) {
csvPrinter.printRecord(evId, bound.toString(), sourceLoc);
}
csvPrinter.flush();
} catch (IOException e) {
e.printStackTrace();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure that you just want to print the stack trace and then keep going? The code will just keep running with an incomplete file and will likely trigger the same error a few more times since the file is accessed multiple times (this commenet applies to all the cases where you do this).
I would rather throw a terminating exception or log a warning (only once!) and reset whatever has been computed so far as if no file was provided (e.g., when computing the bounds).

}

}
2 changes: 2 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
<guava.version>32.1.2-jre</guava.version>
<junit.version>4.13.2</junit.version>
<log4j.version>2.23.0</log4j.version>
<maven-model.version>3.3.9</maven-model.version>
<commons-csv.version>1.12.0</commons-csv.version>
<mockito.version>5.11.0</mockito.version>
<rsyntaxtextarea.version>3.3.4</rsyntaxtextarea.version>

Expand Down
Loading