From 507d03009cecc30a4aed78279d08ab7f8de813f8 Mon Sep 17 00:00:00 2001 From: Abhinav Saxena Date: Wed, 13 Nov 2024 18:00:57 +0530 Subject: [PATCH] Test Improvements for ML-KEM (#1947) * test improvements for ML-KEM Signed-off-by: Abhinav Saxena * update length type from int to size_t Signed-off-by: Abhinav Saxena * fix windows dll + compilation issues Signed-off-by: Abhinav Saxena * fix windows tests for ACVP vectors Signed-off-by: Abhinav Saxena * fix build failure in vector_kem Signed-off-by: Abhinav Saxena * remove const qualifier from prng_op_stream Signed-off-by: Abhinav Saxena * add macros instead of hardcoding & declasify values before use Signed-off-by: Abhinav Saxena * add ML-KEM rejection tests in seperate function Signed-off-by: Abhinav Saxena * add ciphertext corruption test for kem rejection Signed-off-by: Abhinav Saxena * add conditional compilation for ML-KEM tests Signed-off-by: Abhinav Saxena --------- Signed-off-by: Abhinav Saxena --- tests/CMakeLists.txt | 9 ++++ tests/test_acvp_vectors.py | 3 -- tests/test_kem.c | 101 ++++++++++++++++++++++++++++++++++++- tests/vectors_kem.c | 10 ++-- 4 files changed, 114 insertions(+), 9 deletions(-) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 2eb6ef6b3..6d08516a8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -82,6 +82,15 @@ endif() add_executable(test_kem test_kem.c) target_link_libraries(test_kem PRIVATE ${TEST_DEPS}) +if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS) + # workaround for Windows .dll + if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING) + target_link_options(test_kem PRIVATE -Wl,--allow-multiple-definition) + else() + target_link_options(test_kem PRIVATE "/FORCE:MULTIPLE") + endif() +endif() + add_executable(test_kem_mem test_kem_mem.c) target_link_libraries(test_kem_mem PRIVATE ${TEST_DEPS}) diff --git a/tests/test_acvp_vectors.py b/tests/test_acvp_vectors.py index ddd64003c..a0504cabc 100644 --- a/tests/test_acvp_vectors.py +++ b/tests/test_acvp_vectors.py @@ -18,7 +18,6 @@ ml_dsa_ver = "ACVP_Vectors/ML-DSA-sigVer-FIPS204/internalProjection.json" @helpers.filtered_test -@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows") @pytest.mark.parametrize('kem_name', helpers.available_kems_by_name()) def test_acvp_vec_kem_keygen(kem_name): if not(helpers.is_kem_enabled_by_name(kem_name)): pytest.skip('Not enabled') @@ -45,7 +44,6 @@ def test_acvp_vec_kem_keygen(kem_name): assert(variantFound == True) @helpers.filtered_test -@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows") @pytest.mark.parametrize('kem_name', helpers.available_kems_by_name()) def test_acvp_vec_kem_encdec_aft(kem_name): @@ -76,7 +74,6 @@ def test_acvp_vec_kem_encdec_aft(kem_name): assert(variantFound == True) @helpers.filtered_test -@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows") @pytest.mark.parametrize('kem_name', helpers.available_kems_by_name()) def test_acvp_vec_kem_encdec_val(kem_name): diff --git a/tests/test_kem.c b/tests/test_kem.c index a4076776a..75204b1d3 100644 --- a/tests/test_kem.c +++ b/tests/test_kem.c @@ -3,10 +3,15 @@ #include #include #include +#if defined(_WIN32) #include +#define strcasecmp _stricmp +#else +#include +#endif #include - +#include #if OQS_USE_PTHREADS #include #endif @@ -20,6 +25,10 @@ #define OQS_TEST_CT_DECLASSIFY(addr, len) #endif +#ifdef OQS_ENABLE_KEM_ML_KEM +#define MLKEM_SECRET_LEN 32 +#endif + #include "system_info.c" /* Displays hexadecimal strings */ @@ -31,6 +40,89 @@ static void OQS_print_hex_string(const char *label, const uint8_t *str, size_t l printf("\n"); } +#ifdef OQS_ENABLE_KEM_ML_KEM +/* mlkem rejection key testcase */ +static bool mlkem_rej_testcase(OQS_KEM *kem, uint8_t *ciphertext, uint8_t *secret_key) { + // sanity checks + if ((kem == NULL) || (ciphertext == NULL) || (secret_key == NULL)) { + fprintf(stderr, "ERROR: inputs NULL!\n"); + return false; + } + // Only run tests for ML-KEM + if (!(strcasecmp(kem->method_name, OQS_KEM_alg_ml_kem_512) == 0 || + strcasecmp(kem->method_name, OQS_KEM_alg_ml_kem_768) == 0 || + strcasecmp(kem->method_name, OQS_KEM_alg_ml_kem_1024) == 0)) { + return true; + } + // Buffer to hold z and c. z is always 32 bytes + uint8_t *buff_z_c = NULL; + bool retval = false; + OQS_STATUS rc; + int rv; + size_t length_z_c = 32 + kem->length_ciphertext; + buff_z_c = OQS_MEM_malloc(length_z_c); + if (buff_z_c == NULL) { + fprintf(stderr, "ERROR: OQS_MEM_malloc failed\n"); + return false; + } + // Scenario 1: Test rejection key by corrupting the secret key + secret_key[0] += 1; + uint8_t shared_secret_r[MLKEM_SECRET_LEN]; // expected output + uint8_t shared_secret_d[MLKEM_SECRET_LEN]; // calculated output + memcpy(buff_z_c, &secret_key[kem->length_secret_key - 32], 32); + memcpy(&buff_z_c[MLKEM_SECRET_LEN], ciphertext, kem->length_ciphertext); + // Calculate expected secret in case of corrupted cipher: shake256(z || c) + OQS_SHA3_shake256(shared_secret_r, MLKEM_SECRET_LEN, buff_z_c, length_z_c); + rc = OQS_KEM_decaps(kem, shared_secret_d, ciphertext, secret_key); + OQS_TEST_CT_DECLASSIFY(&rc, sizeof rc); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "ERROR: OQS_KEM_decaps failed for rejection testcase scenario 1\n"); + goto cleanup; + } + OQS_TEST_CT_DECLASSIFY(shared_secret_d, MLKEM_SECRET_LEN); + OQS_TEST_CT_DECLASSIFY(shared_secret_r, MLKEM_SECRET_LEN); + rv = memcmp(shared_secret_d, shared_secret_r, MLKEM_SECRET_LEN); + if (rv != 0) { + fprintf(stderr, "ERROR: shared secrets are not equal for rejection key in decapsulation scenario 1\n"); + OQS_print_hex_string("shared_secret_d", shared_secret_d, MLKEM_SECRET_LEN); + OQS_print_hex_string("shared_secret_r", shared_secret_r, MLKEM_SECRET_LEN); + goto cleanup; + } + secret_key[0] -= 1; // Restore private key + memset(buff_z_c, 0, length_z_c); // Reset buffer + + // Scenario 2: Test rejection key by corrupting the ciphertext + ciphertext[0] += 1; + memcpy(buff_z_c, &secret_key[kem->length_secret_key - 32], 32); + memcpy(&buff_z_c[MLKEM_SECRET_LEN], ciphertext, kem->length_ciphertext); + + // Calculate expected secret in case of corrupted cipher: shake256(z || c) + OQS_SHA3_shake256(shared_secret_r, MLKEM_SECRET_LEN, buff_z_c, length_z_c); + rc = OQS_KEM_decaps(kem, shared_secret_d, ciphertext, secret_key); + OQS_TEST_CT_DECLASSIFY(&rc, sizeof rc); + if (rc != OQS_SUCCESS) { + fprintf(stderr, "ERROR: OQS_KEM_decaps failed for rejection testcase scenario 2\n"); + goto cleanup; + } + OQS_TEST_CT_DECLASSIFY(shared_secret_d, MLKEM_SECRET_LEN); + OQS_TEST_CT_DECLASSIFY(shared_secret_r, MLKEM_SECRET_LEN); + rv = memcmp(shared_secret_d, shared_secret_r, MLKEM_SECRET_LEN); + if (rv != 0) { + fprintf(stderr, "ERROR: shared secrets are not equal for rejection key in decapsulation scenario 2\n"); + OQS_print_hex_string("shared_secret_d", shared_secret_d, MLKEM_SECRET_LEN); + OQS_print_hex_string("shared_secret_r", shared_secret_r, MLKEM_SECRET_LEN); + goto cleanup; + } + ciphertext[0] -= 1; // Restore ciphertext + retval = true; +cleanup: + if (buff_z_c) { + OQS_MEM_secure_free(buff_z_c, length_z_c); + } + return retval; +} +#endif //OQS_ENABLE_KEM_ML_KEM + typedef struct magic_s { uint8_t val[31]; } magic_t; @@ -127,6 +219,13 @@ static OQS_STATUS kem_test_correctness(const char *method_name) { printf("shared secrets are equal\n"); } +#ifdef OQS_ENABLE_KEM_ML_KEM + /* check mlkem rejection testcases. returns true for all other kem algos */ + if (false == mlkem_rej_testcase(kem, ciphertext, secret_key)) { + goto err; + } +#endif + // test invalid encapsulation (call should either fail or result in invalid shared secret) OQS_randombytes(ciphertext, kem->length_ciphertext); OQS_TEST_CT_DECLASSIFY(ciphertext, kem->length_ciphertext); diff --git a/tests/vectors_kem.c b/tests/vectors_kem.c index 4dc8ae63b..1037be726 100644 --- a/tests/vectors_kem.c +++ b/tests/vectors_kem.c @@ -285,7 +285,7 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name, ret = OQS_SUCCESS; } else { ret = OQS_ERROR; - fprintf(stderr, "[vectors_kem] %s ERROR (AFT): ciphertext or shared secret doesn't match!\n", method_name); + fprintf(stderr, "[vectors_kem] %s ERROR (AFT): shared secret doesn't match!\n", method_name); } goto cleanup; @@ -358,11 +358,11 @@ int main(int argc, char **argv) { } if (!strcmp(test_name, "keyGen")) { - prng_output_stream = argv[3]; // d || z + prng_output_stream = argv[3]; // d || z : both should be 32 bytes each as per FIPS-203 kg_pk = argv[4]; kg_sk = argv[5]; - if (strlen(prng_output_stream) % 2 != 0 || + if (strlen(prng_output_stream) != 128 || strlen(kg_pk) != 2 * kem->length_public_key || strlen(kg_sk) != 2 * kem->length_secret_key) { rc = OQS_ERROR; @@ -386,12 +386,12 @@ int main(int argc, char **argv) { rc = kem_kg_vector(alg_name, prng_output_stream_bytes, kg_pk_bytes, kg_sk_bytes); } else if (!strcmp(test_name, "encDecAFT")) { - prng_output_stream = argv[3]; // m + prng_output_stream = argv[3]; // m : should be 32 bytes as per FIPS-203 encdec_aft_pk = argv[4]; encdec_aft_k = argv[5]; encdec_aft_c = argv[6]; - if (strlen(prng_output_stream) % 2 != 0 || + if (strlen(prng_output_stream) != 64 || strlen(encdec_aft_c) != 2 * kem->length_ciphertext || strlen(encdec_aft_k) != 2 * kem->length_shared_secret || strlen(encdec_aft_pk) != 2 * kem->length_public_key) {