Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

w&b error correction interpolation algorithm implemented #4

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions honeybadgermpc/batch_reconstruction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# from honeybadgermpc.wb_interpolate import makeEncoderDecoder
from honeybadgermpc.wb_interpolate import decoding_message_with_none_elements
from honeybadgermpc.field import GF
from honeybadgermpc.polynomial import polynomialsOver
# from honeybadgermpc.test-asyncio-simplerouter import simple_router
# import asyncio

"""
batch_reconstruction function
input:
shared_secrets: an array of points representing shared secrets S1 - St+1
p: prime number used in the field
t: degree t polynomial
n: total number of nodes n=3t+1
id: id of the specific node running batch_reconstruction function
"""


async def batch_reconstruction(shared_secrets, p, t, n, myid, send, recv):
print("my id %d" % myid)
print(shared_secrets)
# construct the first polynomial f(x,i) = [S1]ti + [S2]ti x + … [St+1]ti xt
Fp = GF(p)
Poly = polynomialsOver(Fp)
tmp_poly = Poly(shared_secrets)

# Evaluate and send f(j,i) for each other participating party Pj
for i in range(n):
send(i, [Fp(myid+1), tmp_poly(Fp(i+1))])

# Interpolate the polynomial, but we don't need to wait for getting all the values, we can start with 2t+1 values
tmp_gathered_results = []
for j in range(n):
# TODO: can we assume that if received, the values are non-none?
(i, o) = await recv()
print("{} gets {} from {}".format(myid, o, i))
tmp_gathered_results.append(o)
if t == 1:
start_interpolation = j
else:
start_interpolation = j + 1
if start_interpolation >= (2*t + 1):
print("{} is in first interpolation".format(myid))
print(tmp_gathered_results)
# interpolate with error correction to get f(j,y)
Solved, P1 = decoding_message_with_none_elements(t, tmp_gathered_results, p)
if Solved:
break

# Evaluate and send f(j,y) for each other participating party Pj
for i in range(n):
send(i, [myid + 1, P1.coeffs[0]])

# Interpolate the polynomial to get f(x,0)
tmp_gathered_results2 = []
for j in range(n):
# TODO: can we assume that here the received values are non-none?
(i, o) = await recv()
print("{} gets {} from {}".format(myid, o, i))
tmp_gathered_results2.append(o)
if t == 1:
start_interpolation = j
else:
start_interpolation = j + 1
if start_interpolation >= (2*t + 1):
# interpolate with error correction to get f(x,0)
print("{} is in second interpolation".format(myid))
Solved, P2 = decoding_message_with_none_elements(t, tmp_gathered_results2, p)
if Solved:
break

# return the result
if Solved:
print("I am {} and the secret polynomial is {}".format(myid, P2))
return Solved, P2
else:
print("I am {} and I failed decoding the shared secrets".format(myid))
return Solved, None
119 changes: 119 additions & 0 deletions honeybadgermpc/linearsolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@

# compute the reduced-row echelon form of a matrix in place


def rref(matrix):
if not matrix:
return

numRows = len(matrix)
numCols = len(matrix[0])

i, j = 0, 0
while True:
if i >= numRows or j >= numCols:
break

if matrix[i][j] == 0:
nonzeroRow = i
while nonzeroRow < numRows and matrix[nonzeroRow][j] == 0:
nonzeroRow += 1

if nonzeroRow == numRows:
j += 1
continue

temp = matrix[i]
matrix[i] = matrix[nonzeroRow]
matrix[nonzeroRow] = temp

pivot = matrix[i][j]
matrix[i] = [x / pivot for x in matrix[i]]

for otherRow in range(0, numRows):
if otherRow == i:
continue
if matrix[otherRow][j] != 0:
matrix[otherRow] = [y - matrix[otherRow][j] * x
for (x, y) in zip(matrix[i], matrix[otherRow])]

i += 1
j += 1

return matrix


# check if a row-reduced system has no solution
# if there is no solution, return (True, dont-care)
# if there is a solution, return (False, i) where i is the index of the last nonzero row
def noSolution(A):
i = -1
while all(x == 0 for x in A[i]):
i -= 1

lastNonzeroRow = A[i]
if all(x == 0 for x in lastNonzeroRow[:-1]):
return True, 0

return False, i


# determine if the given column is a pivot column (contains all zeros except a single 1)
# and return the row index of the 1 if it exists
def isPivotColumn(A, j):
i = 0
while A[i][j] == 0 and i < len(A):
i += 1

if i == len(A):
return (False, i)

if A[i][j] != 1:
return (False, i)
else:
pivotRow = i

i += 1
while i < len(A):
if A[i][j] != 0:
return (False, pivotRow)
i += 1

return (True, pivotRow)


