Skip to content

Commit

Permalink
feat(circom): print timing info in more details
Browse files Browse the repository at this point in the history
  • Loading branch information
chokobole committed Jul 8, 2024
1 parent 9f9184b commit 0da2d85
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 16 deletions.
1 change: 1 addition & 0 deletions vendors/circom/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ tachyon_cc_binary(
"//circomlib/json:prime_field",
"//circomlib/wtns",
"//circomlib/zkey",
"@com_google_absl//absl/strings",
"@kroma_network_tachyon//tachyon/base/console",
"@kroma_network_tachyon//tachyon/base/files:file_path_flag",
"@kroma_network_tachyon//tachyon/base/flag:flag_parser",
Expand Down
96 changes: 80 additions & 16 deletions vendors/circom/prover_main.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include "absl/strings/substitute.h"

#include "circomlib/circuit/quadratic_arithmetic_program.h"
#include "circomlib/json/groth16_proof.h"
#include "circomlib/json/json.h"
Expand Down Expand Up @@ -44,16 +46,62 @@ class FlagValueTraits<Curve> {

namespace circom {

struct TimeInfo {
base::TimeDelta parse_zkey;
base::TimeDelta parse_wtns;
base::TimeDelta prove;
base::TimeDelta verify;

static TimeInfo Max(const TimeInfo& a, const TimeInfo& b) {
return TimeInfo{
std::max(a.parse_zkey, b.parse_zkey),
std::max(a.parse_wtns, b.parse_wtns),
std::max(a.prove, b.prove),
std::max(a.verify, b.verify),
};
}

TimeInfo operator+(const TimeInfo& other) const {
return {
parse_zkey + other.parse_zkey,
parse_wtns + other.parse_wtns,
prove + other.prove,
verify + other.verify,
};
}

TimeInfo& operator+=(const TimeInfo& other) { return *this = *this + other; }

TimeInfo operator/(size_t num_runs) const {
return {
parse_zkey / num_runs,
parse_wtns / num_runs,
prove / num_runs,
verify / num_runs,
};
}

std::string ToString() const {
return absl::Substitute(
"{Parse ZKey: $0 s, Parse Wtns: $1 s, Prove: $2 s, Verify: $3 s}",
parse_zkey.InSecondsF(), parse_wtns.InSecondsF(), prove.InSecondsF(),
verify.InSecondsF());
}
};

template <typename Curve>
void CreateProof(const base::FilePath& zkey_path,
const base::FilePath& witness_path,
const base::FilePath& proof_path,
const base::FilePath& public_path, bool no_zk, bool verify) {
TimeInfo CreateProof(const base::FilePath& zkey_path,
const base::FilePath& witness_path,
const base::FilePath& proof_path,
const base::FilePath& public_path, bool no_zk,
bool verify) {
using F = typename Curve::G1Curve::ScalarField;
using Domain = math::UnivariateEvaluationDomain<F, SIZE_MAX>;

Curve::Init();

TimeInfo time_info;
base::TimeTicks start = base::TimeTicks::Now();
zk::r1cs::groth16::ProvingKey<Curve> proving_key;
zk::r1cs::ConstraintMatrices<F> constraint_matrices;
{
Expand All @@ -64,9 +112,17 @@ void CreateProof(const base::FilePath& zkey_path,
constraint_matrices = std::move(*zkey).TakeConstraintMatrices().ToNative();
}

base::TimeTicks end = base::TimeTicks::Now();
time_info.parse_zkey = end - start;
start = end;

std::unique_ptr<Wtns<F>> wtns = ParseWtns<F>(witness_path);
CHECK(wtns);

end = base::TimeTicks::Now();
time_info.parse_wtns = end - start;
start = end;

absl::Span<const F> full_assignments = wtns->GetWitnesses();

std::unique_ptr<Domain> domain =
Expand All @@ -93,6 +149,10 @@ void CreateProof(const base::FilePath& zkey_path,
full_assignments.subspan(1));
}

end = base::TimeTicks::Now();
time_info.prove = end - start;
start = end;

zk::r1cs::groth16::PreparedVerifyingKey<Curve> prepared_verifying_key =
std::move(proving_key).TakeVerifyingKey().ToPreparedVerifyingKey();
absl::Span<const F> public_inputs = full_assignments.subspan(
Expand All @@ -102,8 +162,13 @@ void CreateProof(const base::FilePath& zkey_path,
public_inputs));
}

end = base::TimeTicks::Now();
time_info.verify = end - start;
end = start;

CHECK(WriteToJson(proof, proof_path));
CHECK(WriteToJson(public_inputs, public_path));
return time_info;
}

} // namespace circom
Expand Down Expand Up @@ -161,32 +226,31 @@ int RealMain(int argc, char** argv) {
tachyon_cerr << "num_runs should be positive" << std::endl;
return 1;
}
base::TimeDelta total_time;
base::TimeDelta max_time;
circom::TimeInfo total_time;
circom::TimeInfo max_time;
for (size_t i = 0; i < num_runs; ++i) {
base::TimeTicks start = base::TimeTicks::Now();
circom::TimeInfo time_info;
switch (curve) {
case Curve::kBN254:
circom::CreateProof<math::bn254::BN254Curve>(
time_info = circom::CreateProof<math::bn254::BN254Curve>(
zkey_path, witness_path, proof_path, public_path, no_zk, verify);
break;
case Curve::kBLS12_381:
#if !TACHYON_CUDA
circom::CreateProof<math::bls12_381::BLS12_381Curve>(
time_info = circom::CreateProof<math::bls12_381::BLS12_381Curve>(
zkey_path, witness_path, proof_path, public_path, no_zk, verify);
#endif
break;
}
base::TimeDelta time_taken = base::TimeTicks::Now() - start;
total_time += time_taken;
max_time = std::max(max_time, time_taken);
std::cout << "Run " << (i + 1) << ", Time Taken: " << time_taken
total_time += time_info;
max_time = circom::TimeInfo::Max(max_time, time_info);
std::cout << "Run " << (i + 1) << ", Time Taken: " << time_info.ToString()
<< std::endl;
}

base::TimeDelta avg_time = total_time / num_runs;
std::cout << "Average Time Taken: " << avg_time << std::endl;
std::cout << "Maximum Time Taken: " << max_time << std::endl;
circom::TimeInfo avg_time = total_time / num_runs;
std::cout << "Average Time Taken: " << avg_time.ToString() << std::endl;
std::cout << "Maximum Time Taken: " << max_time.ToString() << std::endl;

return 0;
}
Expand Down

0 comments on commit 0da2d85

Please sign in to comment.