From 2fcbce7c77cb10b7f4a076723fb0b1c95d789591 Mon Sep 17 00:00:00 2001 From: ShantanuThakoor Date: Sat, 28 Apr 2018 15:35:56 -0700 Subject: [PATCH] Bias add improvements (#23) * Initial commit * Fixed bugs * Big fixes * Slight bug fix --- maraboupy/MarabouNetwork.py | 32 +++++++++++++++++++++++ maraboupy/MarabouNetworkTF.py | 49 +++++++++++++++++++++++++++-------- maraboupy/MarabouUtils.py | 35 +++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 11 deletions(-) diff --git a/maraboupy/MarabouNetwork.py b/maraboupy/MarabouNetwork.py index e20558dfc..7a9a33dcb 100644 --- a/maraboupy/MarabouNetwork.py +++ b/maraboupy/MarabouNetwork.py @@ -79,6 +79,38 @@ def addMaxConstraint(self, elements, v): """ self.maxList += [(elements, v)] + def lowerBoundExists(self, x): + """ + Function to check whether lower bound for a variable is known + Arguments: + x: (int) variable to check + """ + return x in self.lowerBounds + + def upperBoundExists(self, x): + """ + Function to check whether upper bound for a variable is known + Arguments: + x: (int) variable to check + """ + return x in self.upperBounds + + def participatesInPLConstraint(self, x): + """ + Function to check whether variable participates in any piecewise linear constraint in this network + Arguments: + x: (int) variable to check + """ + # ReLUs + fs, bs = zip(*self.reluList) + if x in fs or x in bs: + return True + # Max constraints + for elems, var in self.maxList: + if x in elems or x==var: + return True + return False + def getMarabouQuery(self): """ Function to convert network into Marabou Query diff --git a/maraboupy/MarabouNetworkTF.py b/maraboupy/MarabouNetworkTF.py index aff34a280..21e041995 100644 --- a/maraboupy/MarabouNetworkTF.py +++ b/maraboupy/MarabouNetworkTF.py @@ -24,7 +24,9 @@ def __init__(self, filename, inputName=None, outputName=None, savedModel=False, savedModelTags: (list of strings) If loading a SavedModel, the user must specify tags used. """ super().__init__() + self.biasAddRelations = list() self.readFromPb(filename, inputName, outputName, savedModel, savedModelTags) + self.processBiasAddRelations() def clear(self): """ @@ -36,6 +38,7 @@ def clear(self): self.inputOp = None self.outputOp = None self.sess = None + self.biasAddRelations = list() def readFromPb(self, filename, inputName, outputName, savedModel, savedModelTags): """ @@ -237,17 +240,41 @@ def biasAddEquations(self, op): assert len(prevVars)==len(curVars) and len(curVars)==len(prevConsts) ### END getting inputs ### - ### Generate actual equations ### - for i in range(len(curVars)): - e = MarabouUtils.Equation() - e.addAddend(1, prevVars[i]) - e.addAddend(-1, curVars[i]) - e.setScalar(-prevConsts[i]) - aux = self.getNewVariable() - self.setLowerBound(aux, 0.0) - self.setUpperBound(aux, 0.0) - e.markAuxiliaryVariable(aux) - self.addEquation(e) + ### Do not generate equations, as these can be eliminated ### + for i in range(len(prevVars)): + # prevVars = curVars - prevConst + self.biasAddRelations += [(prevVars[i], curVars[i], -prevConsts[i])] + + def processBiasAddRelations(self): + """ + Either add an equation representing a bias add, + Or eliminate one of the two variables in every other relation + """ + biasAddUpdates = dict() + for (x, xprime, c) in self.biasAddRelations: + # x = xprime + c + # replace x only if it does not occur anywhere else in the system + if self.lowerBoundExists(x) or self.upperBoundExists(x) or self.participatesInPLConstraint(x): + e = MarabouUtils.Equation() + e.addAddend(1.0, x) + e.addAddend(-1.0, xprime) + e.setScalar(c) + aux = self.getNewVariable() + self.setLowerBound(aux, 0.0) + self.setUpperBound(aux, 0.0) + e.markAuxiliaryVariable(aux) + self.addEquation(e) + else: + biasAddUpdates[x] = (xprime, c) + self.setLowerBound(x, 0.0) + self.setUpperBound(x, 0.0) + + for equ in self.equList: + participating = equ.getParticipatingVariables() + for x in participating: + if x in biasAddUpdates: # if a variable to remove is part of this equation + xprime, c = biasAddUpdates[x] + equ.replaceVariable(x, xprime, c) def conv2DEquations(self, op): """ diff --git a/maraboupy/MarabouUtils.py b/maraboupy/MarabouUtils.py index a707a52f7..a353b2bfe 100644 --- a/maraboupy/MarabouUtils.py +++ b/maraboupy/MarabouUtils.py @@ -7,6 +7,7 @@ def __init__(self): Construct empty equation """ self.addendList = [] + self.participatingVariables = set() self.auxVar = None self.scalar = None @@ -26,6 +27,7 @@ def addAddend(self, c, x): x: (int) variable number of variable in addend """ self.addendList += [(c, x)] + self.participatingVariables.update([x]) def markAuxiliaryVariable(self, aux): """ @@ -35,6 +37,39 @@ def markAuxiliaryVariable(self, aux): """ self.auxVar = aux + def getParticipatingVariables(self): + """ + Returns set of variables participating in this equation + """ + return self.participatingVariables + + def participatingVariable(self, var): + """ + Check if the variable participates in this equation + Arguments: + var: (int) variable number to check + """ + return var in self.getParticipatingVariables() + + def replaceVariable(self, x, xprime, c): + """ + Replace x with xprime + c + Arguments: + x: (int) old variable to be replaced in this equation + xprime: (int) new variable to be added, does not participate in this equation + c: (float) difference between old and new variable + """ + assert self.participatingVariable(x) + assert not self.participatingVariable(xprime) + assert self.auxVar != x and self.auxVar != xprime + for i in range(len(self.addendList)): + if self.addendList[i][1] == x: + coeff = self.addendList[i][0] + self.addendList[i] = (coeff, xprime) + self.setScalar(self.scalar - coeff*c) + self.participatingVariables.remove(x) + self.participatingVariables.update([xprime]) + def addEquality(network, vars, coeffs, scalar): """ Function to conveniently add equality constraint to network