# return any solution of the system, with free variables set to the given value
def someSolution(system, freeVariableValue=1):
rref(system)

hasNoSolution, lastNonzeroRowIndex = noSolution(system)
if hasNoSolution:
raise Exception("No solution")

numVars = len(system[0]) - 1 # last row is constants
variableValues = [0] * numVars

freeVars = set()
pivotVars = set()
rowIndexToPivotColumnIndex = dict()
pivotRowIndex = dict()

for j in range(numVars):
isPivot, rowOfPivot = isPivotColumn(system, j)
if isPivot:
rowIndexToPivotColumnIndex[rowOfPivot] = j
pivotRowIndex[j] = rowOfPivot
pivotVars.add(j)
else:
freeVars.add(j)

for j in freeVars:
variableValues[j] = freeVariableValue

for j in pivotVars:
theRow = pivotRowIndex[j]
variableValues[j] = (system[theRow][-1] -
sum(system[theRow][i] *
variableValues[i] for i in freeVars))

return variableValues
70 changes: 65 additions & 5 deletions honeybadgermpc/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from functools import reduce
import sys
import time
from itertools import zip_longest


def strip_trailing_zeros(a):
if len(a) == 0:
return []
for i in range(len(a), 0, -1):
if a[i-1] != 0:
break
Expand All @@ -24,7 +27,8 @@ def __init__(self, coeffs):
self.coeffs = strip_trailing_zeros(coeffs)
self.field = field

def isZero(self): return self.coeffs == []
def isZero(self):
return self.coeffs == [] or (len(self.coeffs) == 1 and self.coeffs[0] == 0)

def __repr__(self):
if self.isZero():
Expand Down Expand Up @@ -64,24 +68,80 @@ def interpolate_fft(cls, ys, omega):
assert n & (n-1) == 0, "n must be power of two"
assert type(omega) is field
assert omega ** n == 1, "must be an n'th root of unity"
assert omega ** (n//2) != 1, "must be a primitive n'th root of unity"
assert omega ** (n //
2) != 1, "must be a primitive n'th root of unity"
coeffs = [b/n for b in fft_helper(ys, 1/omega, field)]
return cls(coeffs)

def evaluate_fft(self, omega, n):
assert n & (n-1) == 0, "n must be power of two"
assert type(omega) is field
assert omega ** n == 1, "must be an n'th root of unity"
assert omega ** (n//2) != 1, "must be a primitive n'th root of unity"
assert omega ** (n //
2) != 1, "must be a primitive n'th root of unity"
return fft(self, n, omega)

@classmethod
def random(cls, degree, y0=None):
coeffs = [field(random.randint(0, field.modulus-1)) for _ in range(degree+1)]
coeffs = [field(random.randint(0, field.modulus-1))
for _ in range(degree+1)]
if y0 is not None:
coeffs[0] = y0
return cls(coeffs)

# the valuation only gives 0 to the zero polynomial, i.e. 1+degree
def __abs__(self): return len(self.coeffs)

def __iter__(self): return iter(self.coeffs)

def __sub__(self, other): return self + (-other)

def __neg__(self): return Polynomial([-a for a in self])

def __len__(self): return len(self.coeffs)

def __add__(self, other):
newCoefficients = [sum(x) for x in zip_longest(
self, other, fillvalue=self.field(0))]
return Polynomial(newCoefficients)

def __mul__(self, other):
if self.isZero() or other.isZero():
return Zero()

newCoeffs = [self.field(0)
for _ in range(len(self) + len(other) - 1)]

for i, a in enumerate(self):
for j, b in enumerate(other):
newCoeffs[i+j] += a*b
return Polynomial(newCoeffs)

def degree(self): return abs(self) - 1

def leadingCoefficient(self): return self.coeffs[-1]

def __divmod__(self, divisor):
quotient, remainder = Zero(), self
divisorDeg = divisor.degree()
divisorLC = divisor.leadingCoefficient()

while remainder.degree() >= divisorDeg:
monomialExponent = remainder.degree() - divisorDeg
monomialZeros = [self.field(0)
for _ in range(monomialExponent)]
monomialDivisor = Polynomial(
monomialZeros + [remainder.leadingCoefficient() / divisorLC])

quotient += monomialDivisor
remainder -= monomialDivisor * divisor
print(remainder.coeffs)

return quotient, remainder

def Zero():
return Polynomial([])

_poly_cache[field] = Polynomial
return Polynomial

Expand Down Expand Up @@ -191,7 +251,7 @@ def test_correctness(poly, omega, fft_helper_result):
y = poly(omega**i)
c += 1
sys.stdout.write("%d / %d points verified!" % (c,
total_verification_points))
total_verification_points))
char = "\r" if c < len(sample) else "\n"
sys.stdout.write(char)
sys.stdout.flush()
Expand Down
Loading