Skip to content

Commit

Permalink
Fix hax extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
jschneider-bensch committed Dec 18, 2024
1 parent 6f3e276 commit 00424f6
Show file tree
Hide file tree
Showing 5 changed files with 521 additions and 270 deletions.
87 changes: 50 additions & 37 deletions libcrux-ml-dsa/proofs/fstar/extraction/Libcrux_ml_dsa.Sample.fst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,34 @@ let update_seed (seed: t_Array u8 (sz 66)) (domain_separator: u16) =
let hax_temp_output:t_Array u8 (sz 66) = seed in
domain_separator, hax_temp_output <: (u16 & t_Array u8 (sz 66))

let update_matrix
(#v_SIMDUnit: Type0)
(v_ROWS_IN_A v_COLUMNS_IN_A: usize)
(#[FStar.Tactics.Typeclasses.tcresolve ()]
i1:
Libcrux_ml_dsa.Simd.Traits.t_Operations v_SIMDUnit)
(m:
t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A)
(i j: usize)
(v: Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit)
=
let m:t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize m
i
(Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (m.[ i ]
<:
t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
j
v
<:
t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
in
m

let rejection_sample_less_than_eta_equals_2_
(#v_SIMDUnit: Type0)
(#[FStar.Tactics.Typeclasses.tcresolve ()]
Expand Down Expand Up @@ -976,7 +1004,7 @@ let sample_up_to_four_ring_elements
t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A)
(rand_stack: t_Array (t_Array u8 (sz 840)) (sz 4))
(rand_stack0 rand_stack1 rand_stack2 rand_stack3: t_Array u8 (sz 840))
(tmp_stack: t_Slice (t_Array i32 (sz 263)))
(indices: t_Array (u8 & u8) (sz 4))
(elements_requested: usize)
Expand Down Expand Up @@ -1042,10 +1070,6 @@ let sample_up_to_four_ring_elements
(seed2 <: t_Slice u8)
(seed3 <: t_Slice u8)
in
let rand_stack0:t_Array u8 (sz 840) = rand_stack.[ sz 0 ] in
let rand_stack1:t_Array u8 (sz 840) = rand_stack.[ sz 1 ] in
let rand_stack2:t_Array u8 (sz 840) = rand_stack.[ sz 2 ] in
let rand_stack3:t_Array u8 (sz 840) = rand_stack.[ sz 3 ] in
let tmp0, tmp1, tmp2, tmp3, tmp4:(v_Shake128 & t_Array u8 (sz 840) & t_Array u8 (sz 840) &
t_Array u8 (sz 840) &
t_Array u8 (sz 840)) =
Expand All @@ -1067,11 +1091,9 @@ let sample_up_to_four_ring_elements
let sampled1:usize = sz 0 in
let sampled2:usize = sz 0 in
let sampled3:usize = sz 0 in
let tmp0, out:(t_Array u8 (sz 840) & t_Array u8 (sz 840)) = rand_stack0 in
let rand_stack0:t_Array u8 (sz 840) = tmp0 in
let tmp0, tmp1, out:(usize & t_Array i32 (sz 263) & bool) =
rejection_sample_less_than_field_modulus #v_SIMDUnit
(out <: t_Slice u8)
(rand_stack0 <: t_Slice u8)
sampled0
(tmp_stack.[ sz 0 ] <: t_Array i32 (sz 263))
in
Expand All @@ -1080,11 +1102,9 @@ let sample_up_to_four_ring_elements
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize tmp_stack (sz 0) tmp1
in
let done0:bool = out in
let tmp0, out:(t_Array u8 (sz 840) & t_Array u8 (sz 840)) = rand_stack1 in
let rand_stack1:t_Array u8 (sz 840) = tmp0 in
let tmp0, tmp1, out:(usize & t_Array i32 (sz 263) & bool) =
rejection_sample_less_than_field_modulus #v_SIMDUnit
(out <: t_Slice u8)
(rand_stack1 <: t_Slice u8)
sampled1
(tmp_stack.[ sz 1 ] <: t_Array i32 (sz 263))
in
Expand All @@ -1093,11 +1113,9 @@ let sample_up_to_four_ring_elements
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize tmp_stack (sz 1) tmp1
in
let done1:bool = out in
let tmp0, out:(t_Array u8 (sz 840) & t_Array u8 (sz 840)) = rand_stack2 in
let rand_stack2:t_Array u8 (sz 840) = tmp0 in
let tmp0, tmp1, out:(usize & t_Array i32 (sz 263) & bool) =
rejection_sample_less_than_field_modulus #v_SIMDUnit
(out <: t_Slice u8)
(rand_stack2 <: t_Slice u8)
sampled2
(tmp_stack.[ sz 2 ] <: t_Array i32 (sz 263))
in
Expand All @@ -1106,11 +1124,9 @@ let sample_up_to_four_ring_elements
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize tmp_stack (sz 2) tmp1
in
let done2:bool = out in
let tmp0, out:(t_Array u8 (sz 840) & t_Array u8 (sz 840)) = rand_stack3 in
let rand_stack3:t_Array u8 (sz 840) = tmp0 in
let tmp0, tmp1, out:(usize & t_Array i32 (sz 263) & bool) =
rejection_sample_less_than_field_modulus #v_SIMDUnit
(out <: t_Slice u8)
(rand_stack3 <: t_Slice u8)
sampled3
(tmp_stack.[ sz 3 ] <: t_Array i32 (sz 263))
in
Expand Down Expand Up @@ -1246,10 +1262,9 @@ let sample_up_to_four_ring_elements
(bool & bool & bool & bool & usize & usize & usize & usize & v_Shake128 &
t_Slice (t_Array i32 (sz 263))))
in
let matrix, hax_temp_output:(t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A &
Prims.unit) =
let matrix:t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A =
Rust_primitives.Hax.Folds.fold_range (sz 0)
elements_requested
(fun matrix temp_1_ ->
Expand All @@ -1272,28 +1287,26 @@ let sample_up_to_four_ring_elements
let matrix:t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize matrix
update_matrix #v_SIMDUnit
v_ROWS_IN_A
v_COLUMNS_IN_A
matrix
(cast (i <: u8) <: usize)
(Rust_primitives.Hax.Monomorphized_update_at.update_at_usize (matrix.[ cast (i <: u8)
<:
usize ]
<:
t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit)
v_COLUMNS_IN_A)
(cast (j <: u8) <: usize)
(Libcrux_ml_dsa.Polynomial.impl__from_i32_array #v_SIMDUnit
(tmp_stack.[ k ] <: t_Slice i32)
<:
Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit)
(cast (j <: u8) <: usize)
(Libcrux_ml_dsa.Polynomial.impl__from_i32_array #v_SIMDUnit
(tmp_stack.[ k ] <: t_Slice i32)
<:
t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit)
v_COLUMNS_IN_A)
Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit)
in
matrix)
in
matrix, rand_stack, tmp_stack
let hax_temp_output:Prims.unit = () <: Prims.unit in
matrix, rand_stack0, rand_stack1, rand_stack2, rand_stack3, tmp_stack
<:
(t_Array (t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A &
t_Array (t_Array u8 (sz 840)) (sz 4) &
t_Array u8 (sz 840) &
t_Array u8 (sz 840) &
t_Array u8 (sz 840) &
t_Array u8 (sz 840) &
t_Slice (t_Array i32 (sz 263)))
22 changes: 20 additions & 2 deletions libcrux-ml-dsa/proofs/fstar/extraction/Libcrux_ml_dsa.Sample.fsti
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ val generate_domain_separator: (u8 & u8) -> Prims.Pure u16 Prims.l_True (fun _ -
val update_seed (seed: t_Array u8 (sz 66)) (domain_separator: u16)
: Prims.Pure (u16 & t_Array u8 (sz 66)) Prims.l_True (fun _ -> Prims.l_True)

val update_matrix
(#v_SIMDUnit: Type0)
(v_ROWS_IN_A v_COLUMNS_IN_A: usize)
{| i1: Libcrux_ml_dsa.Simd.Traits.t_Operations v_SIMDUnit |}
(m:
t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A)
(i j: usize)
(v: Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit)
: Prims.Pure
(t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A) Prims.l_True (fun _ -> Prims.l_True)

val rejection_sample_less_than_eta_equals_2_
(#v_SIMDUnit: Type0)
{| i1: Libcrux_ml_dsa.Simd.Traits.t_Operations v_SIMDUnit |}
Expand Down Expand Up @@ -122,13 +137,16 @@ val sample_up_to_four_ring_elements
t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A)
(rand_stack: t_Array (t_Array u8 (sz 840)) (sz 4))
(rand_stack0 rand_stack1 rand_stack2 rand_stack3: t_Array u8 (sz 840))
(tmp_stack: t_Slice (t_Array i32 (sz 263)))
(indices: t_Array (u8 & u8) (sz 4))
(elements_requested: usize)
: Prims.Pure
(t_Array
(t_Array (Libcrux_ml_dsa.Polynomial.t_PolynomialRingElement v_SIMDUnit) v_COLUMNS_IN_A)
v_ROWS_IN_A &
t_Array (t_Array u8 (sz 840)) (sz 4) &
t_Array u8 (sz 840) &
t_Array u8 (sz 840) &
t_Array u8 (sz 840) &
t_Array u8 (sz 840) &
t_Slice (t_Array i32 (sz 263))) Prims.l_True (fun _ -> Prims.l_True)
Loading

0 comments on commit 00424f6

Please sign in to comment.