Skip to content

Commit

Permalink
Test Improvements for ML-KEM (#1947)
Browse files Browse the repository at this point in the history
* test improvements for ML-KEM

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

* update length type from int to size_t

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

* fix windows dll + compilation issues

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

* fix windows tests for ACVP vectors

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

* fix build failure in vector_kem

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

* remove const qualifier from prng_op_stream

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

* add macros instead of hardcoding & declasify values before use

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

* add ML-KEM rejection tests in seperate function

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

* add ciphertext corruption test for kem rejection

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

* add conditional compilation for ML-KEM tests

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>

---------

Signed-off-by: Abhinav Saxena <abhinav.saxena@thalesgroup.com>
  • Loading branch information
abhinav-thales authored Nov 13, 2024
1 parent 2ee908d commit 507d030
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 9 deletions.
9 changes: 9 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ endif()
add_executable(test_kem test_kem.c)
target_link_libraries(test_kem PRIVATE ${TEST_DEPS})

if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS)
# workaround for Windows .dll
if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING)
target_link_options(test_kem PRIVATE -Wl,--allow-multiple-definition)
else()
target_link_options(test_kem PRIVATE "/FORCE:MULTIPLE")
endif()
endif()

add_executable(test_kem_mem test_kem_mem.c)
target_link_libraries(test_kem_mem PRIVATE ${TEST_DEPS})

Expand Down
3 changes: 0 additions & 3 deletions tests/test_acvp_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
ml_dsa_ver = "ACVP_Vectors/ML-DSA-sigVer-FIPS204/internalProjection.json"

@helpers.filtered_test
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows")
@pytest.mark.parametrize('kem_name', helpers.available_kems_by_name())
def test_acvp_vec_kem_keygen(kem_name):
if not(helpers.is_kem_enabled_by_name(kem_name)): pytest.skip('Not enabled')
Expand All @@ -45,7 +44,6 @@ def test_acvp_vec_kem_keygen(kem_name):
assert(variantFound == True)

@helpers.filtered_test
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows")
@pytest.mark.parametrize('kem_name', helpers.available_kems_by_name())
def test_acvp_vec_kem_encdec_aft(kem_name):

Expand Down Expand Up @@ -76,7 +74,6 @@ def test_acvp_vec_kem_encdec_aft(kem_name):
assert(variantFound == True)

@helpers.filtered_test
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows")
@pytest.mark.parametrize('kem_name', helpers.available_kems_by_name())
def test_acvp_vec_kem_encdec_val(kem_name):

Expand Down
101 changes: 100 additions & 1 deletion tests/test_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#if defined(_WIN32)
#include <string.h>
#define strcasecmp _stricmp
#else
#include <strings.h>
#endif

#include <oqs/oqs.h>

#include <oqs/sha3.h>
#if OQS_USE_PTHREADS
#include <pthread.h>
#endif
Expand All @@ -20,6 +25,10 @@
#define OQS_TEST_CT_DECLASSIFY(addr, len)
#endif

#ifdef OQS_ENABLE_KEM_ML_KEM
#define MLKEM_SECRET_LEN 32
#endif

#include "system_info.c"

/* Displays hexadecimal strings */
Expand All @@ -31,6 +40,89 @@ static void OQS_print_hex_string(const char *label, const uint8_t *str, size_t l
printf("\n");
}

#ifdef OQS_ENABLE_KEM_ML_KEM
/* mlkem rejection key testcase */
static bool mlkem_rej_testcase(OQS_KEM *kem, uint8_t *ciphertext, uint8_t *secret_key) {
// sanity checks
if ((kem == NULL) || (ciphertext == NULL) || (secret_key == NULL)) {
fprintf(stderr, "ERROR: inputs NULL!\n");
return false;
}
// Only run tests for ML-KEM
if (!(strcasecmp(kem->method_name, OQS_KEM_alg_ml_kem_512) == 0 ||
strcasecmp(kem->method_name, OQS_KEM_alg_ml_kem_768) == 0 ||
strcasecmp(kem->method_name, OQS_KEM_alg_ml_kem_1024) == 0)) {
return true;
}
// Buffer to hold z and c. z is always 32 bytes
uint8_t *buff_z_c = NULL;
bool retval = false;
OQS_STATUS rc;
int rv;
size_t length_z_c = 32 + kem->length_ciphertext;
buff_z_c = OQS_MEM_malloc(length_z_c);
if (buff_z_c == NULL) {
fprintf(stderr, "ERROR: OQS_MEM_malloc failed\n");
return false;
}
// Scenario 1: Test rejection key by corrupting the secret key
secret_key[0] += 1;
uint8_t shared_secret_r[MLKEM_SECRET_LEN]; // expected output
uint8_t shared_secret_d[MLKEM_SECRET_LEN]; // calculated output
memcpy(buff_z_c, &secret_key[kem->length_secret_key - 32], 32);
memcpy(&buff_z_c[MLKEM_SECRET_LEN], ciphertext, kem->length_ciphertext);
// Calculate expected secret in case of corrupted cipher: shake256(z || c)
OQS_SHA3_shake256(shared_secret_r, MLKEM_SECRET_LEN, buff_z_c, length_z_c);
rc = OQS_KEM_decaps(kem, shared_secret_d, ciphertext, secret_key);
OQS_TEST_CT_DECLASSIFY(&rc, sizeof rc);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "ERROR: OQS_KEM_decaps failed for rejection testcase scenario 1\n");
goto cleanup;
}
OQS_TEST_CT_DECLASSIFY(shared_secret_d, MLKEM_SECRET_LEN);
OQS_TEST_CT_DECLASSIFY(shared_secret_r, MLKEM_SECRET_LEN);
rv = memcmp(shared_secret_d, shared_secret_r, MLKEM_SECRET_LEN);
if (rv != 0) {
fprintf(stderr, "ERROR: shared secrets are not equal for rejection key in decapsulation scenario 1\n");
OQS_print_hex_string("shared_secret_d", shared_secret_d, MLKEM_SECRET_LEN);
OQS_print_hex_string("shared_secret_r", shared_secret_r, MLKEM_SECRET_LEN);
goto cleanup;
}
secret_key[0] -= 1; // Restore private key
memset(buff_z_c, 0, length_z_c); // Reset buffer

