A ZKML framework based on the Libra protocol for proving Onnx and general Numpy computations built with pure Python.
In a nutshell, zkgraph aims to provide a tool for non-cryptographers to work with verifiable inference over public data and ML models with public parameters. With Zkgraph statements such as:
I have a valid execution of a machine learning model with public inputs and public parameters
I have a valid sequence of operations being performed by a known family of algorithms and with public inputs
If you are here, it means you at least considered the premise that decentralization can impact the future of human society, redefining how we see money, property, social contract, governance, and, at the most basic level, clever forms to tackle collective human greed using prosocial mechanisms for consensus. With that in mind, let's talk about how we are doing ML today:
- Companies train models. Some of them release open-source weights and open-source datasets, but... All of them require a massive amount of computing power, acting as a constant factor for abusive pricing and abusive data manipulation. You know how it goes: Company WTF has an enormous global platform for doing AI on any task. Sometimes, they get careless about data and sell every data point you generated on their platform to someone else, or, even better, they use it for their own benefit. And you? Ha! You get to pay for a new premium plan to use a new model that can implement the code for the machine that goes PING much faster than the previous version, among many things.
Hitherto, the basic premise for these large corporations to exploit your data is that you do not have a planetary and scalable supercomputer capable of reaching consensus automatically about programs and computations, using money as the principal asset that drives the acquisition of goods and the main incentive reward for prosocial behavior.
- Wait?! That's a blockchain! I do have, then!
- But hey, blockchain it's shitty for performing heavy computations; how can someone run an ML model on-chain?
Zkgraph uses a new protocol called Libra that we use to generate proofs of execution of computational graphs generated by ML algorithms. Being based on a ZKSNARK protocol, the statement about the computation being valid can be carried out in the form of a proof transcript (it's just a file or a chunk of bytes, you name it) to be verified by a third-party application like a bank API, or a BLOCKCHAIN SMART CONTRACT! Ha!
The thing here is that you don't have to run code on-chain; you can do it wherever you want, off-chain. Suppose you have a protocol for attesting work, statements, or predicates, and the protocol can verify said statements, work, or any predicate you can think of ON-CHAIN. In that case, you can bridge heavy computation workloads and blockchain.
Zkgraph, by using ZKSNARKs, can act as a consensus bridge for different decentralized applications. Let's see an example:
It can serve as a consensus protocol for an automated crypto trading bot. How?
- Users can enlist on the protocol and run computations to predict the best moment to sell or buy assets.
- Using zkgraph, they can send zero-knowledge proof to the protocol that the ML computation they did to predict when the best moment to buy an asset is valid and that the protocol can trust that the calculation was performed correctly.
- The protocol relied on actors to perform computation on his behalf, and a smart contract can reward the actors' prosocial behavior if they provide a ZKP (Zero-knowledge proof) attesting to meaningful work.
Using zkgraph is easy; you need an Onnx graph and some float point inputs; take a look at our example:
from zkgraph.polynomials.field import dequantization
import time
import numpy as np
import onnx
import onnxruntime
import os
from zkgraph.graph.engine import Value
from zkgraph.ops.onnx_utils import generate_small_iris_onnx_model
from zkgraph.ops.from_onnx import from_onnx
from zkgraph.prover.prover import ZkProver
from zkgraph.verifier.verifier import ZkVerifier
import subprocess
use_mkzg = int(os.environ.get("USE_PCS", 0))
use_noir = int(os.environ.get("USE_NOIR", 0))
def add_intermediate_layers_as_outputs(onnx_model):
"""takes an onnx model and returns the same model but will all intermediate
node outputs as outputs to the model.
Useful for testing that all nodes are calculated correctly
"""
shape_info = onnx.shape_inference.infer_shapes(onnx_model)
value_info_protos = []
for node in shape_info.graph.value_info:
value_info_protos.append(node)
onnx_model.graph.output.extend(value_info_protos)
onnx.checker.check_model(onnx_model)
return onnx_model
def main():
np.random.seed(42)
if "iris_model.onnx" not in os.listdir("tests/assets/"):
generate_small_iris_onnx_model(onnx_output_path="tests/assets/iris_model.onnx")
onnx_model = add_intermediate_layers_as_outputs(
onnx.load("tests/assets/iris_model.onnx")
)
# Create a dummy input
dummy_input = np.random.randn(1, 2).astype(np.float32)
print(f"Dummy input shape: {dummy_input.shape}")
print(f"Dummy input: {dummy_input}")
# Run the model through onnx inference session
session = onnxruntime.InferenceSession(onnx_model.SerializeToString())
input_name = session.get_inputs()[0].name
onnx_outputs = session.run(None, {input_name: dummy_input})
zerok_outputs = from_onnx(onnx_model, dummy_input)
graph_output = np.sum(zerok_outputs[0])
print(f"ONNX output: {onnx_outputs[0]}")
print([dequantization(o.data) for o in zerok_outputs[0][0]])
print(
f"Graph output: {graph_output}, dequantized: {dequantization(graph_output.data)}"
)
start = time.time()
layered_circuit, _ = Value.compile_layered_circuit(graph_output)
start = time.time()
if use_mkzg:
public_parameters = {
"r_pp": "./tests/assets/random_polynomial_r_powers_of_tau.ptau",
"zk_pp": "./tests/assets/zk_sumcheck_powers_of_tau.ptau",
}
prover = ZkProver(
layered_circuit, mkzg=use_mkzg, public_parameters=public_parameters
)
verifier = ZkVerifier(
layered_circuit, mkzg=True, public_parameters=public_parameters
)
else:
prover = ZkProver(layered_circuit, mkzg=False)
verifier = ZkVerifier(layered_circuit, mkzg=False)
assert prover.prove()
print(f"Time to prove: {time.time() - start}")
proof_transcript = prover.proof_transcript.to_bytes()
print(f"Time to prove: {time.time() - start}")
start = time.time()
verifier.run_verifier(proof_transcript=proof_transcript)
print(f"Time to verify: {time.time() - start}")
if use_noir:
verifier.get_noir_transcript()
subprocess.call(
"cd onchain_verifier/ && nargo execute iris && bb prove -b ./target/onchain_verifier.json -w ./target/iris.gz -o ./target/proof && bb write_vk -b target/onchain_verifier.json -o ./target/vk && bb verify -k ./target/vk -p ./target/proof && bb contract",
shell=True,
)
if __name__ == "__main__":
main()
For the BNB Q3 hackathon, we constructed a small yet working version (not suitable for production yet; we did not run an audit test, and we are using pure Python, which is as slow as f*) of the Libra protocol that allows users to generate ZKPs for Onnx graphs. We abstracted complexity by constructing a symmetric quantization scheme that can represent float point numbers as finite field elements for some large prime (We are using the prime modulo for the curve bls12-381). We implemented the code for the hackathon, assuming we would be dealing with public inputs and models with public weights.
poetry install
For testing purposes we have a onnx graph located at tests/assets/iris_model.onnx
to assist you getting up to speed with the framework.
To see in action, run the script below:
python main.py
You can opt to run with the MKZG polynomial commitment scheme by running with:
USE_PCS=1 python main.py
We also created a simple on-chain verifier using noir's ZK DSL. The noir code implements the assertions required on the zkgraph verifier. Through the use of proof recursion, we can attest the validity of the original computation on-chain, and by using noir, we can do that on-chain. In this hackathon, we mainly constructed the on-chain verifier as a toy experiment outlining how one can create an on-chain verifier for a GKR-based protocols such as Libra.
- Run the instructions presented here.
You must complete the last step to be able to generate the contract!!
To generate the verifier contract in solidity run the test script like this:
USE_NOIR=1 python main.py
Zkgraph is NOT ready for production, and we can only generate proofs for ops like GEMM, ReLu, and CNN. We plan to expand on the future, and this hackathon is our open invitation to the ZK geeks out there who want to build or leverage a protocol with linear time prover complexity. GKR-based protocols can also benefit from parallel and decentralized implementations, making them much more practical than SOTA protocols like Hyperplonk.
We implemented a few papers for the code we created:
We implemented everything in pure Python in the spirit that everybody should be able to read the code and understand it. Unfortunately, it is super slow, especially when using bls12-381 pairings. We plan to address this after the hackathon. Once again, we invite the community to work with us to implement these features, as we plan to implement CUDA extensions and redesign the protocol in C++, making it easier for cross-platform use, like generating ZKPs on the browser by leveraging wasm.
We also were inspired by these source codes:
- Libra original code (Base Logic for the GKR)
- Plookup (For the Merlin/Strobe-128 transcript we used)
Copyright (c) 2024 AE Studio. (MIT License)