diff --git a/mbedtls/src/pk/mod.rs b/mbedtls/src/pk/mod.rs index a96719471..9532f22d5 100644 --- a/mbedtls/src/pk/mod.rs +++ b/mbedtls/src/pk/mod.rs @@ -138,6 +138,34 @@ const CUSTOM_PK_INFO: pk_info_t = { } }; +/// RSA components for constructing a private key. +pub enum RsaPrivateComponents<'a> { + WithPrimes { + /// Private 1st prime + p: &'a Mpi, + /// Private 2nd prime + q: &'a Mpi, + /// Public exponent + e: &'a Mpi, + }, + WithPrivateExponent { + /// Public modulus + n: &'a Mpi, + /// Private exponent + d: &'a Mpi, + /// Public exponent + e: &'a Mpi, + }, +} + +/// RSA components for constructing a public key. +pub struct RsaPublicComponents<'a> { + /// Public modulus + pub n: &'a Mpi, + /// Public exponent + pub e: &'a Mpi, +} + // If this changes then certificate.rs unsafe code in public_key needs to also // change. define!( @@ -200,7 +228,7 @@ define!( // - Only const access to context: eckey_check_pair, eckey_get_bitlen, // eckey_can_do, eckey_check_pair // -// - Const acccess / copies context to a stack based variable eckey_verify_wrap, +// - Const access / copies context to a stack based variable eckey_verify_wrap, // eckey_sign_wrap: ../../../mbedtls-sys/vendor/crypto/library/pk_wrap.c:251 // creates a stack ecdsa variable and uses ctx to initialize it. ctx is passed // as 'key', a const pointer to mbedtls_ecdsa_from_keypair( &ecdsa, ctx ) @@ -348,7 +376,7 @@ Please use `private_from_ec_components_with_rng` instead." /// /// This function will return an error if: /// - /// * Fails to genearte `EcPoint` from given EcGroup in `curve`. + /// * Fails to generate `EcPoint` from given EcGroup in `curve`. /// * The underlying C `mbedtls_pk_setup` function fails to set up the `Pk` context. /// * The `EcPoint::mul` function fails to generate the public key point. pub fn private_from_ec_components_with_rng(mut curve: EcGroup, private_key: Mpi, rng: &mut F) -> Result { @@ -376,6 +404,39 @@ Please use `private_from_ec_components_with_rng` instead." Ok(ret) } + /// Construct a private key from RSA components. + pub fn private_from_rsa_components(components: RsaPrivateComponents<'_>) -> Result { + let mut ret = Self::init(); + let (n, p, q, d, e) = match components { + RsaPrivateComponents::WithPrimes { p, q, e } => (None, Some(p), Some(q), None, Some(e)), + RsaPrivateComponents::WithPrivateExponent { n, d, e } => (Some(n), None, None, Some(d), Some(e)), + }; + let to_ptr = |mpi: Option<&Mpi>| match mpi { + None => ptr::null(), + Some(mpi) => mpi.handle(), + }; + unsafe { + pk_setup(&mut ret.inner, pk_info_from_type(Type::Rsa.into())).into_result()?; + let ctx = ret.inner.pk_ctx as *mut rsa_context; + rsa_import(ctx, to_ptr(n), to_ptr(p), to_ptr(q), to_ptr(d), to_ptr(e)).into_result()?; + rsa_complete(ctx).into_result()?; + } + Ok(ret) + } + + /// Construct a public key from RSA components. + pub fn public_from_rsa_components(components: RsaPublicComponents<'_>) -> Result { + let mut ret = Self::init(); + let RsaPublicComponents { n, e } = components; + unsafe { + pk_setup(&mut ret.inner, pk_info_from_type(Type::Rsa.into())).into_result()?; + let ctx = ret.inner.pk_ctx as *mut rsa_context; + rsa_import(ctx, n.handle(), ptr::null(), ptr::null(), ptr::null(), e.handle()).into_result()?; + rsa_complete(ctx).into_result()?; + } + Ok(ret) + } + pub fn public_custom_algo(algo_id: &[u64], pk: &[u8]) -> Result { let mut ret = Self::init(); unsafe { @@ -1564,4 +1625,37 @@ iy6KC991zzvaWY/Ys+q/84Afqa+0qJKQnPuy/7F5GkVdQA/lfbhi assert_eq!(l.unwrap(), LEN); } } + + #[test] + fn private_from_rsa_components_sanity() { + let mut pk = Pk::generate_rsa(&mut crate::test_support::rand::test_rng(), 2048, 0x10001).unwrap(); + let components = RsaPrivateComponents::WithPrimes { + p: &pk.rsa_private_prime1().unwrap(), + q: &pk.rsa_private_prime2().unwrap(), + e: &Mpi::new(pk.rsa_public_exponent().unwrap() as _).unwrap(), + }; + let mut pk2 = Pk::private_from_rsa_components(components).unwrap(); + assert_eq!(pk.write_private_der_vec().unwrap(), pk2.write_private_der_vec().unwrap()); + + let components = RsaPrivateComponents::WithPrivateExponent { + n: &pk.rsa_public_modulus().unwrap(), + d: &pk.rsa_private_exponent().unwrap(), + e: &Mpi::new(pk.rsa_public_exponent().unwrap() as _).unwrap(), + }; + let mut pk3 = Pk::private_from_rsa_components(components).unwrap(); + assert_eq!(pk.write_private_der_vec().unwrap(), pk3.write_private_der_vec().unwrap()); + } + + #[test] + fn public_from_rsa_components_sanity() { + let mut pk = Pk::generate_rsa(&mut crate::test_support::rand::test_rng(), 2048, 0x10001).unwrap(); + let mut pk = Pk::from_public_key(&pk.write_public_der_vec().unwrap()).unwrap(); + + let components = RsaPublicComponents { + n: &pk.rsa_public_modulus().unwrap(), + e: &Mpi::new(pk.rsa_public_exponent().unwrap() as _).unwrap(), + }; + let mut pk2 = Pk::public_from_rsa_components(components).unwrap(); + assert_eq!(pk.write_public_der_vec().unwrap(), pk2.write_public_der_vec().unwrap()); + } }