Skip to content

Commit

Permalink
Make efficient user of semaphores to handle multiple proofs efficiently
Browse files Browse the repository at this point in the history
  • Loading branch information
akshay111meher committed Dec 13, 2024
1 parent dcd0b76 commit 7ef231a
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 17 deletions.
24 changes: 11 additions & 13 deletions listener/src/job_creator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,8 @@ impl JobCreator {
stop_handle.store(true, Ordering::Release);
});

let proof_semaphore = Arc::new(Semaphore::new(self.max_threads)); // ensures that only `max_threads` number of proofs are flushed to generator
let valid_proof_semaphore = Arc::new(Semaphore::new(self.max_threads)); // ensures that only `max_threads` number of proofs are flushed to generator
let invalid_inputs_semaphore = Arc::new(Semaphore::new(5));
let transaction_semaphore = Arc::new(Semaphore::new(1)); // ensures 1 transaction is published at a time

loop {
Expand Down Expand Up @@ -688,7 +689,8 @@ impl JobCreator {
let markets_clone = Arc::clone(&markets);
// code inside thread starts here

let proof_semaphore = proof_semaphore.clone();
let valid_proof_semaphore = valid_proof_semaphore.clone();
let invalid_inputs_semaphore = invalid_inputs_semaphore.clone();
let transaction_semaphore = transaction_semaphore.clone();

let skip_input_verification = self.skip_input_verification.clone();
Expand All @@ -714,19 +716,15 @@ impl JobCreator {
skip_input_verification,
};

let proof_permit = proof_semaphore
.acquire()
.await
.expect("Failed to acquire proof semaphore");

let proof = match proof_generator::generate_proof(generate_proof_args).await
let proof = match proof_generator::generate_proof(
generate_proof_args,
valid_proof_semaphore,
invalid_inputs_semaphore,
)
.await
{
Ok(proof) => {
drop(proof_permit);
proof
}
Ok(proof) => proof,
Err(err) => {
drop(proof_permit);
log::error!("Error generating proof for ask: {}", event.ask_id);
log::error!("{}", err.to_string());
return log::error!("{}", err);
Expand Down
11 changes: 9 additions & 2 deletions listener/src/proof_generator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::io::Read;
use std::sync::Arc;
use std::time::Instant;
use std::{thread, time::Duration};
use tokio::sync::Semaphore;

mod confidential_provers;
mod non_confidential_prover;
Expand All @@ -38,6 +39,8 @@ pub struct GenerateProofParams<'a> {
//Generating proof for the input
pub async fn generate_proof(
generate_proof_params: GenerateProofParams<'_>,
valid_proof_semaphore: Arc<Semaphore>,
invalid_inputs_semaphore: Arc<Semaphore>,
) -> Result<Proof, Box<dyn std::error::Error>> {
let (public_inputs, decoded_secret_input, market_id, parsed_ask_created_log, markets) =
fetch_decoded_secret(generate_proof_params.clone())
Expand Down Expand Up @@ -79,7 +82,9 @@ pub async fn generate_proof(
generate_proof_params.skip_input_verification,
);

confidential_prover.get_proof().await
confidential_prover
.get_proof(valid_proof_semaphore, invalid_inputs_semaphore)
.await
} else {
// market without confidential inputs
let ivs_url = &markets.get(&market_id.to_string()).unwrap().ivs_url;
Expand All @@ -101,7 +106,9 @@ pub async fn generate_proof(
generate_proof_params.skip_input_verification,
);

non_confidential_prover.get_proof().await
non_confidential_prover
.get_proof(valid_proof_semaphore, invalid_inputs_semaphore)
.await
}
}

Expand Down
28 changes: 26 additions & 2 deletions listener/src/proof_generator/prover.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use ethers::types::Bytes;
use reqwest::Client;
use serde::{de::DeserializeOwned, Serialize};
use std::error::Error;
use std::{error::Error, sync::Arc};
use tokio::sync::Semaphore;

#[derive(Debug, Clone)]
pub enum Proof {
Expand All @@ -24,15 +25,31 @@ pub trait Prover {

fn should_skip_input_verification(&self) -> bool;

async fn get_proof(&self) -> Result<Proof, Box<dyn Error>> {
async fn get_proof(
&self,
valid_proof_semaphore: Arc<Semaphore>,
invalid_inputs_semaphore: Arc<Semaphore>,
) -> Result<Proof, Box<dyn Error>> {
if self.should_skip_input_verification() {
let proof_permit = valid_proof_semaphore
.acquire()
.await
.expect("Failed to acquire proof semaphore");
let proof = self.generate_proof().await?;
drop(proof_permit); // not needed but explicity dropping it
return Ok(Proof::ValidProof(proof.proof.into()));
}

let check_input = self.check_inputs().await?;
if check_input.valid {
let proof_permit = valid_proof_semaphore
.acquire()
.await
.expect("Failed to acquire proof semaphore");

let proof = self.generate_proof().await?;

drop(proof_permit); // not needed but explicity dropping it
let check_proof = self.verify_inputs_and_proof(proof.proof.as_ref()).await;
match check_proof {
Ok(data) => {
Expand All @@ -49,7 +66,14 @@ pub trait Prover {
}
Ok(Proof::ValidProof(proof.proof.into()))
} else {
let invalid_inputs_permit = invalid_inputs_semaphore
.acquire()
.await
.expect("Failed to acquire proof semaphore");

let proof = self.generate_attestation_for_invalid_inputs().await?;

drop(invalid_inputs_permit);
Ok(Proof::InvalidProof(proof.proof.into()))
}
}
Expand Down

0 comments on commit 7ef231a

Please sign in to comment.