Skip to content

Commit

Permalink
Bias add improvements (#23)
Browse files Browse the repository at this point in the history
* Initial commit

* Fixed bugs

* Big fixes

* Slight bug fix
  • Loading branch information
ShantanuThakoor authored Apr 28, 2018
1 parent 54ef90b commit 2fcbce7
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 11 deletions.
32 changes: 32 additions & 0 deletions maraboupy/MarabouNetwork.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,38 @@ def addMaxConstraint(self, elements, v):
"""
self.maxList += [(elements, v)]

def lowerBoundExists(self, x):
"""
Function to check whether lower bound for a variable is known
Arguments:
x: (int) variable to check
"""
return x in self.lowerBounds

def upperBoundExists(self, x):
"""
Function to check whether upper bound for a variable is known
Arguments:
x: (int) variable to check
"""
return x in self.upperBounds

def participatesInPLConstraint(self, x):
"""
Function to check whether variable participates in any piecewise linear constraint in this network
Arguments:
x: (int) variable to check
"""
# ReLUs
fs, bs = zip(*self.reluList)
if x in fs or x in bs:
return True
# Max constraints
for elems, var in self.maxList:
if x in elems or x==var:
return True
return False

def getMarabouQuery(self):
"""
Function to convert network into Marabou Query
Expand Down
49 changes: 38 additions & 11 deletions maraboupy/MarabouNetworkTF.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ def __init__(self, filename, inputName=None, outputName=None, savedModel=False,
savedModelTags: (list of strings) If loading a SavedModel, the user must specify tags used.
"""
super().__init__()
self.biasAddRelations = list()
self.readFromPb(filename, inputName, outputName, savedModel, savedModelTags)
self.processBiasAddRelations()

def clear(self):
"""
Expand All @@ -36,6 +38,7 @@ def clear(self):
self.inputOp = None
self.outputOp = None
self.sess = None
self.biasAddRelations = list()

def readFromPb(self, filename, inputName, outputName, savedModel, savedModelTags):
"""
Expand Down Expand Up @@ -237,17 +240,41 @@ def biasAddEquations(self, op):
assert len(prevVars)==len(curVars) and len(curVars)==len(prevConsts)
### END getting inputs ###

### Generate actual equations ###
for i in range(len(curVars)):
e = MarabouUtils.Equation()
e.addAddend(1, prevVars[i])
e.addAddend(-1, curVars[i])
e.setScalar(-prevConsts[i])
aux = self.getNewVariable()
self.setLowerBound(aux, 0.0)
self.setUpperBound(aux, 0.0)
e.markAuxiliaryVariable(aux)
self.addEquation(e)
### Do not generate equations, as these can be eliminated ###
for i in range(len(prevVars)):
# prevVars = curVars - prevConst
self.biasAddRelations += [(prevVars[i], curVars[i], -prevConsts[i])]

def processBiasAddRelations(self):
"""
Either add an equation representing a bias add,
Or eliminate one of the two variables in every other relation
"""
biasAddUpdates = dict()
for (x, xprime, c) in self.biasAddRelations:
# x = xprime + c
# replace x only if it does not occur anywhere else in the system
if self.lowerBoundExists(x) or self.upperBoundExists(x) or self.participatesInPLConstraint(x):
e = MarabouUtils.Equation()
e.addAddend(1.0, x)
e.addAddend(-1.0, xprime)
e.setScalar(c)
aux = self.getNewVariable()
self.setLowerBound(aux, 0.0)
self.setUpperBound(aux, 0.0)
e.markAuxiliaryVariable(aux)
self.addEquation(e)
else:
biasAddUpdates[x] = (xprime, c)
self.setLowerBound(x, 0.0)
self.setUpperBound(x, 0.0)

for equ in self.equList:
participating = equ.getParticipatingVariables()
for x in participating:
if x in biasAddUpdates: # if a variable to remove is part of this equation
xprime, c = biasAddUpdates[x]
equ.replaceVariable(x, xprime, c)

def conv2DEquations(self, op):
"""
Expand Down
35 changes: 35 additions & 0 deletions maraboupy/MarabouUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ def __init__(self):
Construct empty equation
"""
self.addendList = []
self.participatingVariables = set()
self.auxVar = None
self.scalar = None

Expand All @@ -26,6 +27,7 @@ def addAddend(self, c, x):
x: (int) variable number of variable in addend
"""
self.addendList += [(c, x)]
self.participatingVariables.update([x])

def markAuxiliaryVariable(self, aux):
"""
Expand All @@ -35,6 +37,39 @@ def markAuxiliaryVariable(self, aux):
"""
self.auxVar = aux

def getParticipatingVariables(self):
"""
Returns set of variables participating in this equation
"""
return self.participatingVariables

def participatingVariable(self, var):
"""
Check if the variable participates in this equation
Arguments:
var: (int) variable number to check
"""
return var in self.getParticipatingVariables()

def replaceVariable(self, x, xprime, c):
"""
Replace x with xprime + c
Arguments:
x: (int) old variable to be replaced in this equation
xprime: (int) new variable to be added, does not participate in this equation
c: (float) difference between old and new variable
"""
assert self.participatingVariable(x)
assert not self.participatingVariable(xprime)
assert self.auxVar != x and self.auxVar != xprime
for i in range(len(self.addendList)):
if self.addendList[i][1] == x:
coeff = self.addendList[i][0]
self.addendList[i] = (coeff, xprime)
self.setScalar(self.scalar - coeff*c)
self.participatingVariables.remove(x)
self.participatingVariables.update([xprime])

def addEquality(network, vars, coeffs, scalar):
"""
Function to conveniently add equality constraint to network
Expand Down

0 comments on commit 2fcbce7

Please sign in to comment.