Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hint function enhancements #253

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions examples/fixture/asm/kimchi/hint.asm
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,11 @@ DoubleGeneric<1,0,0,0,-2>
DoubleGeneric<2,0,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,-1>
DoubleGeneric<1,0,-1,0,2>
DoubleGeneric<1,-1>
DoubleGeneric<1,0,0,0,-16>
DoubleGeneric<1,0,0,0,-32>
DoubleGeneric<1,0,0,0,-4>
DoubleGeneric<1,-1>
DoubleGeneric<1,0,0,0,-3>
DoubleGeneric<1,0,-1,0,1>
Expand All @@ -20,11 +23,12 @@ DoubleGeneric<1,0,-1,0,1>
DoubleGeneric<1,0,0,0,-1>
DoubleGeneric<1,0,0,0,-1>
DoubleGeneric<1,-1>
(0,0) -> (18,0)
(1,0) -> (2,0) -> (9,1)
(0,0) -> (21,0)
(1,0) -> (2,0) -> (12,1)
(4,0) -> (6,1)
(4,2) -> (5,1)
(5,0) -> (7,1) -> (18,1)
(9,0) -> (11,0)
(14,1) -> (15,0)
(15,2) -> (16,0)
(5,0) -> (7,0) -> (21,1)
(7,2) -> (8,1)
(12,0) -> (14,0)
(17,1) -> (18,0)
(18,2) -> (19,0)
14 changes: 8 additions & 6 deletions examples/fixture/asm/r1cs/hint.asm
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
2 == (v_3) * (1)
2 * v_5 == (v_4) * (1)
v_5 == (v_6) * (1)
v_4 == (v_7) * (1)
v_4 + 2 == (v_7) * (1)
16 == (v_8) * (1)
v_2 == (v_9) * (1)
3 == (v_10) * (1)
1 == (v_11) * (1)
1 == (v_12) * (1)
1 == (-1 * v_13 + 1) * (1)
32 == (v_9) * (1)
4 == (v_10) * (1)
v_2 == (v_11) * (1)
3 == (v_12) * (1)
1 == (v_13) * (1)
1 == (v_14) * (1)
1 == (-1 * v_15 + 1) * (1)
1 == (v_16) * (1)
v_4 == (v_1) * (1)
38 changes: 30 additions & 8 deletions examples/hint.no
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ struct Thing {
yy: Field,
}

fn init_arr(const LEN: Field) -> [Field; LEN] {
return [0; LEN];
}

hint fn cst_div(const lhs: Field, const rhs: Field) -> Field {
return lhs / rhs;
}

hint fn mul(lhs: Field, rhs: Field) -> Field {
return lhs * rhs;
}
Expand All @@ -17,19 +25,23 @@ hint fn div(lhs: Field, rhs: Field) -> Field {
}

hint fn ite(lhs: Field, rhs: Field) -> Field {
return if lhs != rhs { lhs } else { rhs };
return if lhs != rhs { lhs + 2 } else { rhs * 2 };
}

