Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Kyber512 and Kyber1024 reference implementation #109

Merged
merged 19 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/hax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,15 @@ jobs:
# the function-extraction code.
# Extract functions from the remaining modules to test the
# module-extraction code.
./extract_to_fstar.py --crate-path specs/kyber --functions compress::compress compress::decompress compress::compress_d compress::decompress_d --modules ind_cpa root matrix ntt parameters sampling serialize
./extract_to_fstar.py --crate-path specs/kyber \
--functions hacspec_kyber::compress::compress \
hacspec_kyber::compress::decompress \
hacspec_kyber::compress::compress_d \
hacspec::kyber::compress::decompress_d \
--modules ind_cpa \
hacspec_kyber \
matrix ntt \
parameters \
sampling \
serialize \
--exclude-modules libcrux::hacl::sha3 libcrux::digest
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
.vscode
.DS_Store
benches/boringssl/build
__pycache__
50 changes: 30 additions & 20 deletions extract_to_fstar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys


def shell(command, expect=0, cwd=None, format_selection_string=False):
def shell(command, expect=0, cwd=None, format_filter_string=False):
subprocess_stdout = subprocess.DEVNULL

print("Command: ", end="")
Expand Down Expand Up @@ -45,7 +45,7 @@ def shell(command, expect=0, cwd=None, format_selection_string=False):
nargs="*",
dest="functions",
default="",
help="Space-separated list of functions to extract. The functions must be fully qualified from the crate root.",
help="Space-separated list of functions to extract. The function names must be fully qualified.",
)

parser.add_argument(
Expand All @@ -54,37 +54,47 @@ def shell(command, expect=0, cwd=None, format_selection_string=False):
dest="modules",
nargs="*",
default="",
help='Space-separated list of modules to extract. The modules must be fully qualified from the crate root. The special argument"root" can be used to extract the lib.rs file.',
help="Space-separated list of modules to extract. The module names must be fully qualified.",
)
parser.add_argument(
"--exclude-modules",
type=str,
dest="exclude_modules",
nargs="*",
default="",
help="Space-separated list of modules to exclude from extraction. The module names must be fully qualified.",
)

options = parser.parse_args()

if options.modules or options.functions:
if options.modules:
options.modules = " ".join(
[
"+::*" if module == "root" else "+" + module + "::*"
for module in options.modules
]
)
options.modules = " {}".format(options.modules)
filter_string = ""

if options.modules:
options.modules = " ".join(["+" + module + "::*" for module in options.modules])
filter_string += "{}".format(options.modules)
if options.functions:
options.functions = " ".join(["+" + function for function in options.functions])
filter_string += " {}".format(options.functions)

if options.exclude_modules:
options.exclude_modules = " ".join(
["-" + module + "::*" for module in options.exclude_modules]
)
filter_string += " {}".format(options.exclude_modules)

if options.functions:
options.functions = " ".join(
["+" + function for function in options.functions])
options.functions = " {}".format(options.functions)

