Skip to content

Commit

Permalink
Merge pull request #699 from cryspen/jonas/ml-dsa-crutch
Browse files Browse the repository at this point in the history
Temporary fixes for performance/build times/stack usage for ML-DSA AVX2
  • Loading branch information
franziskuskiefer authored Nov 28, 2024
2 parents d5e4a0f + f9097bb commit 77e9464
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 56 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ lto = "fat"
codegen-units = 1
panic = "abort"

[profile.dev.package."libcrux-ml-dsa"]
opt-level = 1

[lints.rust]
unexpected_cfgs = { level = "warn", check-cfg = [
'cfg(hax)',
Expand Down
15 changes: 6 additions & 9 deletions libcrux-ml-dsa/src/ml_dsa_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ pub(crate) fn generate_key_pair<
let (seed_for_error_vectors, seed_for_signing) =
seed_expanded.split_at(SEED_FOR_ERROR_VECTORS_SIZE);

let a_as_ntt = samplex4::matrix_A::<SIMDUnit, Shake128X4, ROWS_IN_A, COLUMNS_IN_A>(
into_padded_array(seed_for_a),
);
let a_as_ntt =
samplex4::matrix_A::<SIMDUnit, ROWS_IN_A, COLUMNS_IN_A>(into_padded_array(seed_for_a));

let (s1, s2) = samplex4::sample_s1_and_s2::<SIMDUnit, Shake256X4, ETA, COLUMNS_IN_A, ROWS_IN_A>(
into_padded_array(seed_for_error_vectors),
Expand Down Expand Up @@ -246,9 +245,8 @@ pub(crate) fn sign_internal<
SIGNING_KEY_SIZE,
>(signing_key);

let A_as_ntt = samplex4::matrix_A::<SIMDUnit, Shake128X4, ROWS_IN_A, COLUMNS_IN_A>(
into_padded_array(&seed_for_A),
);
let A_as_ntt =
samplex4::matrix_A::<SIMDUnit, ROWS_IN_A, COLUMNS_IN_A>(into_padded_array(&seed_for_A));

let mut message_representative = [0; MESSAGE_REPRESENTATIVE_SIZE];
derive_message_representative(
Expand Down Expand Up @@ -492,9 +490,8 @@ pub(crate) fn verify_internal<
signature.signer_response,
(2 << GAMMA1_EXPONENT) - BETA,
) {
let A_as_ntt = samplex4::matrix_A::<SIMDUnit, Shake128X4, ROWS_IN_A, COLUMNS_IN_A>(
into_padded_array(&seed_for_A),
);
let A_as_ntt =
samplex4::matrix_A::<SIMDUnit, ROWS_IN_A, COLUMNS_IN_A>(into_padded_array(&seed_for_A));

let mut verification_key_hash = [0; BYTES_FOR_VERIFICATION_KEY_HASH];
Shake256::shake256::<BYTES_FOR_VERIFICATION_KEY_HASH>(
Expand Down
21 changes: 13 additions & 8 deletions libcrux-ml-dsa/src/sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ fn rejection_sample_less_than_field_modulus<SIMDUnit: Operations>(
done
}

#[inline(always)]
pub(crate) fn sample_four_ring_elements<SIMDUnit: Operations, Shake128: shake128::XofX4>(
pub(crate) fn sample_four_ring_elements<SIMDUnit: Operations>(
mut seed0: [u8; 34],
domain_separator0: u16,
domain_separator1: u16,
Expand All @@ -44,6 +43,8 @@ pub(crate) fn sample_four_ring_elements<SIMDUnit: Operations, Shake128: shake128
PolynomialRingElement<SIMDUnit>,
PolynomialRingElement<SIMDUnit>,
) {
use crate::hash_functions::shake128::XofX4;

// Prepare the seeds
seed0[32] = domain_separator0 as u8;
seed0[33] = (domain_separator0 >> 8) as u8;
Expand All @@ -60,7 +61,12 @@ pub(crate) fn sample_four_ring_elements<SIMDUnit: Operations, Shake128: shake128
seed3[32] = domain_separator3 as u8;
seed3[33] = (domain_separator3 >> 8) as u8;

let mut state = Shake128::init_absorb(&seed0, &seed1, &seed2, &seed3);
// FIXME: We use the portable implementation here, since the
// compiler has an easier time optimizing it, compared to the AVX2
// version, which actually results in faster code (except for key
// generation), even in the AVX2 instantiation of ML-DSA.
let mut state =
crate::hash_functions::portable::Shake128X4::init_absorb(&seed0, &seed1, &seed2, &seed3);

let mut randomness0 = [0u8; shake128::FIVE_BLOCKS_SIZE];
let mut randomness1 = [0u8; shake128::FIVE_BLOCKS_SIZE];
Expand Down Expand Up @@ -483,10 +489,10 @@ mod tests {

// This is just a wrapper around sample_four_ring_elements, for testing
// purposes.
fn sample_ring_element_uniform<SIMDUnit: Operations, Shake128: shake128::XofX4>(
fn sample_ring_element_uniform<SIMDUnit: Operations>(
seed: [u8; 34],
) -> PolynomialRingElement<SIMDUnit> {
let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
((seed[33] as u16) << 8) | (seed[32] as u16),
0,
Expand Down Expand Up @@ -554,7 +560,7 @@ mod tests {
];

assert_eq!(
sample_ring_element_uniform::<SIMDUnit, Shake128>(seed).to_i32_array(),
sample_ring_element_uniform::<SIMDUnit>(seed).to_i32_array(),
expected_coefficients
);

Expand All @@ -568,8 +574,7 @@ mod tests {
0xB1, 0x83, 0x9B, 0x86, 0x06, 0xF5, 0x94, 0x8B, 0x9D, 0x72, 0xA9, 0x56, 0xDC, 0xF1,
0x01, 0x16, 0xDA, 0x9E, 0x01, 0x00,
];
let actual_coefficients =
sample_ring_element_uniform::<SIMDUnit, Shake128>(seed).to_i32_array();
let actual_coefficients = sample_ring_element_uniform::<SIMDUnit>(seed).to_i32_array();

assert_eq!(actual_coefficients[0], 1_165_602);
assert_eq!(
Expand Down
70 changes: 31 additions & 39 deletions libcrux-ml-dsa/src/samplex4.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::{
hash_functions::{shake128, shake256},
hash_functions::shake256,
polynomial::PolynomialRingElement,
sample::{sample_four_error_ring_elements, sample_four_ring_elements},
simd::traits::Operations,
Expand Down Expand Up @@ -30,7 +30,6 @@ fn update_matrix<SIMDUnit: Operations, const ROWS_IN_A: usize, const COLUMNS_IN_
#[inline(always)]
pub(crate) fn matrix_A_4_by_4<
SIMDUnit: Operations,
Shake128X4: shake128::XofX4,
const ROWS_IN_A: usize,
const COLUMNS_IN_A: usize,
>(
Expand All @@ -39,7 +38,7 @@ pub(crate) fn matrix_A_4_by_4<
let mut A: Matrix<SIMDUnit, ROWS_IN_A, COLUMNS_IN_A> =
[[PolynomialRingElement::<SIMDUnit>::ZERO(); COLUMNS_IN_A]; ROWS_IN_A];

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(0, 0),
generate_domain_separator(0, 1),
Expand All @@ -51,7 +50,7 @@ pub(crate) fn matrix_A_4_by_4<
update_matrix(&mut A, 0, 2, four_ring_elements.2);
update_matrix(&mut A, 0, 3, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(1, 0),
generate_domain_separator(1, 1),
Expand All @@ -63,7 +62,7 @@ pub(crate) fn matrix_A_4_by_4<
update_matrix(&mut A, 1, 2, four_ring_elements.2);
update_matrix(&mut A, 1, 3, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(2, 0),
generate_domain_separator(2, 1),
Expand All @@ -75,7 +74,7 @@ pub(crate) fn matrix_A_4_by_4<
update_matrix(&mut A, 2, 2, four_ring_elements.2);
update_matrix(&mut A, 2, 3, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(3, 0),
generate_domain_separator(3, 1),
Expand All @@ -94,15 +93,14 @@ pub(crate) fn matrix_A_4_by_4<
#[inline(always)]
pub(crate) fn matrix_A_6_by_5<
SIMDUnit: Operations,
Shake128X4: shake128::XofX4,
const ROWS_IN_A: usize,
const COLUMNS_IN_A: usize,
>(
seed: [u8; 34],
) -> [[PolynomialRingElement<SIMDUnit>; COLUMNS_IN_A]; ROWS_IN_A] {
let mut A = [[PolynomialRingElement::<SIMDUnit>::ZERO(); COLUMNS_IN_A]; ROWS_IN_A];

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(0, 0),
generate_domain_separator(0, 1),
Expand All @@ -114,7 +112,7 @@ pub(crate) fn matrix_A_6_by_5<
update_matrix(&mut A, 0, 2, four_ring_elements.2);
update_matrix(&mut A, 0, 3, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(0, 4),
generate_domain_separator(1, 0),
Expand All @@ -126,7 +124,7 @@ pub(crate) fn matrix_A_6_by_5<
update_matrix(&mut A, 1, 1, four_ring_elements.2);
update_matrix(&mut A, 1, 2, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(1, 3),
generate_domain_separator(1, 4),
Expand All @@ -138,7 +136,7 @@ pub(crate) fn matrix_A_6_by_5<
update_matrix(&mut A, 2, 0, four_ring_elements.2);
update_matrix(&mut A, 2, 1, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(2, 2),
generate_domain_separator(2, 3),
Expand All @@ -150,7 +148,7 @@ pub(crate) fn matrix_A_6_by_5<
update_matrix(&mut A, 2, 4, four_ring_elements.2);
update_matrix(&mut A, 3, 0, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(3, 1),
generate_domain_separator(3, 2),
Expand All @@ -162,7 +160,7 @@ pub(crate) fn matrix_A_6_by_5<
update_matrix(&mut A, 3, 3, four_ring_elements.2);
update_matrix(&mut A, 3, 4, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(4, 0),
generate_domain_separator(4, 1),
Expand All @@ -174,7 +172,7 @@ pub(crate) fn matrix_A_6_by_5<
update_matrix(&mut A, 4, 2, four_ring_elements.2);
update_matrix(&mut A, 4, 3, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(4, 4),
generate_domain_separator(5, 0),
Expand All @@ -187,7 +185,7 @@ pub(crate) fn matrix_A_6_by_5<
update_matrix(&mut A, 5, 2, four_ring_elements.3);

// The the last 2 sampled ring elements are discarded here.
let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(5, 3),
generate_domain_separator(5, 4),
Expand All @@ -203,15 +201,14 @@ pub(crate) fn matrix_A_6_by_5<
#[inline(always)]
pub(crate) fn matrix_A_8_by_7<
SIMDUnit: Operations,
Shake128X4: shake128::XofX4,
const ROWS_IN_A: usize,
const COLUMNS_IN_A: usize,
>(
seed: [u8; 34],
) -> [[PolynomialRingElement<SIMDUnit>; COLUMNS_IN_A]; ROWS_IN_A] {
let mut A = [[PolynomialRingElement::<SIMDUnit>::ZERO(); COLUMNS_IN_A]; ROWS_IN_A];

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(0, 0),
generate_domain_separator(0, 1),
Expand All @@ -223,7 +220,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 0, 2, four_ring_elements.2);
update_matrix(&mut A, 0, 3, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(0, 4),
generate_domain_separator(0, 5),
Expand All @@ -235,7 +232,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 0, 6, four_ring_elements.2);
update_matrix(&mut A, 1, 0, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(1, 1),
generate_domain_separator(1, 2),
Expand All @@ -247,7 +244,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 1, 3, four_ring_elements.2);
update_matrix(&mut A, 1, 4, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(1, 5),
generate_domain_separator(1, 6),
Expand All @@ -259,7 +256,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 2, 0, four_ring_elements.2);
update_matrix(&mut A, 2, 1, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(2, 2),
generate_domain_separator(2, 3),
Expand All @@ -271,7 +268,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 2, 4, four_ring_elements.2);
update_matrix(&mut A, 2, 5, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(2, 6),
generate_domain_separator(3, 0),
Expand All @@ -283,7 +280,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 3, 1, four_ring_elements.2);
update_matrix(&mut A, 3, 2, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(3, 3),
generate_domain_separator(3, 4),
Expand All @@ -295,7 +292,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 3, 5, four_ring_elements.2);
update_matrix(&mut A, 3, 6, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(4, 0),
generate_domain_separator(4, 1),
Expand All @@ -307,7 +304,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 4, 2, four_ring_elements.2);
update_matrix(&mut A, 4, 3, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(4, 4),
generate_domain_separator(4, 5),
Expand All @@ -319,7 +316,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 4, 6, four_ring_elements.2);
update_matrix(&mut A, 5, 0, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(5, 1),
generate_domain_separator(5, 2),
Expand All @@ -331,7 +328,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 5, 3, four_ring_elements.2);
update_matrix(&mut A, 5, 4, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(5, 5),
generate_domain_separator(5, 6),
Expand All @@ -343,7 +340,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 6, 0, four_ring_elements.2);
update_matrix(&mut A, 6, 1, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(6, 2),
generate_domain_separator(6, 3),
Expand All @@ -355,7 +352,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 6, 4, four_ring_elements.2);
update_matrix(&mut A, 6, 5, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(6, 6),
generate_domain_separator(7, 0),
Expand All @@ -367,7 +364,7 @@ pub(crate) fn matrix_A_8_by_7<
update_matrix(&mut A, 7, 1, four_ring_elements.2);
update_matrix(&mut A, 7, 2, four_ring_elements.3);

let four_ring_elements = sample_four_ring_elements::<SIMDUnit, Shake128X4>(
let four_ring_elements = sample_four_ring_elements::<SIMDUnit>(
seed,
generate_domain_separator(7, 3),
generate_domain_separator(7, 4),
Expand All @@ -383,18 +380,13 @@ pub(crate) fn matrix_A_8_by_7<
}
#[allow(non_snake_case)]
#[inline(always)]
pub(crate) fn matrix_A<
SIMDUnit: Operations,
Shake128X4: shake128::XofX4,
const ROWS_IN_A: usize,
const COLUMNS_IN_A: usize,
>(
pub(crate) fn matrix_A<SIMDUnit: Operations, const ROWS_IN_A: usize, const COLUMNS_IN_A: usize>(
seed: [u8; 34],
) -> [[PolynomialRingElement<SIMDUnit>; COLUMNS_IN_A]; ROWS_IN_A] {
match (ROWS_IN_A as u8, COLUMNS_IN_A as u8) {
(4, 4) => matrix_A_4_by_4::<SIMDUnit, Shake128X4, ROWS_IN_A, COLUMNS_IN_A>(seed),
(6, 5) => matrix_A_6_by_5::<SIMDUnit, Shake128X4, ROWS_IN_A, COLUMNS_IN_A>(seed),
(8, 7) => matrix_A_8_by_7::<SIMDUnit, Shake128X4, ROWS_IN_A, COLUMNS_IN_A>(seed),
(4, 4) => matrix_A_4_by_4::<SIMDUnit, ROWS_IN_A, COLUMNS_IN_A>(seed),
(6, 5) => matrix_A_6_by_5::<SIMDUnit, ROWS_IN_A, COLUMNS_IN_A>(seed),
(8, 7) => matrix_A_8_by_7::<SIMDUnit, ROWS_IN_A, COLUMNS_IN_A>(seed),
_ => unreachable!(),
}
}
Expand Down

0 comments on commit 77e9464

Please sign in to comment.