hint fn exp(const EXP: Field, val: Field) -> Field {
let mut res = val;

for num in 1..EXP {
res = res * val;
}
hint fn exp(const EXP: Field, base: Field) -> Field {
let res = base ** EXP;

return res;
}

hint fn lshift(val: Field, shift: Field) -> Field {
return val << shift;
}

hint fn rem(lhs: Field, rhs: Field) -> Field {
return lhs % rhs;
}

hint fn sub(lhs: Field, rhs: Field) -> Field {
return lhs - rhs;
}
Expand Down Expand Up @@ -63,11 +75,17 @@ fn main(pub public_input: Field, private_input: Field) -> Field {
assert_eq(zz, yy);

let ww = unsafe ite(xx, yy);
assert_eq(ww, xx);
assert_eq(ww, xx + 2);

let kk = unsafe exp(4, public_input);
assert_eq(kk, 16);

let k2 = unsafe lshift(public_input, 4);
assert_eq(k2, 32);

let ll = unsafe rem(kk, 12);
assert_eq(ll, 4);

let thing = unsafe multiple_inputs_outputs([public_input, 3]);
// have to include all the outputs from hint function, otherwise it throws vars not in circuit error.
// this is because each individual element in the hint output maps to a separate cell var in noname.
Expand All @@ -82,5 +100,9 @@ fn main(pub public_input: Field, private_input: Field) -> Field {
assert(!oo[1]);
assert(oo[2]);

// mast phase can fold the constant value using hint functions
let one = unsafe cst_div(2, 2);
let arr = init_arr(one);
assert_eq(arr[0], 0);
return xx;
}
22 changes: 16 additions & 6 deletions src/backends/kimchi/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ pub struct KimchiVesta {

/// This is how you compute the value of each variable during witness generation.
/// It is created during circuit generation.
pub(crate) vars_to_value: HashMap<usize, Value<Self>>,
pub(crate) vars_to_value: HashMap<usize, (Value<Self>, Span)>,

/// The execution trace table with vars as placeholders.
/// It is created during circuit generation,
Expand Down Expand Up @@ -303,7 +303,7 @@ impl Backend for KimchiVesta {
self.next_variable += 1;

// store it in the circuit_writer
self.vars_to_value.insert(var.index, val);
self.vars_to_value.insert(var.index, (val, span));

var
}
Expand Down Expand Up @@ -361,6 +361,16 @@ impl Backend for KimchiVesta {

for var in 0..self.next_variable {
if !written_vars.contains(&var) && !disable_safety_check {
let (val, span) = self
.vars_to_value
.get(&var)
.expect("a var should be in vars_to_value");

if matches!(val, Value::HintIR(..)) {
println!("a HintIR value not used in the circuit: {:?}", span);
continue;
}

if let Some(private_cell_var) = self
.private_input_cell_vars
.iter()
Expand All @@ -374,7 +384,7 @@ impl Backend for KimchiVesta {
);
Err(err)?;
} else {
Err(Error::new("contraint-finalization", ErrorKind::UnexpectedError("there's a bug in the circuit_writer, some cellvar does not end up being a cellvar in the circuit!"), Span::default()))?;
Err(Error::new("contraint-finalization", ErrorKind::UnexpectedError("there's a bug in the circuit_writer, some cellvar does not end up being a cellvar in the circuit!"), *span))?;
}
}
}
Expand All @@ -399,7 +409,7 @@ impl Backend for KimchiVesta {
let var_idx = pub_var.cvar().unwrap().index;
let prev = self
.vars_to_value
.insert(var_idx, Value::PublicOutput(Some(ret_var)));
.insert(var_idx, (Value::PublicOutput(Some(ret_var)), ret_var.span));
assert!(prev.is_some());
}
}
Expand All @@ -415,7 +425,7 @@ impl Backend for KimchiVesta {
var: &Self::Var,
) -> crate::error::Result<Self::Field> {
let val = self.vars_to_value.get(&var.index).unwrap();
self.compute_val(env, val, var.index)
self.compute_val(env, &val.0, var.index)
}

fn generate_witness<B: Backend>(
Expand All @@ -442,7 +452,7 @@ impl Backend for KimchiVesta {
// if it's a public output, defer it's computation
if matches!(
self.vars_to_value.get(&var.index),
Some(Value::PublicOutput(_))
Some((Value::PublicOutput(_), ..))
) {
public_outputs_vars
.entry(*var)
Expand Down
47 changes: 19 additions & 28 deletions src/backends/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use std::{fmt::Debug, str::FromStr};

use ::kimchi::o1_utils::FieldHelpers;
use ark_ff::{Field, One, PrimeField, Zero};
use circ::ir::term::precomp::PreComp;
use ark_ff::{Field, One, Zero};
use fxhash::FxHashMap;
use num_bigint::BigUint;

use crate::{
circuit_writer::VarInfo,
circuit_writer::{ir::IRWriter, VarInfo},
compiler::Sources,
constants::Span,
error::{Error, ErrorKind, Result},
Expand Down Expand Up @@ -204,14 +203,9 @@ pub trait Backend: Clone {

Ok(res)
}
Value::HintIR(t, named_vars) => {
let mut precomp = PreComp::new();
// For hint evaluation purpose, precomp only has only one output and no connections with other parts,
// so just use a dummy output var name.
precomp.add_output("x".to_string(), t.clone());

Value::HintIR(t, named_vars, logs) => {
// map the named vars to env
let env = named_vars
let ir_env = named_vars
.iter()
.map(|(name, var)| {
let val = match var {
Expand All @@ -225,24 +219,21 @@ pub trait Backend: Clone {
})
.collect::<FxHashMap<String, circ::ir::term::Value>>();

// evaluate and get the only one output
let eval_map = precomp.eval(&env);
let value = eval_map.get("x").unwrap();
// convert to field
let res = match value {
circ::ir::term::Value::Field(f) => {
let bytes = f.i().to_digits::<u8>(rug::integer::Order::Lsf);
Self::Field::from_le_bytes_mod_order(&bytes)
}
circ::ir::term::Value::Bool(b) => {
if *b {
Self::Field::one()
} else {
Self::Field::zero()
}
}
_ => panic!("unexpected output type"),
};
// evaluate logs
for log in logs {
// check the cache on env and log, and only evaluate if not in cache
let res: Vec<Self::Field> = IRWriter::<Self>::eval_ir(&ir_env, log);
// format and print out array
println!(
"log: {:#?}",
res.iter().map(|f| f.pretty()).collect::<Vec<String>>()
);
}

// evaluate the term
let res = IRWriter::<Self>::eval_ir(&ir_env, t)[0];

env.cached_values.insert(cache_key, res); // cache

Ok(res)
}
Expand Down
26 changes: 16 additions & 10 deletions src/backends/r1cs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ where
{
/// Constraints in the r1cs.
constraints: Vec<Constraint<F>>,
witness_vector: Vec<Value<Self>>,
witness_vector: Vec<(Value<Self>, Span)>,
/// Debug information for each constraint.
debug_info: Vec<DebugInfo>,
/// Debug information for var info.
Expand Down Expand Up @@ -384,7 +384,7 @@ where
span,
};

self.witness_vector.insert(var.index, val);
self.witness_vector.insert(var.index, (val, span));

LinearCombination::from(var)
}
Expand Down Expand Up @@ -419,8 +419,9 @@ where
// replace the computation of the public output vars with the actual variables being returned here
let var_idx = pub_var.cvar().unwrap().to_cell_var().index;
let prev = &self.witness_vector[var_idx];
assert!(matches!(prev, Value::PublicOutput(None)));
self.witness_vector[var_idx] = Value::PublicOutput(Some(ret_var));
assert!(matches!(prev.0, Value::PublicOutput(None)));
self.witness_vector[var_idx] =
(Value::PublicOutput(Some(ret_var.clone())), ret_var.span);
}
}

Expand All @@ -435,7 +436,7 @@ where
}

// check if every cell vars end up being a cell var in the circuit or public output
for (index, _) in self.witness_vector.iter().enumerate() {
for (index, (val, span)) in self.witness_vector.iter().enumerate() {
// Skip the first var which is always 1
// - In a linear combination, each of the vars can be paired with a coefficient.
// - The first var is assumed to be the factor of the constant of a linear combination.
Expand All @@ -444,6 +445,11 @@ where
}

if !written_vars.contains(&index) && !disable_safety_check {
// ignore HintIR val
if let Value::HintIR(..) = val {
println!("a HintIR value not used in the circuit: {:?}", span);
continue;
}
if let Some(private_cell_var) = self
.private_input_cell_vars
.iter()
Expand All @@ -458,7 +464,7 @@ where
Err(Error::new(
"constraint-finalization",
ErrorKind::UnexpectedError("there's a bug in the circuit_writer, some cellvar does not end up being a cellvar in the circuit!"),
Span::default(),
*span,
))?
}
}
Expand All @@ -478,7 +484,7 @@ where

for (var, factor) in &lc.terms {
let var_val = self.witness_vector.get(var.index).unwrap();
let calc = self.compute_val(env, var_val, var.index)? * factor;
let calc = self.compute_val(env, &var_val.0, var.index)? * factor;
val += calc;
}

Expand All @@ -501,11 +507,11 @@ where
.iter()
.enumerate()
.map(|(index, val)| {
match val {
match val.0 {
// Defer calculation for output vars.
// The reasoning behind this is to avoid deep recursion potentially triggered by the public output var at the beginning.
Value::PublicOutput(_) => Ok(F::zero()),
_ => self.compute_val(witness_env, val, index),
_ => self.compute_val(witness_env, &val.0, index),
}
})
.collect::<crate::error::Result<Vec<F>>>()?;
Expand Down Expand Up @@ -973,7 +979,7 @@ mod tests {

// first var should be initialized as 1
assert_eq!(r1cs.witness_vector.len(), 1);
match &r1cs.witness_vector[0] {
match &r1cs.witness_vector[0].0 {
crate::var::Value::Constant(cst) => {
assert_eq!(*cst, R1csBls12381Field::one());
}
Expand Down
Loading
Loading