Skip to content

Commit

Permalink
feat(hints): Implement hint on sub_reduced_a_and_reduced_b (#1090)
Browse files Browse the repository at this point in the history
* Remove wrong hint

* Implement hint SUB_REDUCED_A_AND_REDUCED_B

* Test hint

* Add changelog entry

* Prevent sub overflow

* Remove unused import

* Fix memory hole count

* Add suggestion + unit tests

* Fix memory hole count

* Fix benchmark file

* fix typo
  • Loading branch information
fmoletta authored Apr 28, 2023
1 parent 4bc7c54 commit eaea41f
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 342 deletions.
65 changes: 32 additions & 33 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,38 @@

#### Upcoming Changes

* Implement hint on field_arithmetic lib [#1090](https://github.com/lambdaclass/cairo-rs/pull/1090)

`BuiltinHintProcessor` now supports the following hints:

```python
%{
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)

def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

a = pack(ids.a, num_bits_shift = 128)
b = pack(ids.b, num_bits_shift = 128)
p = pack(ids.p, num_bits_shift = 128)

res = (a - b) % p


res_split = split(res, num_bits_shift=128, length=3)

ids.res.d0 = res_split[0]
ids.res.d1 = res_split[1]
ids.res.d2 = res_split[2]
%}
```

* Add missing hint on cairo_secp lib [#1089](https://github.com/lambdaclass/cairo-rs/pull/1089):
`BuiltinHintProcessor` now supports the following hint:

Expand Down Expand Up @@ -1203,39 +1235,6 @@
ids.carry_d2 = 1 if sum_d2 >= ids.SHIFT else 0
```

```python
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)

def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

def pack2(z, num_bits_shift: int) -> int:
limbs = (z.b01, z.b23, z.b45)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))

a = pack(ids.a, num_bits_shift = 128)
div = pack2(ids.div, num_bits_shift = 128)
quotient, remainder = divmod(a, div)

quotient_split = split(quotient, num_bits_shift=128, length=3)
assert len(quotient_split) == 3

ids.quotient.d0 = quotient_split[0]
ids.quotient.d1 = quotient_split[1]
ids.quotient.d2 = quotient_split[2]

remainder_split = split(remainder, num_bits_shift=128, length=3)
ids.remainder.d0 = remainder_split[0]
ids.remainder.d1 = remainder_split[1]
ids.remainder.d2 = remainder_split[2]
```

```python
from starkware.python.math_utils import isqrt

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin
from starkware.cairo.common.bool import TRUE
from cairo_programs.uint384 import u384, Uint384, Uint384_expand
from cairo_programs.uint384 import u384, Uint384
from cairo_programs.uint384_extension import u384_ext
from cairo_programs.field_arithmetic import field_arithmetic

Expand Down
67 changes: 67 additions & 0 deletions cairo_programs/field_arithmetic.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,17 @@ from cairo_programs.uint384_extension import u384_ext, Uint768

// Functions for operating elements in a finite field F_p (i.e. modulo a prime p), with p of at most 384 bits
namespace field_arithmetic {
// Computes (a + b) modulo p .
func add{range_check_ptr}(a: Uint384, b: Uint384, p: Uint384) -> (res: Uint384) {
let (sum: Uint384, carry) = u384.add(a, b);
let sum_with_carry: Uint768 = Uint768(sum.d0, sum.d1, sum.d2, carry, 0, 0);

let (
quotient: Uint768, remainder: Uint384
) = u384_ext.unsigned_div_rem_uint768_by_uint384(sum_with_carry, p);
return (remainder,);
}

// Computes a * b modulo p
func mul{range_check_ptr}(a: Uint384, b: Uint384, p: Uint384) -> (res: Uint384) {
let (low: Uint384, high: Uint384) = u384.mul_d(a, b);
Expand Down Expand Up @@ -267,6 +278,49 @@ namespace field_arithmetic {
let (res: Uint384) = mul(a, b_inverse_mod_p, p);
return (res,);
}

// Computes (a - b) modulo p .
// NOTE: Expects a and b to be reduced modulo p (i.e. between 0 and p-1). The function will revert if a > p.
// NOTE: To reduce a, take the remainder of uint384_lin.unsigned_div_rem(a, p), and similarly for b.
// @dev First it computes res =(a-b) mod p in a hint and then checks outside of the hint that res + b = a modulo p
func sub_reduced_a_and_reduced_b{range_check_ptr}(a: Uint384, b: Uint384, p: Uint384) -> (
res: Uint384
) {
alloc_locals;
local res: Uint384;
%{
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)
def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
a = pack(ids.a, num_bits_shift = 128)
b = pack(ids.b, num_bits_shift = 128)
p = pack(ids.p, num_bits_shift = 128)
res = (a - b) % p
res_split = split(res, num_bits_shift=128, length=3)
ids.res.d0 = res_split[0]
ids.res.d1 = res_split[1]
ids.res.d2 = res_split[2]
%}
u384.check(res);
let (is_valid) = u384.lt(res, p);
assert is_valid = 1;
let (b_plus_res) = add(b, res, p);
assert b_plus_res = a;
return (res,);
}

}

