Skip to content

Commit

Permalink
fix formatting errors
Browse files Browse the repository at this point in the history
  • Loading branch information
changxu2 committed Aug 2, 2018
1 parent 3ed8ac5 commit 4360a47
Show file tree
Hide file tree
Showing 4 changed files with 218 additions and 188 deletions.
169 changes: 86 additions & 83 deletions honeybadgermpc/linearsolver.py
Original file line number Diff line number Diff line change
@@ -1,116 +1,119 @@

# compute the reduced-row echelon form of a matrix in place


def rref(matrix):
if not matrix: return
if not matrix:
return

numRows = len(matrix)
numCols = len(matrix[0])
numRows = len(matrix)
numCols = len(matrix[0])

i,j = 0,0
while True:
if i >= numRows or j >= numCols:
break
i, j = 0, 0
while True:
if i >= numRows or j >= numCols:
break

if matrix[i][j] == 0:
nonzeroRow = i
while nonzeroRow < numRows and matrix[nonzeroRow][j] == 0:
nonzeroRow += 1
if matrix[i][j] == 0:
nonzeroRow = i
while nonzeroRow < numRows and matrix[nonzeroRow][j] == 0:
nonzeroRow += 1

if nonzeroRow == numRows:
j += 1
continue
if nonzeroRow == numRows:
j += 1
continue

temp = matrix[i]
matrix[i] = matrix[nonzeroRow]
matrix[nonzeroRow] = temp
temp = matrix[i]
matrix[i] = matrix[nonzeroRow]
matrix[nonzeroRow] = temp

pivot = matrix[i][j]
matrix[i] = [x / pivot for x in matrix[i]]
pivot = matrix[i][j]
matrix[i] = [x / pivot for x in matrix[i]]

for otherRow in range(0, numRows):
if otherRow == i:
continue
if matrix[otherRow][j] != 0:
matrix[otherRow] = [y - matrix[otherRow][j]*x
for (x,y) in zip(matrix[i], matrix[otherRow])]
for otherRow in range(0, numRows):
if otherRow == i:
continue
if matrix[otherRow][j] != 0:
matrix[otherRow] = [y - matrix[otherRow][j] * x
for (x, y) in zip(matrix[i], matrix[otherRow])]

i += 1; j+= 1
i += 1
j += 1

return matrix
return matrix


# check if a row-reduced system has no solution
# if there is no solution, return (True, dont-care)
# if there is a solution, return (False, i) where i is the index of the last nonzero row
def noSolution(A):
i = -1
while all(x == 0 for x in A[i]):
i -= 1
i = -1
while all(x == 0 for x in A[i]):
i -= 1

lastNonzeroRow = A[i]
if all(x == 0 for x in lastNonzeroRow[:-1]):
return True, 0
lastNonzeroRow = A[i]
if all(x == 0 for x in lastNonzeroRow[:-1]):
return True, 0

return False, i
return False, i


# determine if the given column is a pivot column (contains all zeros except a single 1)
# and return the row index of the 1 if it exists
def isPivotColumn(A, j):
i = 0
while A[i][j] == 0 and i < len(A):
i += 1

if i == len(A):
return (False, i)
i = 0
while A[i][j] == 0 and i < len(A):
i += 1

if A[i][j] != 1:
return (False, i)
else:
pivotRow = i
if i == len(A):
return (False, i)

i += 1
while i < len(A):
if A[i][j] != 0:
return (False, pivotRow)
i += 1
if A[i][j] != 1:
return (False, i)
else:
pivotRow = i

return (True, pivotRow)
i += 1
while i < len(A):
if A[i][j] != 0:
return (False, pivotRow)
i += 1

return (True, pivotRow)


# return any solution of the system, with free variables set to the given value
def someSolution(system, freeVariableValue=1):
rref(system)

hasNoSolution, lastNonzeroRowIndex = noSolution(system)
if hasNoSolution:
raise Exception("No solution")

numVars = len(system[0]) - 1 # last row is constants
variableValues = [0] * numVars

freeVars = set()
pivotVars = set()
rowIndexToPivotColumnIndex = dict()
pivotRowIndex = dict()

for j in range(numVars):
isPivot, rowOfPivot = isPivotColumn(system, j)
if isPivot:
rowIndexToPivotColumnIndex[rowOfPivot] = j
pivotRowIndex[j] = rowOfPivot
pivotVars.add(j)
else:
freeVars.add(j)

for j in freeVars:
variableValues[j] = freeVariableValue

for j in pivotVars:
theRow = pivotRowIndex[j]
variableValues[j] = (system[theRow][-1] -
sum(system[theRow][i] * variableValues[i] for i in freeVars))