// Scenario 2: Test rejection key by corrupting the ciphertext
ciphertext[0] += 1;
memcpy(buff_z_c, &secret_key[kem->length_secret_key - 32], 32);
memcpy(&buff_z_c[MLKEM_SECRET_LEN], ciphertext, kem->length_ciphertext);

// Calculate expected secret in case of corrupted cipher: shake256(z || c)
OQS_SHA3_shake256(shared_secret_r, MLKEM_SECRET_LEN, buff_z_c, length_z_c);
rc = OQS_KEM_decaps(kem, shared_secret_d, ciphertext, secret_key);
OQS_TEST_CT_DECLASSIFY(&rc, sizeof rc);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "ERROR: OQS_KEM_decaps failed for rejection testcase scenario 2\n");
goto cleanup;
}
OQS_TEST_CT_DECLASSIFY(shared_secret_d, MLKEM_SECRET_LEN);
OQS_TEST_CT_DECLASSIFY(shared_secret_r, MLKEM_SECRET_LEN);
rv = memcmp(shared_secret_d, shared_secret_r, MLKEM_SECRET_LEN);
if (rv != 0) {
fprintf(stderr, "ERROR: shared secrets are not equal for rejection key in decapsulation scenario 2\n");
OQS_print_hex_string("shared_secret_d", shared_secret_d, MLKEM_SECRET_LEN);
OQS_print_hex_string("shared_secret_r", shared_secret_r, MLKEM_SECRET_LEN);
goto cleanup;
}
ciphertext[0] -= 1; // Restore ciphertext
retval = true;
cleanup:
if (buff_z_c) {
OQS_MEM_secure_free(buff_z_c, length_z_c);
}
return retval;
}
#endif //OQS_ENABLE_KEM_ML_KEM

typedef struct magic_s {
uint8_t val[31];
} magic_t;
Expand Down Expand Up @@ -127,6 +219,13 @@ static OQS_STATUS kem_test_correctness(const char *method_name) {
printf("shared secrets are equal\n");
}

#ifdef OQS_ENABLE_KEM_ML_KEM
/* check mlkem rejection testcases. returns true for all other kem algos */
if (false == mlkem_rej_testcase(kem, ciphertext, secret_key)) {
goto err;
}
#endif

// test invalid encapsulation (call should either fail or result in invalid shared secret)
OQS_randombytes(ciphertext, kem->length_ciphertext);
OQS_TEST_CT_DECLASSIFY(ciphertext, kem->length_ciphertext);
Expand Down
10 changes: 5 additions & 5 deletions tests/vectors_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name,
ret = OQS_SUCCESS;
} else {
ret = OQS_ERROR;
fprintf(stderr, "[vectors_kem] %s ERROR (AFT): ciphertext or shared secret doesn't match!\n", method_name);
fprintf(stderr, "[vectors_kem] %s ERROR (AFT): shared secret doesn't match!\n", method_name);
}

goto cleanup;
Expand Down Expand Up @@ -358,11 +358,11 @@ int main(int argc, char **argv) {
}

if (!strcmp(test_name, "keyGen")) {
prng_output_stream = argv[3]; // d || z
prng_output_stream = argv[3]; // d || z : both should be 32 bytes each as per FIPS-203
kg_pk = argv[4];
kg_sk = argv[5];

if (strlen(prng_output_stream) % 2 != 0 ||
if (strlen(prng_output_stream) != 128 ||
strlen(kg_pk) != 2 * kem->length_public_key ||
strlen(kg_sk) != 2 * kem->length_secret_key) {
rc = OQS_ERROR;
Expand All @@ -386,12 +386,12 @@ int main(int argc, char **argv) {

rc = kem_kg_vector(alg_name, prng_output_stream_bytes, kg_pk_bytes, kg_sk_bytes);
} else if (!strcmp(test_name, "encDecAFT")) {
prng_output_stream = argv[3]; // m
prng_output_stream = argv[3]; // m : should be 32 bytes as per FIPS-203
encdec_aft_pk = argv[4];
encdec_aft_k = argv[5];
encdec_aft_c = argv[6];

if (strlen(prng_output_stream) % 2 != 0 ||
if (strlen(prng_output_stream) != 64 ||
strlen(encdec_aft_c) != 2 * kem->length_ciphertext ||
strlen(encdec_aft_k) != 2 * kem->length_shared_secret ||
strlen(encdec_aft_pk) != 2 * kem->length_public_key) {
Expand Down

0 comments on commit 507d030

Please sign in to comment.