if filter_string:
shell(
[
"cargo",
"hax",
"into",
"-i",
"-**{}{}".format(options.functions, options.modules),
"-** {}".format(filter_string),
"fstar",
],
cwd=options.crate_path,
format_selection_string=True,
format_filter_string=True,
)
elif options.kyber_reference:
shell(
Expand All @@ -93,11 +103,11 @@ def shell(command, expect=0, cwd=None, format_selection_string=False):
"hax",
"into",
"-i",
"-** +kem::kyber::** -kem::kyber::arithmetic::mutable_operations::**",
"-** +libcrux::kem::kyber::** -libcrux::kem::kyber::arithmetic::mutable_operations::** -libcrux::hacl::sha3::** -libcrux::digest::**",
"fstar",
],
cwd=".",
format_selection_string=True,
format_filter_string=True,
)
else:
shell(["cargo", "hax", "into", "fstar"], cwd=options.crate_path)
4 changes: 2 additions & 2 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Arithmetic.fst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ let impl_3: Core.Ops.Arith.t_Add t_KyberPolynomialRingElement t_KyberPolynomialR
Core.Ops.Range.f_end = Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT
})
<:
_.f_IntoIter)
(Core.Iter.Traits.Collect.impl (Core.Ops.Range.t_Range usize)).f_IntoIter)
result
(fun result i ->
{
Expand Down Expand Up @@ -94,7 +94,7 @@ let impl_4: Core.Ops.Arith.t_Sub t_KyberPolynomialRingElement t_KyberPolynomialR
Core.Ops.Range.f_end = Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT
})
<:
_.f_IntoIter)
(Core.Iter.Traits.Collect.impl (Core.Ops.Range.t_Range usize)).f_IntoIter)
result
(fun result i ->
{
Expand Down
20 changes: 9 additions & 11 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Compress.fst
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ let compress
re

let decompress
(#v_COEFFICIENT_BITS: usize)
(re: Libcrux.Kem.Kyber.Arithmetic.t_KyberPolynomialRingElement)
(bits_per_compressed_coefficient: usize)
: Libcrux.Kem.Kyber.Arithmetic.t_KyberPolynomialRingElement =
let re:Libcrux.Kem.Kyber.Arithmetic.t_KyberPolynomialRingElement =
{
re with
Libcrux.Kem.Kyber.Arithmetic.f_coefficients
=
Core.Array.impl_23__map re.Libcrux.Kem.Kyber.Arithmetic.f_coefficients
(fun coefficient -> decompress_q coefficient bits_per_compressed_coefficient <: i32)
(fun coefficient -> decompress_q coefficient <: i32)
}
in
re
Expand All @@ -51,22 +51,20 @@ let compress_q (#v_COEFFICIENT_BITS: usize) (fe: u16) : i32 =
in
let two_pow_bit_size:u32 = 1ul <<! v_COEFFICIENT_BITS in
let compressed:u32 = (cast fe <: u32) *! (two_pow_bit_size <<! 1l <: u32) in
let compressed:Prims.unit =
compressed +! (cast Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS <: u32)
in
let compressed:Prims.unit =
let compressed:u32 = compressed +! (cast Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS <: u32) in
let compressed:u32 =
compressed /! (cast (Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS <<! 1l <: i32) <: u32)
in
cast (compressed &. (two_pow_bit_size -! 1ul <: u32)) <: i32

let decompress_q (fe: i32) (to_bit_size: usize) : i32 =
let decompress_q (#v_COEFFICIENT_BITS: usize) (fe: i32) : i32 =
let _:Prims.unit =
if true
then
let _:Prims.unit =
if ~.(to_bit_size <=. Libcrux.Kem.Kyber.Constants.v_BITS_PER_COEFFICIENT <: bool)
if ~.(v_COEFFICIENT_BITS <=. Libcrux.Kem.Kyber.Constants.v_BITS_PER_COEFFICIENT <: bool)
then
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: to_bit_size <= BITS_PER_COEFFICIENT"
Rust_primitives.Hax.never_to_any (Core.Panicking.panic "assertion failed: COEFFICIENT_BITS <= BITS_PER_COEFFICIENT"

<:
Rust_primitives.Hax.t_Never)
Expand All @@ -76,6 +74,6 @@ let decompress_q (fe: i32) (to_bit_size: usize) : i32 =
let decompressed:u32 =
(cast fe <: u32) *! (cast Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS <: u32)
in
let decompressed:u32 = (decompressed <<! 1l <: u32) +! (1ul <<! to_bit_size <: u32) in
let decompressed:Prims.unit = decompressed >>! (to_bit_size +! sz 1 <: usize) in
let decompressed:u32 = (decompressed <<! 1l <: u32) +! (1ul <<! v_COEFFICIENT_BITS <: u32) in
let decompressed:u32 = decompressed >>! (v_COEFFICIENT_BITS +! sz 1 <: usize) in
cast decompressed <: i32
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ let compare_ciphertexts_in_constant_time (#v_CIPHERTEXT_SIZE: usize) (lhs rhs: s
()
in
let (r: u8):u8 = 0uy in
let r:Prims.unit =
let r:u8 =
Core.Iter.Traits.Iterator.Iterator.fold (Core.Iter.Traits.Collect.f_into_iter ({
Core.Ops.Range.f_start = sz 0;
Core.Ops.Range.f_end = v_CIPHERTEXT_SIZE
})
<:
_.f_IntoIter)
(Core.Iter.Traits.Collect.impl (Core.Ops.Range.t_Range usize)).f_IntoIter)
r
(fun r i -> r |. ((lhs.[ i ] <: u8) ^. (rhs.[ i ] <: u8) <: u8) <: Prims.unit)
(fun r i -> r |. ((lhs.[ i ] <: u8) ^. (rhs.[ i ] <: u8) <: u8) <: u8)
in
is_non_zero r

Expand Down Expand Up @@ -101,7 +101,7 @@ let select_shared_secret_in_constant_time (lhs rhs: slice u8) (selector: u8) : a
Core.Ops.Range.f_end = Libcrux.Kem.Kyber.Constants.v_SHARED_SECRET_SIZE
})
<:
_.f_IntoIter)
(Core.Iter.Traits.Collect.impl (Core.Ops.Range.t_Range usize)).f_IntoIter)
out
(fun out i ->
Rust_primitives.Hax.update_at out
Expand All @@ -111,7 +111,7 @@ let select_shared_secret_in_constant_time (lhs rhs: slice u8) (selector: u8) : a
<:
u8)
<:
Prims.unit)
u8)
<:
array u8 (sz 32))
in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ let into_padded_array (#v_LEN: usize) (slice: slice u8) : array u8 v_LEN =
out

class t_UpdatingArray (#v_Self: Type) = {
[@@@ FStar.Tactics.Typeclasses.no_method]_super_447510783:t_UpdatingArray v_Self;
[@@@ FStar.Tactics.Typeclasses.no_method]_super_509883233:t_UpdatingArray v_Self;
f_push:v_Self -> slice u8 -> v_Self
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ let v_XOFx4 (#v_LEN #v_K: usize) (input: array (array u8 (sz 34)) v_K) : array (
Core.Ops.Range.f_end = v_K
})
<:
_.f_IntoIter)
(Core.Iter.Traits.Collect.impl (Core.Ops.Range.t_Range usize)).f_IntoIter)
out
(fun out i ->
Rust_primitives.Hax.update_at out
Expand Down
Loading
Loading