From a4e1fc5c6cf090657900163f3b1d3dd0f2b8156b Mon Sep 17 00:00:00 2001 From: Jonas Nick Date: Wed, 8 Jan 2020 21:30:55 +0000 Subject: [PATCH] Optionally output circuit and assignment in libsecp-zkp bulletproof format --- README.md | 6 +- purify.py | 319 ++++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 287 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index 5271289..e12f035 100644 --- a/README.md +++ b/README.md @@ -76,14 +76,14 @@ For a message *m* and key *(z1, z2)*, *Purify((z1purify.py tool implements this: - $ ./purify.py eval 11427c7268288dddf0cd24af3d30524fd817a91e103e7e02eb28b78db81cb350b3d2562f45fa8ecd711d1becc02fa348cf2187429228e7aac6644a3da2824e93 01234567 + $ ./purify.py eval 01234567 11427c7268288dddf0cd24af3d30524fd817a91e103e7e02eb28b78db81cb350b3d2562f45fa8ecd711d1becc02fa348cf2187429228e7aac6644a3da2824e93 eval: afae82108c66397451ce376bc95751c398e40eaf8c768d1b18cc9dd4161cee35 ## Verification using arithmetic circuits The purify.py can also construct arithmetic circuits that verify the Purify evaluation as well as correctness of public keys. Specifically: - $ ./purify.py verifier 01234567 >verifier.py + $ ./purify.py verifier 01234567 9343f981e9c40546061e63f9f4e6f61541c483c8aae8fe27180c490f0faf584d5036a5952b01200d8b0fdb49c83d5f8dcc8ae434e77785c576720d18897bbea5 >verifier.py This generates a Python function verifier(pubkey, output, v) that takes as input the *x* value from above, the output from the evaluation, and an assignment for all of the circuit's secret variables. It is specific for the message 01234567 in this case. @@ -104,6 +104,8 @@ These are indeed the public key and the evaluation. The third argument to **Note that this does not actually implement any zero-knowledge proofs. It only derives the relations that would need to be proven, and the secret values they're over in specific instances.** +Alternatively, by adding the --bulletproofs-outfile flag to the prove and verifier commands, the output is a format that can be used in the [libsecp256k1-zkp](https://github.com/ElementsProject/secp256k1-zkp/pull/16) bulletproofs module (see https://github.com/jonasnick/secp256k1-zkp/tree/bulletproof-musig-dn for benchmarking purify with bulletproofs). + ## Example parameters The code in this repository has parameters that correspond to the order of the secp256k1 group: diff --git a/purify.py b/purify.py index 8fcae92..13e8898 100755 --- a/purify.py +++ b/purify.py @@ -4,7 +4,9 @@ import hmac import hashlib import secrets -from math import ceil +import copy +from math import ceil, log +import argparse # Parameters generated using gen_params.sage for Curve25519 #P = 0x1000000000000000000000000000000014DEF9DEA2F79CD65812631A5CF5D3ED @@ -302,7 +304,7 @@ def __str__(self): if (factor == 1): terms.append(varname) else: - terms.append("%i*%s" % (factor, varname)) + terms.append("%i * %s" % (factor, varname)) if len(terms) == 1: return terms[0] else: @@ -319,6 +321,12 @@ def evaluate(self, m): return None return ret % P + # split in constant and non-constant part + def split(self): + e = Expr(0) + e.linear = self.linear + return (Expr(self.const), e) + class Transcript: def __init__(self): self.varmap = dict() @@ -398,6 +406,184 @@ def evaluate(self, e): # self.eqs.append(bitsum - e) # return ret +# A transcript that can be turned into libsecp256k1-zkp circuit and assignment format +class BulletproofsTranscript: + def __init__(self, transcript, n_bits): + # Number of bit constraints. We don't need to explicitly state them for + # bulletproofs. + self.n_bits = n_bits + # libsecp-zkp bulletproofs require power of 2 muls + self.n_muls = 2**ceil(log(len(transcript.muls), 2)) + # Simple assignments of wires, for example (L0, v[0]) + self.assignments = [] + # Assignments of wires as linear combination of other wires, for example (L1, L0 + v[1]) + self.linear_assignments = [] + # Filled with n_bits many bit constraints. + self.bit_constraints = [] + # Constraints we will encode + self.constraints = [] + # Map from "v[i]" coming from the transcript to a bulletproofs variable name (i.e. "Li", "Ri", "Oi") + self.vtoA = {} + # There's a single commitment in purify + self.n_commitments = 1 + + for (i, (l, r, o)) in enumerate(transcript.muls): + # Need to copy because the muls elements are the same expressions + # sometimes, but we rely on being able to change the expressions + # independently + self.add_mul("L", i, copy.deepcopy(l)) + self.add_mul("R", i, copy.deepcopy(r)) + self.add_mul("O", i, copy.deepcopy(o)) + for i in range(len(transcript.muls), self.n_muls): + self.add_mul("L", i, Expr(0)) + self.add_mul("R", i, Expr(0)) + self.add_mul("O", i, Expr(0)) + + # Replaces "v[i]" in an expr with the corresponding "Li", "Ri", "Oi" + def replace_expr_v_with_bp_var(self, e): + e.linear = list(map(lambda x: x if not x[0] in self.vtoA else (self.vtoA[x[0]], x[1]), e.linear)) + + # Returns whether expr is a simple assignment + def replace_and_insert(self, expr, s): + if len(expr.linear) >= 1: + self.replace_expr_v_with_bp_var(expr) + if expr.const == 0 and len(expr.linear) == 1 and not expr.linear[0][0] in self.vtoA: + self.vtoA[expr.linear[0][0]] = s + if "v[" in expr.linear[0][0]: + return True + return False + + def add_mul(self, s, i, expr): + varname = s + str(i) + is_assignment = self.replace_and_insert(expr, varname) + if is_assignment: + self.assignments += [(varname, expr)] + else: + # Split the expression, because only the constant part must be on + # the right hand side of the equation. + c, l = expr.split() + e = Expr(varname) + lhs = e - l + self.linear_assignments += [(varname, expr)] + # Skip bit constraints + if len(self.bit_constraints) < 2*self.n_bits: + self.bit_constraints += [(lhs, c)] + else: + self.constraints += [(lhs, c)] + + def add_pubkey_and_out(self, pubkey, P1x, P2x, out): + def a(pk, Px): + self.replace_expr_v_with_bp_var(Px) + c, l = Px.split() + tup = (l, pk - c) + self.constraints += [tup] + a(pubkey % P, P1x) + a(pubkey // P, P2x) + self.replace_expr_v_with_bp_var(out) + # Add constraint to for commitment + self.constraints += [(out - Expr("V0"), Expr(0))] + + # Return circuit in bulletproofs module plaintext format + def plaintext_circuit(self): + ret = "%i,%i,%i,%i;" % (self.n_muls, self.n_commitments, self.n_bits, len(self.constraints)) + i = 0 + for cons in self.constraints: + cons0 = str(cons[0]) + cons1 = str(cons[1]) + # Remove unnecessary parantheses from Expression string after + # verifying they don't do anything + assert(cons0.count("(") == 0 or (cons0.count("(") == 1 and cons0[0] == "(" and cons0[-1] == ")")) + assert(cons1.count("(") == 0 or (cons1.count("(") == 1 and cons1[0] == "(" and cons1[-1] == ")")) + cons0 = cons0.replace("(","").replace(")","") + cons1 = cons1.replace("(","").replace(")","") + ret += "%s = %s;" % (cons0, cons1) + return ret + + def write_circuit(self, f): + version = 1 + f.write(version.to_bytes(4, byteorder='little')) + f.write(self.n_commitments.to_bytes(4, byteorder='little')) + f.write(self.n_muls.to_bytes(8, byteorder='little')) + f.write(self.n_bits.to_bytes(8, byteorder='little')) + f.write(len(self.constraints).to_bytes(8, byteorder='little')) + + # Copied from libsecp + def secp256k1_bulletproofs_encoding_width(n): + if n < 0x100: + return 1 + if n < 0x10000: + return 2; + if n < 0x100000000: + return 4; + return 8; + row_width = secp256k1_bulletproofs_encoding_width(self.n_muls) + row_size = 0 + # In these "matrices" every row corresponds to a wire (f.e. wl[0] is + # L0). Every entry in the row is a tuple of the constraints index this + # wire is added to, and the factor its multiplied with before that. + wl = [[]] * self.n_muls + wr = [[]] * self.n_muls + wo = [[]] * self.n_muls + wv = [[]] * self.n_commitments + + def add_entry(w, var, constraint_idx, factor): + var_idx = int(var[1:]) + w[var_idx] = w[var_idx] + [(constraint_idx, factor)] + + for (i, (left, _)) in enumerate(self.constraints): + for summand in left.linear: + if "L" == summand[0][0]: + add_entry(wl, summand[0], i, summand[1]) + elif "R" == summand[0][0]: + add_entry(wr, summand[0], i, summand[1]) + elif "O" == summand[0][0]: + add_entry(wo, summand[0], i, summand[1]) + elif "V" == summand[0][0]: + add_entry(wv, summand[0], i, summand[1]) + + for row in wl + wr + wo + wv: + row_width = secp256k1_bulletproofs_encoding_width(self.n_muls); + f.write(len(row).to_bytes(row_width, byteorder='little')) + for entry in row: + f.write(entry[0].to_bytes(row_width, byteorder='little')) + f.write(b'\x20') + f.write(entry[1].to_bytes(32, byteorder='little')) + + # Write constant part (right hand side) + for (_, right) in self.constraints: + f.write(b'\x20') + f.write(right.const.to_bytes(32, byteorder='little')) + + def evaluate(self, m, commitment): + m["V0"] = commitment + for (v, A) in self.vtoA.items(): + m[A] = m[v] + for assign in self.assignments + self.linear_assignments: + m[assign[0]] = assign[1].evaluate(m) + for i in range(self.n_muls): + if (m["L%i" %i] * m["R%i" % i]) % P != m["O%i" % i]: + return False + for con in self.constraints + self.bit_constraints: + if con[0].evaluate(m) != con[1].evaluate(m): + return False + return True + + # m must have been called with self.evaluate + def write_assignment(self, m, f): + version = 1 + f.write(version.to_bytes(4, byteorder='little')) + f.write(self.n_commitments.to_bytes(4, byteorder='little')) + f.write(self.n_muls.to_bytes(8, byteorder='little')) + def write(s): + for i in range(self.n_muls): + f.write(b'\x20') + f.write(m["%s%s" % (s, i)].to_bytes(32, byteorder='little')) + write("L") + write("R") + write("O") + f.write(b'\x20') + f.write(m["V0"].to_bytes(32, byteorder='little')) + def hmac_sha256(key, data): return hmac.new(key, data, hashlib.sha256).digest() @@ -583,31 +769,81 @@ def circuit_main(trans, M1, M2, z1=None, z2=None): z2bitvals = key_to_bits(z2, N2.bit_length() - 1) z1bits = [trans.boolean(trans.secret(z1bitval)) for z1bitval in z1bitvals] z2bits = [trans.boolean(trans.secret(z2bitval)) for z2bitval in z2bitvals] + # number of bit constraints + n_bits = len(z1bits) + len(z2bits) out_P1x = circuit_ec_multiply_x(E1, trans, G1, z1bits) out_P2x = circuit_ec_multiply_x(E2, trans, G2, z2bits) out_x1 = circuit_ec_multiply_x(E1, trans, M1, z1bits) out_x2 = circuit_ec_multiply_x(E2, trans, M2, z2bits) - return (circuit_combine(trans, out_x1, out_x2), out_P1x, out_P2x) - -if len(sys.argv) < 2: - print("Usage: %s gen []: generate a key" % __file__) - print(" %s eval : evaluate the PRF" % __file__) - print(" %s verifier : output verifier circuit for a given message" % __file__) - print(" %s prove : produce input for verifier" % __file__) -elif sys.argv[1] == "gen": - if len(sys.argv) == 2: + return (circuit_combine(trans, out_x1, out_x2), out_P1x, out_P2x, n_bits) + +# verifier command with python output +def verifier_cmd_python(trans, P1x, P2x, out): + print("def verify(pubkey, output, v):") + print(" P = %i" % P) + print(" # %i multiplications" % len(trans.muls)) + for (a, b, m) in trans.muls: + print(" assert((%s * %s - %s) %% P == 0)" % (a, b, m)) + print(" # %i linear equations" % len(trans.eqs)) + for (eq) in trans.eqs: + print(" assert((%s) %% P == 0)" % eq) + print(" # Verify public key") + print(" assert(%s %% P == pubkey %% P)" % P1x) + print(" assert(%s %% P == pubkey // P)" % P2x) + print(" # Verify output") + print(" assert(output == %s %% P)" % out) + +# verifier command with bulletproofs output +def verifier_cmd_bulletproofs(fname, trans, n_bits, pubkey, P1x, P2x, out): + b_trans = BulletproofsTranscript(trans, n_bits) + b_trans.add_pubkey_and_out(pubkey, P1x, P2x, out) + with open(fname, 'wb') as f: + b_trans.write_circuit(f) + +# prove command with python output +def prove_cmd_python(trans, pubkey, out_native): + print("verify(0x%x, 0x%x, [%s])" % (pubkey, out_native, ",".join("%s" % (trans.varmap["v[%i]" % i]) for i in range(len(trans.varmap))))) + +# prove command with bulletproofs output +def prove_cmd_bulletproofs(fname, trans, n_bits, pubkey, P1x, P2x, out, out_native): + b_trans = BulletproofsTranscript(trans, n_bits) + b_trans.add_pubkey_and_out(pubkey, P1x, P2x, out) + assert(b_trans.evaluate(trans.varmap, out_native)) + with open(fname, 'wb') as f: + b_trans.write_assignment(trans.varmap, f) + +arg_parser = argparse.ArgumentParser(description='A PRF with low multiplicative complexity', usage='''%s [] + The available commands are: + %s gen [--seckey ]: generate a key + %s eval : evaluate the PRF + %s verifier [--bulletproofs-outfile ]: output verifier circuit for a given message + %s prove [--bulletproofs-outfile ]: produce input for verifier + ''' % ((__file__,)*5)) +arg_parser.add_argument('cmd', choices=['gen', 'eval', 'verifier', 'prove']) +args = arg_parser.parse_args(sys.argv[1:2]) + +if args.cmd == "gen": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('--seckey', required=False) + args = arg_parser.parse_args(sys.argv[2:]) + if args.seckey is None: z = secrets.randbelow((N1 - 1) // 2 * (N2 - 1) // 2) else: - z = int(sys.argv[2], 16) + z = int(args.seckey, 16) z1, z2 = unpack_secret(z) P1 = E1.affine(E1.mul(G1, z1)) P2 = E2.affine(E2.mul(G2, z2)) print("z=%x # private key" % z) print("x=%x # public key" % pack_public(P1[0], P2[0])) -elif sys.argv[1] == "eval": - z = int(sys.argv[2], 16) - m = bytes.fromhex(sys.argv[3]) +elif args.cmd == "eval": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('hexmsg') + arg_parser.add_argument('seckey') + args = arg_parser.parse_args(sys.argv[2:]) + + z = int(args.seckey, 16) + m = bytes.fromhex(args.hexmsg) z1, z2 = unpack_secret(z) M1 = hash_to_curve(b"Eval/1/" + m, E1) M2 = hash_to_curve(b"Eval/2/" + m, E2) @@ -615,28 +851,34 @@ def circuit_main(trans, M1, M2, z1=None, z2=None): Q2 = E2.affine(E2.mul(M2, z2)) out = combine(Q1[0], Q2[0]) print("eval: %x" % out) -elif sys.argv[1] == "verifier": - m = bytes.fromhex(sys.argv[2]) +elif args.cmd == "verifier": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('hexmsg') + arg_parser.add_argument('pubkey') + arg_parser.add_argument('--bulletproofs-outfile') + args = arg_parser.parse_args(sys.argv[2:]) + + m = bytes.fromhex(args.hexmsg) + pubkey = int(args.pubkey, 16) M1 = hash_to_curve(b"Eval/1/" + m, E1) M2 = hash_to_curve(b"Eval/2/" + m, E2) trans = Transcript() - out, P1x, P2x = circuit_main(trans, M1, M2) - print("def verify(pubkey, output, v):") - print(" P = %i" % P) - print(" # %i multiplications" % len(trans.muls)) - for (a, b, m) in trans.muls: - print(" assert((%s * %s - %s) %% P == 0)" % (a, b, m)) - print(" # %i linear equations" % len(trans.eqs)) - for (eq) in trans.eqs: - print(" assert((%s) %% P == 0)" % eq) - print(" # Verify public key") - print(" assert(%s %% P == pubkey %% P)" % P1x) - print(" assert(%s %% P == pubkey // P)" % P2x) - print(" # Verify output") - print(" assert(output == %s %% P)" % out) -elif sys.argv[1] == "prove": - m = bytes.fromhex(sys.argv[2]) - z = int(sys.argv[3], 16) + out, P1x, P2x, n_bits = circuit_main(trans, M1, M2) + + if args.bulletproofs_outfile is None: + verifier_cmd_python(trans, P1x, P2x, out) + else: + verifier_cmd_bulletproofs(args.bulletproofs_outfile, trans, n_bits, pubkey, P1x, P2x, out) + +elif args.cmd == "prove": + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument('hexmsg') + arg_parser.add_argument('seckey') + arg_parser.add_argument('--bulletproofs-outfile') + args = arg_parser.parse_args(sys.argv[2:]) + + m = bytes.fromhex(args.hexmsg) + z = int(args.seckey, 16) z1, z2 = unpack_secret(z) M1 = hash_to_curve(b"Eval/1/" + m, E1) M2 = hash_to_curve(b"Eval/2/" + m, E2) @@ -646,11 +888,16 @@ def circuit_main(trans, M1, M2, z1=None, z2=None): Q2 = E2.affine(E2.mul(M2, z2)) out_native = combine(Q1[0], Q2[0]) trans = Transcript() - out, P1x, P2x = circuit_main(trans, M1, M2, z1, z2) + out, P1x, P2x, n_bits = circuit_main(trans, M1, M2, z1, z2) assert(trans.evaluate(P1x) == P1[0]) assert(trans.evaluate(P2x) == P2[0]) assert(trans.evaluate(out) == out_native) pubkey = pack_public(P1[0], P2[0]) - print("verify(0x%x, 0x%x, [%s])" % (pubkey, out_native, ",".join("%s" % (trans.varmap["v[%i]" % i]) for i in range(len(trans.varmap))))) + + if args.bulletproofs_outfile is None: + prove_cmd_python(trans, pubkey, out_native) + else: + prove_cmd_bulletproofs(args.bulletproofs_outfile, trans, n_bits, pubkey, P1x, P2x, out, out_native) + else: print("Unknown command")