Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(circom): enable loading witness from json #407

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vendors/circom/build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ def witness_gen_library(
"@com_google_absl//absl/container:flat_hash_map",
"@kroma_network_circom//circomlib/generated/common:common_hdrs",
"@kroma_network_circom//circomlib/generated/{}:fr".format(prime),
"@nlohmann_json//:json",
],
)
2 changes: 2 additions & 0 deletions vendors/circom/circomlib/circuit/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ tachyon_cc_unittest(
name = "adder_circuit_unittest",
srcs = ["adder_circuit_unittest.cc"],
data = [
"adder_data.json",
"//examples:adder.zkey",
"//examples:compile_adder",
],
Expand All @@ -67,6 +68,7 @@ tachyon_cc_unittest(
name = "multiplier_3_circuit_unittest",
srcs = ["multiplier_3_circuit_unittest.cc"],
data = [
"multiplier_3_data.json",
"//examples:compile_multiplier_3",
"//examples:multiplier_3.zkey",
],
Expand Down
22 changes: 21 additions & 1 deletion vendors/circom/circomlib/circuit/adder_circuit_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class AdderCircuitTest : public CircuitTest {
R1CSParser parser;
r1cs_ = parser.Parse(base::FilePath("examples/adder.r1cs"));
ASSERT_TRUE(r1cs_);
}

void LoadRandomWitness() {
circuit_.reset(new Circuit<F>(
r1cs_.get(), base::FilePath("examples/adder_cpp/adder.dat")));
std::vector<uint32_t> values = base::CreateVector(
Expand All @@ -24,19 +26,37 @@ class AdderCircuitTest : public CircuitTest {
ASSERT_EQ(public_inputs.size(), 1);
ASSERT_EQ(public_inputs[0], F(values[0] + values[1]));
}

void LoadWitnessFromJson() {
circuit_.reset(new Circuit<F>(
r1cs_.get(), base::FilePath("examples/adder_cpp/adder.dat")));
circuit_->witness_loader().Load(
base::FilePath("circomlib/circuit/adder_data.json"));
std::vector<F> public_inputs = circuit_->GetPublicInputs();
ASSERT_EQ(public_inputs.size(), 1);
ASSERT_EQ(public_inputs[0], F(7));
}
};

TEST_F(AdderCircuitTest, Synthesize) { this->SynthesizeTest(); }
TEST_F(AdderCircuitTest, Synthesize) {
LoadRandomWitness();
this->SynthesizeTest();

LoadWitnessFromJson();
this->SynthesizeTest();
}

TEST_F(AdderCircuitTest, Groth16ProveAndVerify) {
constexpr size_t kMaxDegree = 127;
LoadRandomWitness();
this->Groth16ProveAndVerifyTest<kMaxDegree,
zk::r1cs::QuadraticArithmeticProgram<F>>();
}

TEST_F(AdderCircuitTest, Groth16ProveAndVerifyUsingZkey) {
constexpr size_t kMaxDegree = 127;

LoadRandomWitness();
ZKeyParser parser;
std::unique_ptr<ZKey> zkey =
parser.Parse(base::FilePath("examples/adder.zkey"));
Expand Down
4 changes: 4 additions & 0 deletions vendors/circom/circomlib/circuit/adder_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"a": "3",
"b": "4"
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ class Multiplier3CircuitTest : public CircuitTest {
R1CSParser parser;
r1cs_ = parser.Parse(base::FilePath("examples/multiplier_3.r1cs"));
ASSERT_TRUE(r1cs_);
}

void LoadRandomWitness() {
circuit_.reset(new Circuit<F>(
r1cs_.get(),
base::FilePath("examples/multiplier_3_cpp/multiplier_3.dat")));
Expand All @@ -23,19 +25,38 @@ class Multiplier3CircuitTest : public CircuitTest {
ASSERT_EQ(public_inputs.size(), 1);
ASSERT_EQ(public_inputs[0], values[0] * values[1] * values[2]);
}

void LoadWitnessFromJson() {
circuit_.reset(new Circuit<F>(
r1cs_.get(),
base::FilePath("examples/multiplier_3_cpp/multiplier_3.dat")));
circuit_->witness_loader().Load(
base::FilePath("circomlib/circuit/multiplier_3_data.json"));
std::vector<F> public_inputs = circuit_->GetPublicInputs();
ASSERT_EQ(public_inputs.size(), 1);
ASSERT_EQ(public_inputs[0], F(60));
}
};

TEST_F(Multiplier3CircuitTest, Synthesize) { this->SynthesizeTest(); }
TEST_F(Multiplier3CircuitTest, Synthesize) {
LoadRandomWitness();
this->SynthesizeTest();

LoadWitnessFromJson();
this->SynthesizeTest();
}

TEST_F(Multiplier3CircuitTest, Groth16ProveAndVerify) {
constexpr size_t kMaxDegree = 31;
LoadRandomWitness();
this->Groth16ProveAndVerifyTest<kMaxDegree,
zk::r1cs::QuadraticArithmeticProgram<F>>();
}

TEST_F(Multiplier3CircuitTest, Groth16ProveAndVerifyUsingZkey) {
constexpr size_t kMaxDegree = 3;

LoadRandomWitness();
ZKeyParser parser;
std::unique_ptr<ZKey> zkey =
parser.Parse(base::FilePath("examples/multiplier_3.zkey"));
Expand Down
3 changes: 3 additions & 0 deletions vendors/circom/circomlib/circuit/multiplier_3_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"in": ["3", "4", "5"]
}
4 changes: 4 additions & 0 deletions vendors/circom/circomlib/circuit/witness_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class WitnessLoader {

void Load() { loadWitness(calc_wit_.get(), witness_); }

void Load(const base::FilePath& json) {
loadJson(calc_wit_.get(), json.value());
}

F Get(uint32_t i) const {
FrElement v;
calc_wit_->getWitness(i, &v);
Expand Down
2 changes: 1 addition & 1 deletion vendors/circom/circomlib/generated/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ This is taken and modified from [iden3/circom/code_producers/src/c_elements](htt

- Modified to remove errors from compilation.
- `loadCircuit` and `writeBinWitness` are moved to [common/calcwit.hpp](/vendors/circom//circomlib/generated/common/calcwit.hpp) and [common/calcwit.hpp](/vendors/circom//circomlib/generated/common/calcwit.cpp).
- `loadWitness` is created instead of `loadJson`, which loads the witness from the `absl::flat_hash_map<>`.
- `loadWitness` is created, which loads the witness from the `absl::flat_hash_map<>`.

See the following files for more details on the modifications.

Expand Down
106 changes: 106 additions & 0 deletions vendors/circom/circomlib/generated/common/calcwit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include <vector>
#include <chrono>
#include "calcwit.hpp"
#include "nlohmann/json.hpp"

using json = nlohmann::json;

extern void run(Circom_CalcWit* ctx);

Expand Down Expand Up @@ -213,6 +216,109 @@ Circom_Circuit* loadCircuit(std::string const &datFileName) {
return circuit;
}

bool check_valid_number(std::string & s, uint base){
bool is_valid = true;
if (base == 16){
for (uint i = 0; i < s.size(); i++){
is_valid &= (
('0' <= s[i] && s[i] <= '9') ||
('a' <= s[i] && s[i] <= 'f') ||
('A' <= s[i] && s[i] <= 'F')
);
}
} else{
for (uint i = 0; i < s.size(); i++){
is_valid &= ('0' <= s[i] && s[i] < char(int('0') + base));
}
}
return is_valid;
}

void json2FrElements (json val, std::vector<FrElement> & vval){
if (!val.is_array()) {
FrElement v;
std::string s_aux, s;
uint base;
if (val.is_string()) {
s_aux = val.get<std::string>();
std::string possible_prefix = s_aux.substr(0, 2);
if (possible_prefix == "0b" || possible_prefix == "0B"){
s = s_aux.substr(2, s_aux.size() - 2);
base = 2;
} else if (possible_prefix == "0o" || possible_prefix == "0O"){
s = s_aux.substr(2, s_aux.size() - 2);
base = 8;
} else if (possible_prefix == "0x" || possible_prefix == "0X"){
s = s_aux.substr(2, s_aux.size() - 2);
base = 16;
} else{
s = s_aux;
base = 10;
}
if (!check_valid_number(s, base)){
std::ostringstream errStrStream;
errStrStream << "Invalid number in JSON input: " << s_aux << "\n";
throw std::runtime_error(errStrStream.str() );
}
} else if (val.is_number()) {
double vd = val.get<double>();
std::stringstream stream;
stream << std::fixed << std::setprecision(0) << vd;
s = stream.str();
base = 10;
} else {
std::ostringstream errStrStream;
errStrStream << "Invalid JSON type\n";
throw std::runtime_error(errStrStream.str() );
}
Fr_str2element (&v, s.c_str(), base);
vval.push_back(v);
} else {
for (uint i = 0; i < val.size(); i++) {
json2FrElements (val[i], vval);
}
}
}

void loadJson(Circom_CalcWit *ctx, std::string filename) {
std::ifstream inStream(filename);
json j;
inStream >> j;

u64 nItems = j.size();
// printf("Items : %llu\n",nItems);
fakedev9999 marked this conversation as resolved.
Show resolved Hide resolved
if (nItems == 0){
ctx->tryRunCircuit();
}
for (json::iterator it = j.begin(); it != j.end(); ++it) {
// std::cout << it.key() << " => " << it.value() << '\n';
fakedev9999 marked this conversation as resolved.
Show resolved Hide resolved
u64 h = fnv1a(it.key());
std::vector<FrElement> v;
json2FrElements(it.value(),v);
uint signalSize = ctx->getInputSignalSize(h);
if (v.size() < signalSize) {
std::ostringstream errStrStream;
errStrStream << "Error loading signal " << it.key() << ": Not enough values\n";
throw std::runtime_error(errStrStream.str() );
}
if (v.size() > signalSize) {
std::ostringstream errStrStream;
errStrStream << "Error loading signal " << it.key() << ": Too many values\n";
throw std::runtime_error(errStrStream.str() );
}
for (uint i = 0; i<v.size(); i++){
try {
// std::cout << it.key() << "," << i << " => " << Fr_element2str(&(v[i])) << '\n';
fakedev9999 marked this conversation as resolved.
Show resolved Hide resolved
ctx->setInputSignal(h,i,v[i]);
} catch (std::runtime_error e) {
std::ostringstream errStrStream;
errStrStream << "Error setting signal: " << it.key() << "\n" << e.what();
throw std::runtime_error(errStrStream.str() );
}
}
}
}

void loadWitness(Circom_CalcWit *ctx, const absl::flat_hash_map<std::string, std::vector<FrElement>>& witness) {
size_t nItems = witness.size();
// printf("Items : %llu\n",nItems);
Expand Down
Loading
Loading