Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-KEM decapsulation key hash check #1873

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions crypto/evp_extra/evp_extra_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2906,15 +2906,33 @@ TEST_P(PerMLKEMTest, InputValidation) {
std::vector<uint8_t> ss(ss_len);

// Encapsulate.
ASSERT_TRUE(
EVP_PKEY_encapsulate(ctx.get(), ct.data(), &ct_len, ss.data(), &ss_len));
ASSERT_TRUE(EVP_PKEY_encapsulate(ctx.get(), ct.data(), &ct_len, ss.data(), &ss_len));

// ---- 3. Test invalid public key ----
// FIPS 203 Section 7.2 Encapsulation key check (Modulus check).
// Invalidate the key by forcing a coefficient out of range.
// Invalidate the key by forcing a coefficient out of range
// (save the original values to reset later).
uint8_t tmp0 = ctx->pkey->pkey.kem_key->public_key[0];
uint8_t tmp1 = ctx->pkey->pkey.kem_key->public_key[1];
ctx->pkey->pkey.kem_key->public_key[0] = 0xff;
ctx->pkey->pkey.kem_key->public_key[1] = 0xff;

ASSERT_FALSE(
EVP_PKEY_encapsulate(ctx.get(), ct.data(), &ct_len, ss.data(), &ss_len));

// Reset the public key and make sure encapsulation/decapsulation succeeds.
ctx->pkey->pkey.kem_key->public_key[0] = tmp0;
ctx->pkey->pkey.kem_key->public_key[1] = tmp1;

std::vector<uint8_t> ss_expected(ss_len); // The shared secret.
ASSERT_TRUE(EVP_PKEY_encapsulate(ctx.get(), ct.data(), &ct_len, ss.data(), &ss_len));
ASSERT_TRUE(EVP_PKEY_decapsulate(ctx.get(), ss_expected.data(), &ss_len, ct.data(), ct_len));
EXPECT_EQ(Bytes(ss_expected), Bytes(ss));

// ---- 4. Test invalid secret key ----
// FIPS 203 Section 7.3 Decapsulation key check (Hash check).
// Invalidate the key by changing the hash of the public key within the secret key.
// The 32-byte hash is stored right before the last 32 bytes of the secret key.
ctx->pkey->pkey.kem_key->secret_key[GetParam().secret_key_len - 64] ^= 1;
ASSERT_FALSE(EVP_PKEY_decapsulate(ctx.get(), ss_expected.data(), &ss_len, ct.data(), ct_len));
}
20 changes: 20 additions & 0 deletions crypto/fipsmodule/ml_kem/ml_kem_ref/kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ static int encapsulation_key_modulus_check(ml_kem_params *params, const uint8_t
return verify(ek_recoded, ek, params->k * BYTE_ENCODE_12_OUT_SIZE);
}

// FIPS 203. Section 7.3 Decapsulation key hash check
// The spec defines the decapsulation key as following:
// dk <-- (dk_pke || ek || H(ek) || z).
// This check takes |ek| out of |dk|, computes H(ek), and verifies that it is
// the same as the H(ek) portion stored in |dk|.
Comment on lines +162 to +163
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this sounds like input check 3. from the spec, but what about the "type checks" in 1. and 2.?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering particularly about check 1 on the ciphertext which is said to have to be performed for every decapsulation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is the Hash check (as specified in L159). The other two Decaps checks and also the first check for Encaps (Section 7.2) are all length checks on the input arrays. We can safely omit them here because those checks are done in higher level functions. The required lengths for different variants of ML-KEM are hard-coded here:

  • DEFINE_LOCAL_DATA(KEM, KEM_ml_kem_512) {

    The ciphertext length for Decaps is checked here:
  • if (ciphertext_len != kem->ciphertext_len ||

    If a key is generated by aws-lc then it satisfies the length requirements. If a key is generated outside of aws-lc, it has to be imported into an EVP_PKEY object to be used within aws-lc. We provide only these three functions to do that: EVP_PKEY_kem_new_raw_key, EVP_PKEY_kem_new_raw_secret_key, EVP_PKEY_kem_new_raw_public_key. The lengths are checked for example here:
  • if (kem->public_key_len != len_public || kem->secret_key_len != len_secret) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @dkostic. Maybe you can summarise somewhere that all checks for encapsulate and decapsulate are done as per fips 203 in various places because I think someone coming to this part of the code where it says as in Sec 7.2 or 7.3 will wonder where the other checks are.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think someone coming to this part of the code where it says as in Sec 7.2 or 7.3 will wonder where the other checks are.

That's me, and I was indeed wondering that. I would appreciate it if the comments could be added.

static int decapsulation_key_hash_check(ml_kem_params *params, const uint8_t *dk) {
uint8_t dk_pke_hash_computed[KYBER_SYMBYTES] = {0};

hash_h(dk_pke_hash_computed, &dk[params->indcpa_secret_key_bytes],
params->indcpa_public_key_bytes);
const uint8_t *dk_pke_hash_expected = &dk[params->indcpa_secret_key_bytes +
params->indcpa_public_key_bytes];

return verify(dk_pke_hash_computed, dk_pke_hash_expected, KYBER_SYMBYTES);
}

/*************************************************
* Name: crypto_kem_enc_derand
*
Expand Down Expand Up @@ -248,6 +264,10 @@ int crypto_kem_dec(ml_kem_params *params,
const uint8_t *ct,
const uint8_t *sk)
{
if (decapsulation_key_hash_check(params, sk) != 0) {
return 1;
}

int fail;
uint8_t buf[2*KYBER_SYMBYTES];
/* Will contain key, coins */
Expand Down
Loading