func test_field_arithmetics_extension_operations{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}() {
Expand Down Expand Up @@ -365,9 +419,22 @@ func test_u256_get_square_root{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}()
return ();
}

func test_sub_reduced_a_and_reduced_b{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(){
let a = Uint384(1, 1, 1);
let b = Uint384(2, 2, 2);
let p = Uint384(7, 7, 7);
let (r) = field_arithmetic.sub_reduced_a_and_reduced_b(a, b, p);
assert r.d0 = 6;
assert r.d1 = 6;
assert r.d2 = 6;

return ();
}

func main{range_check_ptr: felt, bitwise_ptr: BitwiseBuiltin*}() {
test_field_arithmetics_extension_operations();
test_u256_get_square_root();
test_sub_reduced_a_and_reduced_b();

return ();
}
56 changes: 0 additions & 56 deletions cairo_programs/uint384.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -358,62 +358,6 @@ namespace u384 {
return (quotient=quotient, remainder=remainder);
}

// Unsigned integer division between two integers. Returns the quotient and the remainder.
func unsigned_div_rem_expanded{range_check_ptr}(a: Uint384, div: Uint384_expand) -> (
quotient: Uint384, remainder: Uint384
) {
alloc_locals;
local quotient: Uint384;
local remainder: Uint384;

let div2 = Uint384(div.b01, div.b23, div.b45);

%{
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
num = num >> num_bits_shift
return tuple(a)
def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
def pack2(z, num_bits_shift: int) -> int:
limbs = (z.b01, z.b23, z.b45)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
a = pack(ids.a, num_bits_shift = 128)
div = pack2(ids.div, num_bits_shift = 128)
quotient, remainder = divmod(a, div)
quotient_split = split(quotient, num_bits_shift=128, length=3)
assert len(quotient_split) == 3
ids.quotient.d0 = quotient_split[0]
ids.quotient.d1 = quotient_split[1]
ids.quotient.d2 = quotient_split[2]
remainder_split = split(remainder, num_bits_shift=128, length=3)
ids.remainder.d0 = remainder_split[0]
ids.remainder.d1 = remainder_split[1]
ids.remainder.d2 = remainder_split[2]
%}
check(quotient);
check(remainder);
let (res_mul: Uint384, carry: Uint384) = mul_expanded(quotient, div);
assert carry = Uint384(0, 0, 0);

let (check_val: Uint384, add_carry: felt) = _add_no_uint384_check(res_mul, remainder);
assert check_val = a;
assert add_carry = 0;

let (is_valid) = lt(remainder, div2);
assert is_valid = 1;
return (quotient=quotient, remainder=remainder);
}

func square_e{range_check_ptr}(a: Uint384) -> (low: Uint384, high: Uint384) {
alloc_locals;
let (a0, a1) = split_64(a.d0);
Expand Down
18 changes: 1 addition & 17 deletions cairo_programs/uint384_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,6 @@ func test_uint384_operations{range_check_ptr}() {
assert sum_res.d2 = 7;
assert carry = 1;

// Test unsigned_div_rem_expanded
let e = Uint384(83434123481193248, 82349321849739284, 839243219401320423);
let div_expand = Uint384_expand(
9283430921839492319493, 313248123482483248, 3790328402913840, 13, 78990, 109, 7
);
let (quotient: Uint384, remainder: Uint384) = u384.unsigned_div_rem_expanded{
range_check_ptr=range_check_ptr
}(a, div_expand);
assert quotient.d0 = 7699479077076334;
assert quotient.d1 = 0;
assert quotient.d2 = 0;

assert remainder.d0 = 340279955073565776659831804641277151872;
assert remainder.d1 = 340282366920938463463356863525615958397;
assert remainder.d2 = 16;

// Test sqrt
let f = Uint384(83434123481193248, 82349321849739284, 839243219401320423);
let (root) = u384.sqrt(f);
Expand All @@ -65,7 +49,7 @@ func test_uint384_operations{range_check_ptr}() {
let (sign_h) = u384.signed_nn(h);
assert sign_h = 0;

return ();
return();
}

func main{range_check_ptr: felt}() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use super::{
},
secp_utils::{ALPHA, ALPHA_V2, SECP_P, SECP_P_V2},
},
uint384::sub_reduced_a_and_reduced_b,
vrf::{
fq::{inv_mod_p_uint256, uint512_unsigned_div_rem},
inv_mod_p_uint512::inv_mod_p_uint512,
Expand Down Expand Up @@ -89,7 +90,7 @@ use crate::{
},
uint384::{
add_no_uint384_check, uint384_signed_nn, uint384_split_128, uint384_sqrt,
uint384_unsigned_div_rem, uint384_unsigned_div_rem_expanded,
uint384_unsigned_div_rem,
},
uint384_extension::unsigned_div_rem_uint768_by_uint384,
usort::{
Expand Down Expand Up @@ -721,16 +722,16 @@ impl HintProcessor for BuiltinHintProcessor {
hint_code::ADD_NO_UINT384_CHECK => {
add_no_uint384_check(vm, &hint_data.ids_data, &hint_data.ap_tracking, constants)
}
hint_code::UINT384_UNSIGNED_DIV_REM_EXPANDED => {
uint384_unsigned_div_rem_expanded(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::UINT384_SQRT => {
uint384_sqrt(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::UNSIGNED_DIV_REM_UINT768_BY_UINT384
| hint_code::UNSIGNED_DIV_REM_UINT768_BY_UINT384_STRIPPED => {
unsigned_div_rem_uint768_by_uint384(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::SUB_REDUCED_A_AND_REDUCED_B => {
sub_reduced_a_and_reduced_b(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
hint_code::UINT384_GET_SQUARE_ROOT => {
u384_get_square_root(vm, &hint_data.ids_data, &hint_data.ap_tracking)
}
Expand Down
54 changes: 24 additions & 30 deletions src/hint_processor/builtin_hint_processor/hint_code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1017,8 +1017,9 @@ sum_d1 = ids.a.d1 + ids.b.d1 + ids.carry_d0
ids.carry_d1 = 1 if sum_d1 >= ids.SHIFT else 0
sum_d2 = ids.a.d2 + ids.b.d2 + ids.carry_d1
ids.carry_d2 = 1 if sum_d2 >= ids.SHIFT else 0";
pub const UINT384_UNSIGNED_DIV_REM_EXPANDED: &str =
"def split(num: int, num_bits_shift: int, length: int):
pub const UINT384_SQRT: &str = "from starkware.python.math_utils import isqrt
def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
Expand All @@ -1029,28 +1030,16 @@ def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
def pack2(z, num_bits_shift: int) -> int:
limbs = (z.b01, z.b23, z.b45)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
a = pack(ids.a, num_bits_shift = 128)
div = pack2(ids.div, num_bits_shift = 128)
quotient, remainder = divmod(a, div)
quotient_split = split(quotient, num_bits_shift=128, length=3)
assert len(quotient_split) == 3
ids.quotient.d0 = quotient_split[0]
ids.quotient.d1 = quotient_split[1]
ids.quotient.d2 = quotient_split[2]
remainder_split = split(remainder, num_bits_shift=128, length=3)
ids.remainder.d0 = remainder_split[0]
ids.remainder.d1 = remainder_split[1]
ids.remainder.d2 = remainder_split[2]";
pub const UINT384_SQRT: &str = "from starkware.python.math_utils import isqrt
a = pack(ids.a, num_bits_shift=128)
root = isqrt(a)
assert 0 <= root < 2 ** 192
root_split = split(root, num_bits_shift=128, length=3)
ids.root.d0 = root_split[0]
ids.root.d1 = root_split[1]
ids.root.d2 = root_split[2]";

def split(num: int, num_bits_shift: int, length: int):
pub const SUB_REDUCED_A_AND_REDUCED_B: &str =
"def split(num: int, num_bits_shift: int, length: int):
a = []
for _ in range(length):
a.append( num & ((1 << num_bits_shift) - 1) )
Expand All @@ -1061,13 +1050,18 @@ def pack(z, num_bits_shift: int) -> int:
limbs = (z.d0, z.d1, z.d2)
return sum(limb << (num_bits_shift * i) for i, limb in enumerate(limbs))
a = pack(ids.a, num_bits_shift=128)
root = isqrt(a)
assert 0 <= root < 2 ** 192
root_split = split(root, num_bits_shift=128, length=3)
ids.root.d0 = root_split[0]
ids.root.d1 = root_split[1]
ids.root.d2 = root_split[2]";
a = pack(ids.a, num_bits_shift = 128)
b = pack(ids.b, num_bits_shift = 128)
p = pack(ids.p, num_bits_shift = 128)
res = (a - b) % p
res_split = split(res, num_bits_shift=128, length=3)
ids.res.d0 = res_split[0]
ids.res.d1 = res_split[1]
ids.res.d2 = res_split[2]";

pub const UNSIGNED_DIV_REM_UINT768_BY_UINT384: &str =
"def split(num: int, num_bits_shift: int, length: int):
Expand Down
Loading

1 comment on commit eaea41f

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Performance Alert ⚠️

Possible performance regression was detected for benchmark.
Benchmark result of this commit is worse than the previous benchmark result exceeding threshold 1.30.

Benchmark suite Current: eaea41f Previous: 4bc7c54 Ratio
parse program 26627444 ns/iter (± 1386493) 18555387 ns/iter (± 324640) 1.44
build runner 3904587 ns/iter (± 145666) 2445689 ns/iter (± 1084) 1.60

This comment was automatically generated by workflow using github-action-benchmark.

CC: @unbalancedparentheses

Please sign in to comment.