Skip to content

Commit

Permalink
fix uint range check
Browse files Browse the repository at this point in the history
  • Loading branch information
dark64 committed May 16, 2023
1 parent 633ee82 commit 53f2a7d
Show file tree
Hide file tree
Showing 11 changed files with 287 additions and 28 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ members = [
"zokrates_profiler",
]

[profile.dev]
opt-level = 1
# [profile.dev]
# opt-level = 1
4 changes: 2 additions & 2 deletions zokrates_ast/src/common/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub enum RuntimeError {
BranchIsolation,
ConstantLtBitness,
ConstantLtSum,
LtFinalSum,
LtSum,
LtSymetric,
Or,
Xor,
Expand Down Expand Up @@ -83,7 +83,7 @@ impl fmt::Display for RuntimeError {
BranchIsolation => "Branch isolation failed",
ConstantLtBitness => "Bitness check failed in constant Lt check",
ConstantLtSum => "Sum check failed in constant Lt check",
LtFinalSum => "Sum check failed in final Lt check",
LtSum => "Sum check failed in Lt check",
LtSymetric => "Symetrical check failed in Lt check",
Or => "Or check failed",
Xor => "Xor check failed",
Expand Down
132 changes: 111 additions & 21 deletions zokrates_codegen/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
lhs_flattened: FlatExpression<T>,
rhs_flattened: FlatExpression<T>,
bit_width: usize,
error: RuntimeError,
) -> FlatExpression<T> {
FlatExpression::add(
self.eq_check(
Expand All @@ -815,6 +816,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
lhs_flattened,
rhs_flattened,
bit_width,
error,
),
)
}
Expand All @@ -826,6 +828,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
lhs_flattened: FlatExpression<T>,
rhs_flattened: FlatExpression<T>,
bit_width: usize,
error: RuntimeError,
) -> FlatExpression<T> {
match (lhs_flattened, rhs_flattened) {
(x, FlatExpression::Value(constant)) => {
Expand All @@ -841,7 +844,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let lhs_id = self.define(lhs_flattened, statements_flattened);
let rhs_id = self.define(rhs_flattened, statements_flattened);

// shifted_sub := 2**safe_width + lhs - rhs
// shifted_sub := 2**bit_width + lhs - rhs
let shifted_sub = FlatExpression::add(
FlatExpression::value(T::from(2).pow(bit_width)),
FlatExpression::sub(
Expand All @@ -857,7 +860,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
sub_width,
sub_width,
statements_flattened,
RuntimeError::IncompleteDynamicRange,
error,
);

FlatExpression::sub(
Expand Down Expand Up @@ -910,6 +913,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
lhs_flattened,
rhs_flattened,
safe_width,
RuntimeError::IncompleteDynamicRange,
)
}
BooleanExpression::BoolEq(e) => {
Expand Down Expand Up @@ -1003,6 +1007,7 @@ impl<'ast, T: Field> Flattener<'ast, T> {
lhs_flattened,
rhs_flattened,
bit_width,
RuntimeError::LtSum,
)
}
BooleanExpression::UintLe(e) => {
Expand Down Expand Up @@ -2016,18 +2021,39 @@ impl<'ast, T: Field> Flattener<'ast, T> {
let res = match self.bits_cache.entry(e.field.clone().unwrap()) {
Entry::Occupied(entry) => {
let res: Vec<_> = entry.get().clone();
// if we already know a decomposition, its number of elements has to be smaller or equal to `to`
assert!(res.len() <= to);

// we then pad it with zeroes on the left (big endian) to return `to` bits
if res.len() == to {
res
} else {
(0..to - res.len())
if res.len() > to {
// if the result is bigger than `to`, we sum higher bits up to `to`
let bit_sum = res[..res.len() - to]
.iter()
.cloned()
.fold(FlatExpression::from(T::zero()), |acc, e| {
FlatExpression::add(acc, e)
});

// sum of higher bits must be zero
statements_flattened.push_back(FlatStatement::condition(
FlatExpression::value(T::from(0)),
bit_sum,
error,
));

// truncate to the `to` lowest bits
let res = res[res.len() - to..].to_vec();
assert_eq!(res.len(), to);

return res;
}

// if result is smaller than `to` we pad it with zeroes on the left (big endian) to return `to` bits
if res.len() < to {
return (0..to - res.len())
.map(|_| FlatExpression::value(T::zero()))
.chain(res)
.collect()
.collect();
}

res
}
Entry::Vacant(_) => {
let bits = (0..from).map(|_| self.use_sym()).collect::<Vec<_>>();
Expand Down Expand Up @@ -2554,7 +2580,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(lhs, rhs) => {
let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // dynamic comparison is not complete
let e = self.lt_check(statements_flattened, lhs, rhs, safe_width);
let e = self.lt_check(
statements_flattened,
lhs,
rhs,
safe_width,
RuntimeError::IncompleteDynamicRange,
);
statements_flattened.push_back(FlatStatement::condition(
e,
FlatExpression::value(T::one()),
Expand Down Expand Up @@ -2584,7 +2616,13 @@ impl<'ast, T: Field> Flattener<'ast, T> {
(lhs, rhs) => {
let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // dynamic comparison is not complete
let e = self.le_check(statements_flattened, lhs, rhs, safe_width);
let e = self.le_check(
statements_flattened,
lhs,
rhs,
safe_width,
RuntimeError::IncompleteDynamicRange,
);
statements_flattened.push_back(FlatStatement::condition(
e,
FlatExpression::value(T::one()),
Expand All @@ -2593,32 +2631,84 @@ impl<'ast, T: Field> Flattener<'ast, T> {
}
}
}
BooleanExpression::UintLe(e) => {
BooleanExpression::UintLt(e) => {
let bitwidth = e.left.bitwidth as usize;
let lhs = self
.flatten_uint_expression(statements_flattened, *e.left)
.get_field_unchecked();

let rhs = self
.flatten_uint_expression(statements_flattened, *e.right)
.get_field_unchecked();

match (lhs, rhs) {
(e, FlatExpression::Value(c)) => self.enforce_constant_le_check(
(e, FlatExpression::Value(c)) => self.enforce_constant_lt_check(
statements_flattened,
e,
c.value,
error.into(),
),
// c <= e <=> p - 1 - e <= p - 1 - c
(FlatExpression::Value(c), e) => self.enforce_constant_le_check(
// c < e <=> 2^bw - 1 - e < 2^bw - 1 - c
(FlatExpression::Value(c), e) => {
let max = T::from(2u32).pow(bitwidth) - T::one();
self.enforce_constant_lt_check(
statements_flattened,
FlatExpression::sub(max.into(), e),
max - c.value,
error.into(),
)
}
(lhs, rhs) => {
let e = self.lt_check(
statements_flattened,
lhs,
rhs,
bitwidth,
RuntimeError::LtSum,
);
statements_flattened.push_back(FlatStatement::condition(
e,
FlatExpression::value(T::one()),
error.into(),
));
}
}
}
BooleanExpression::UintLe(e) => {
let bitwidth = e.left.bitwidth as usize;
let lhs = self
.flatten_uint_expression(statements_flattened, *e.left)
.get_field_unchecked();

let rhs = self
.flatten_uint_expression(statements_flattened, *e.right)
.get_field_unchecked();

match (lhs, rhs) {
(e, FlatExpression::Value(c)) => self.enforce_constant_le_check(
statements_flattened,
FlatExpression::sub(T::max_value().into(), e),
T::max_value() - c.value,
e,
c.value,
error.into(),
),
// c < e <=> 2^bw - 1 - e < 2^bw - 1 - c
(FlatExpression::Value(c), e) => {
let max = T::from(2u32).pow(bitwidth) - T::one();
self.enforce_constant_le_check(
statements_flattened,
FlatExpression::sub(max.into(), e),
max - c.value,
error.into(),
)
}
(lhs, rhs) => {
let bit_width = T::get_required_bits();
let safe_width = bit_width - 2; // dynamic comparison is not complete
let e = self.le_check(statements_flattened, lhs, rhs, safe_width);
let e = self.le_check(
statements_flattened,
lhs,
rhs,
bitwidth,
RuntimeError::LtSum,
);
statements_flattened.push_back(FlatStatement::condition(
e,
FlatExpression::value(T::one()),
Expand Down
57 changes: 57 additions & 0 deletions zokrates_core_test/tests/tests/range_check/assert_le_u8.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
{
"entry_point": "./tests/tests/range_check/assert_le_u8.zok",
"max_constraint_count": 12,
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": ["0x00"]
},
"output": {
"Ok": {
"value": []
}
}
},
{
"input": {
"values": ["0x01"]
},
"output": {
"Ok": {
"value": []
}
}
},
{
"input": {
"values": ["0x02"]
},
"output": {
"Ok": {
"value": []
}
}
},
{
"input": {
"values": ["0x0f"]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"error": {
"SourceAssertion": {
"file": "./tests/tests/range_check/assert_le_u8.zok",
"position": {
"line": 2,
"col": 5
}
}
}
}
}
}
}
]
}
3 changes: 3 additions & 0 deletions zokrates_core_test/tests/tests/range_check/assert_le_u8.zok
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def main(u8 x) {
assert(x <= 2);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
{
"entry_point": "./tests/tests/range_check/assert_le_u8_dynamic.zok",
"max_constraint_count": 31,
"curves": ["Bn128"],
"tests": [
{
"input": {
"values": ["0x01", "0x02"]
},
"output": {
"Ok": {
"value": []
}
}
},
{
"input": {
"values": ["0x02", "0x02"]
},
"output": {
"Ok": {
"value": []
}
}
},
{
"input": {
"values": ["0x04", "0x02"]
},
"output": {
"Err": {
"UnsatisfiedConstraint": {
"error": {
"SourceAssertion": {
"file": "./tests/tests/range_check/assert_le_u8_dynamic.zok",
"position": {
"line": 2,
"col": 5
}
}
}
}
}
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
def main(u8 a, u8 b) {
assert(a <= b);
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"entry_point": "./tests/tests/range_check/assert_lt_u8.zok",
"max_constraint_count": 9,
"max_constraint_count": 10,
"curves": ["Bn128"],
"tests": [
{
Expand Down
3 changes: 1 addition & 2 deletions zokrates_core_test/tests/tests/range_check/assert_lt_u8.zok
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
def main(field x) {
def main(u8 x) {
assert(x < 2);
return;
}
Loading

0 comments on commit 53f2a7d

Please sign in to comment.