-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introducing maraboupy: a python module that wraps Marabou and enables easy interaction with Tensorflow networks and .nnet files. * Python Wrapper -Abstract classes needed for maraboupy -TF Parser -NNet Parser -Example scripts -Example Notebooks -Example TF network proto buffers. * Refactored addReluConstraint into MarabouCore * Left all original Marabou code untouched and moved all needed modifications to maraboucore.cpp in maraboupy. * Maraboupy now uses the same Rules.mk as Marabou and uses its own Makefile for its flags.
- Loading branch information
Showing
67 changed files
with
2,709 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
ROOT_DIR = ../ | ||
|
||
SUBDIRS += \ | ||
|
||
PYBIND11_INCLUDES = $(shell python3 -m pybind11 --includes) | ||
|
||
LOCAL_INCLUDES += \ | ||
.. \ | ||
$(BASIS_FACTORIZATION_DIR) \ | ||
$(CONFIGURATION_DIR) \ | ||
$(ENGINE_DIR) \ | ||
$(PYBIND11_INCLUDES) \ | ||
|
||
LINK_FLAGS += \ | ||
|
||
LOCAL_LIBRARIES += \ | ||
|
||
CFLAGS += \ | ||
-DDEBUG_ON \ | ||
|
||
SUFFIX = $(shell python3-config --extension-suffix) | ||
API_NAME = $(addprefix MarabouCore, $(SUFFIX)) | ||
|
||
SOURCES += \ | ||
GlobalConfiguration.cpp \ | ||
\ | ||
Errno.cpp \ | ||
Error.cpp \ | ||
FloatUtils.cpp \ | ||
MString.cpp \ | ||
TimeUtils.cpp \ | ||
\ | ||
BasisFactorizationFactory.cpp \ | ||
BlandsRule.cpp \ | ||
ConstraintMatrixAnalyzer.cpp \ | ||
CostFunctionManager.cpp \ | ||
CostFunctionManagerFactory.cpp \ | ||
DantzigsRule.cpp \ | ||
DegradationChecker.cpp \ | ||
Engine.cpp \ | ||
EngineState.cpp \ | ||
EntrySelectionStrategy.cpp \ | ||
Equation.cpp \ | ||
EtaMatrix.cpp \ | ||
ForrestTomlinFactorization.cpp \ | ||
FreshVariables.cpp \ | ||
InputQuery.cpp \ | ||
LPElement.cpp \ | ||
LUFactorization.cpp \ | ||
PermutationMatrix.cpp \ | ||
PiecewiseLinearCaseSplit.cpp \ | ||
PiecewiseLinearConstraint.cpp \ | ||
PrecisionRestorer.cpp \ | ||
Preprocessor.cpp \ | ||
ProjectedSteepestEdge.cpp \ | ||
ProjectedSteepestEdgeFactory.cpp \ | ||
ReluConstraint.cpp \ | ||
RowBoundTightener.cpp \ | ||
RowBoundTightenerFactory.cpp \ | ||
SmtCore.cpp \ | ||
Statistics.cpp \ | ||
Tableau.cpp \ | ||
TableauFactory.cpp \ | ||
TableauRow.cpp \ | ||
TableauState.cpp \ | ||
\ | ||
MarabouCore.cpp \ | ||
|
||
TARGET = $(API_NAME) | ||
|
||
include ../Rules.mk | ||
|
||
COMPILE = c++ -std=c++11 -O3 -fPIC $(PYBIND11_INCLUDES) | ||
LINK = c++ -std=c++11 -O3 -fPIC -shared -Wl,-undefined,dynamic_lookup $(PYBIND11_INCLUDES) | ||
|
||
$(TARGET): $(OBJECTS) | ||
@echo "LD\t" $@ | ||
@$(LINK) $(LINK_FLAGS) -o $@ $^ $(addprefix -l, $(SYSTEM_LIBRARIES)) $(addprefix -l, $(LOCAL_LIBRARIES)) | ||
|
||
vpath %.cpp $(BASIS_FACTORIZATION_DIR) | ||
vpath %.cpp $(CONFIGURATION_DIR) | ||
vpath %.cpp $(ENGINE_DIR) | ||
vpath %.cpp $(ENGINE_REAL_DIR) | ||
vpath %.cpp $(COMMON_DIR) | ||
vpath %.cpp $(COMMON_REAL_DIR) | ||
|
||
# | ||
# Local Variables: | ||
# compile-command: "make -C ../../.. " | ||
# tags-file-name: "../../../TAGS" | ||
# c-basic-offset: 4 | ||
# End: | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#Marabou File | ||
from .MarabouNetworkNNet import * | ||
from .MarabouNetworkTF import * | ||
|
||
def read_nnet(filename): | ||
""" | ||
Constructs a MarabouNetworkNnet object from a .nnet file | ||
Args: | ||
filename: (string) path to the .nnet file. | ||
Returns: | ||
marabouNetworkNNet: (MarabouNetworkNNet) representing network | ||
""" | ||
return MarabouNetworkNNet(filename) | ||
|
||
|
||
def read_tf(filename, inputName=None, outputName=None): | ||
""" | ||
Constructs a MarabouNetworkTF object from a frozen Tensorflow protobuf | ||
Args: | ||
filename: (string) path to the .nnet file. | ||
inputName: (string) optional, name of operation corresponding to input | ||
outputName: (string) optional, name of operation corresponding to output | ||
Returns: | ||
marabouNetworkTF: (MarabouNetworkTF) representing network | ||
""" | ||
return MarabouNetworkTF(filename) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
#include <map> | ||
#include <vector> | ||
#include <string> | ||
#include <unistd.h> | ||
#include <sys/stat.h> | ||
#include <sys/types.h> | ||
#include <fcntl.h> | ||
#include "Engine.h" | ||
#include "InputQuery.h" | ||
#include "ReluplexError.h" | ||
#include "FloatUtils.h" | ||
#include "PiecewiseLinearConstraint.h" | ||
#include "ReluConstraint.h" | ||
|
||
namespace py = pybind11; | ||
|
||
int redirectOutputToFile(std::string outputFilePath){ | ||
// Redirect standard output to a file | ||
int outputFile = open(outputFilePath.c_str(), O_WRONLY | O_CREAT | O_TRUNC, 0644); | ||
if ( outputFile < 0 ) | ||
{ | ||
printf( "Error redirecting output to file\n"); | ||
exit( 1 ); | ||
} | ||
|
||
int outputStream = dup( STDOUT_FILENO ); | ||
if (outputStream < 0) | ||
{ | ||
printf( "Error duplicating standard output\n" ); | ||
exit(1); | ||
} | ||
|
||
if ( dup2( outputFile, STDOUT_FILENO ) < 0 ) | ||
{ | ||
printf("Error duplicating to standard output\n"); | ||
exit(1); | ||
} | ||
|
||
close( outputFile ); | ||
return outputStream; | ||
} | ||
|
||
void restoreOutputStream(int outputStream) | ||
{ | ||
// Restore standard output | ||
fflush( stdout ); | ||
if (dup2( outputStream, STDOUT_FILENO ) < 0){ | ||
printf( "Error restoring output stream\n" ); | ||
exit( 1 ); | ||
} | ||
close(outputStream); | ||
} | ||
|
||
void addReluConstraint(InputQuery& ipq, unsigned var1, unsigned var2){ | ||
PiecewiseLinearConstraint* r = new ReluConstraint(var1, var2); | ||
ipq.addPiecewiseLinearConstraint(r); | ||
} | ||
|
||
std::map<int, double> solve(InputQuery inputQuery, std::string redirect=""){ | ||
// Arguments: InputQuery object, filename to redirect output | ||
// Returns: map from variable number to value | ||
std::map<int, double> ret; | ||
int output=-1; | ||
if(redirect.length()>0) | ||
output=redirectOutputToFile(redirect); | ||
try{ | ||
Engine engine; | ||
if(!engine.processInputQuery(inputQuery)) return ret; | ||
|
||
if(!engine.solve()) return ret; | ||
|
||
engine.extractSolution(inputQuery); | ||
for(unsigned int i=0; i<inputQuery.getNumberOfVariables(); i++) | ||
ret[i] = inputQuery.getSolutionValue(i); | ||
} | ||
catch(const ReluplexError &e){ | ||
printf( "Caught a ReluplexError. Code: %u. Message: %s\n", e.getCode(), e.getUserMessage() ); | ||
return ret; | ||
} | ||
if(output != -1) | ||
restoreOutputStream(output); | ||
return ret; | ||
} | ||
|
||
// Code necessary to generate Python library | ||
// Describes which classes and functions are exposed to API | ||
PYBIND11_MODULE(MarabouCore, m) { | ||
m.doc() = "Marabou API Library"; | ||
m.def("solve", &solve, "Takes in a description of the InputQuery and returns the solution"); | ||
m.def("addReluConstraint", &addReluConstraint, "Add a Relu constraint to the InputQuery"); | ||
py::class_<InputQuery>(m, "InputQuery") | ||
.def(py::init()) | ||
.def("setUpperBound", &InputQuery::setUpperBound) | ||
.def("setLowerBound", &InputQuery::setLowerBound) | ||
.def("getUpperBound", &InputQuery::getUpperBound) | ||
.def("getLowerBound", &InputQuery::getLowerBound) | ||
.def("setNumberOfVariables", &InputQuery::setNumberOfVariables) | ||
.def("addEquation", &InputQuery::addEquation); | ||
py::class_<Equation>(m, "Equation") | ||
.def(py::init()) | ||
.def("addAddend", &Equation::addAddend) | ||
.def("setScalar", &Equation::setScalar) | ||
.def("markAuxiliaryVariable", &Equation::markAuxiliaryVariable); | ||
} |
Oops, something went wrong.