return variableValues

rref(system)

hasNoSolution, lastNonzeroRowIndex = noSolution(system)
if hasNoSolution:
raise Exception("No solution")

numVars = len(system[0]) - 1 # last row is constants
variableValues = [0] * numVars

freeVars = set()
pivotVars = set()
rowIndexToPivotColumnIndex = dict()
pivotRowIndex = dict()

for j in range(numVars):
isPivot, rowOfPivot = isPivotColumn(system, j)
if isPivot:
rowIndexToPivotColumnIndex[rowOfPivot] = j
pivotRowIndex[j] = rowOfPivot
pivotVars.add(j)
else:
freeVars.add(j)

for j in freeVars:
variableValues[j] = freeVariableValue

for j in pivotVars:
theRow = pivotRowIndex[j]
variableValues[j] = (system[theRow][-1] -
sum(system[theRow][i] *
variableValues[i] for i in freeVars))

return variableValues
36 changes: 24 additions & 12 deletions honeybadgermpc/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,42 +68,52 @@ def interpolate_fft(cls, ys, omega):
assert n & (n-1) == 0, "n must be power of two"
assert type(omega) is field
assert omega ** n == 1, "must be an n'th root of unity"
assert omega ** (n//2) != 1, "must be a primitive n'th root of unity"
assert omega ** (n //
2) != 1, "must be a primitive n'th root of unity"
coeffs = [b/n for b in fft_helper(ys, 1/omega, field)]
return cls(coeffs)

def evaluate_fft(self, omega, n):
assert n & (n-1) == 0, "n must be power of two"
assert type(omega) is field
assert omega ** n == 1, "must be an n'th root of unity"
assert omega ** (n//2) != 1, "must be a primitive n'th root of unity"
assert omega ** (n //
2) != 1, "must be a primitive n'th root of unity"
return fft(self, n, omega)

@classmethod
def random(cls, degree, y0=None):
coeffs = [field(random.randint(0, field.modulus-1)) for _ in range(degree+1)]
coeffs = [field(random.randint(0, field.modulus-1))
for _ in range(degree+1)]
if y0 is not None:
coeffs[0] = y0
return cls(coeffs)

def __abs__(self): return len(self.coeffs) # the valuation only gives 0 to the zero polynomial, i.e. 1+degree
# the valuation only gives 0 to the zero polynomial, i.e. 1+degree
def __abs__(self): return len(self.coeffs)

def __iter__(self): return iter(self.coeffs)

def __sub__(self, other): return self + (-other)

def __neg__(self): return Polynomial([-a for a in self])

def __len__(self): return len(self.coeffs)

def __add__(self, other):
newCoefficients = [sum(x) for x in zip_longest(self, other, fillvalue=self.field(0))]
newCoefficients = [sum(x) for x in zip_longest(
self, other, fillvalue=self.field(0))]
return Polynomial(newCoefficients)

def __mul__(self, other):
if self.isZero() or other.isZero():
return Zero()

newCoeffs = [self.field(0) for _ in range(len(self) + len(other) - 1)]
newCoeffs = [self.field(0)
for _ in range(len(self) + len(other) - 1)]

for i,a in enumerate(self):
for j,b in enumerate(other):
for i, a in enumerate(self):
for j, b in enumerate(other):
newCoeffs[i+j] += a*b
return Polynomial(newCoeffs)

Expand All @@ -118,8 +128,10 @@ def __divmod__(self, divisor):

while remainder.degree() >= divisorDeg:
monomialExponent = remainder.degree() - divisorDeg
monomialZeros = [self.field(0) for _ in range(monomialExponent)]
monomialDivisor = Polynomial(monomialZeros + [remainder.leadingCoefficient() / divisorLC])
monomialZeros = [self.field(0)
for _ in range(monomialExponent)]
monomialDivisor = Polynomial(
monomialZeros + [remainder.leadingCoefficient() / divisorLC])

quotient += monomialDivisor
remainder -= monomialDivisor * divisor
Expand All @@ -128,7 +140,7 @@ def __divmod__(self, divisor):
return quotient, remainder

def Zero():
return Polynomial([])
return Polynomial([])

_poly_cache[field] = Polynomial
return Polynomial
Expand Down Expand Up @@ -239,7 +251,7 @@ def test_correctness(poly, omega, fft_helper_result):
y = poly(omega**i)
c += 1
sys.stdout.write("%d / %d points verified!" % (c,
total_verification_points))
total_verification_points))
char = "\r" if c < len(sample) else "\n"
sys.stdout.write(char)
sys.stdout.flush()
Expand Down
Loading

0 comments on commit 4360a47

Please sign in to comment.