diff --git a/esp-hal-common/src/lib.rs b/esp-hal-common/src/lib.rs index 1393a0066c0..19658372e10 100644 --- a/esp-hal-common/src/lib.rs +++ b/esp-hal-common/src/lib.rs @@ -104,6 +104,8 @@ pub mod reset; #[cfg(rng)] pub mod rng; pub mod rom; +#[cfg(rsa)] +pub mod rsa; #[cfg(any(lp_clkrst, rtc_cntl))] pub mod rtc_cntl; #[cfg(sha)] diff --git a/esp-hal-common/src/rsa/esp32.rs b/esp-hal-common/src/rsa/esp32.rs new file mode 100644 index 00000000000..b8c74a1676d --- /dev/null +++ b/esp-hal-common/src/rsa/esp32.rs @@ -0,0 +1,211 @@ +use core::{ + convert::Infallible, + marker::PhantomData, + ptr::{copy_nonoverlapping, write_bytes}, +}; + +use crate::rsa::{ + implement_op, + Multi, + Rsa, + RsaMode, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, +}; + +impl<'d> Rsa<'d> { + /// After the RSA Accelerator is released from reset, the memory blocks + /// needs to be initialized, only after that peripheral should be used. + /// This function would return without an error if the memory is initialized + pub fn ready(&mut self) -> nb::Result<(), Infallible> { + if self.rsa.clean.read().clean().bit_is_clear() { + return Err(nb::Error::WouldBlock); + } + Ok(()) + } + + pub(super) fn write_multi_mode(&mut self, mode: u32) { + Self::write_to_register(&mut self.rsa.mult_mode, mode as u32); + } + + pub(super) fn write_modexp_mode(&mut self, mode: u32) { + Self::write_to_register(&mut self.rsa.modexp_mode, mode); + } + + pub(super) fn write_modexp_start(&mut self) { + self.rsa.modexp_start.write(|w| w.modexp_start().set_bit()); + } + + pub(super) fn write_multi_start(&mut self) { + self.rsa.mult_start.write(|w| w.mult_start().set_bit()); + } + + pub(super) fn clear_interrupt(&mut self) { + self.rsa.interrupt.write(|w| w.interrupt().set_bit()); + } + + pub(super) fn is_idle(&mut self) -> bool { + self.rsa.interrupt.read().bits() == 1 + } + + unsafe fn write_multi_operand_a(&mut self, operand_a: &[u8; N]) { + copy_nonoverlapping( + operand_a.as_ptr(), + self.rsa.x_mem.as_mut_ptr() as *mut u8, + N, + ); + write_bytes(self.rsa.x_mem.as_mut_ptr().add(N), 0, N); + } + + unsafe fn write_multi_operand_b(&mut self, operand_b: &[u8; N]) { + write_bytes(self.rsa.z_mem.as_mut_ptr(), 0, N); + copy_nonoverlapping( + operand_b.as_ptr(), + self.rsa.z_mem.as_mut_ptr().add(N) as *mut u8, + N, + ); + } +} + +pub mod operand_sizes { + //! Marker types for the operand sizes + use paste::paste; + + use super::{implement_op, Multi, RsaMode}; + + implement_op!( + (512, multi), + (1024, multi), + (1536, multi), + (2048, multi), + (2560), + (3072), + (3584), + (4096) + ); +} + +impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T> +where + T: RsaMode, +{ + /// Creates an Instance of `RsaMultiplication`. + /// `m_prime` could be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`, for more information check 24.3.2 in the + /// + pub fn new(rsa: &'a mut Rsa<'d>, modulus: &T::InputType, m_prime: u32) -> Self { + Self::set_mode(rsa); + unsafe { + rsa.write_modulus(modulus); + } + rsa.write_mprime(m_prime); + + Self { + rsa, + phantom: PhantomData, + } + } + + fn set_mode(rsa: &mut Rsa) { + rsa.write_multi_mode((N / 64 - 1) as u32) + } + + /// Starts the first step of modular multiplication operation. `r` could be + /// calculated using `2 ^ ( bitlength * 2 ) mod modulus`, + /// for more information check 24.3.2 in the + /// + pub fn start_step1(&mut self, operand_a: &T::InputType, r: &T::InputType) { + unsafe { + self.rsa.write_operand_a(operand_a); + self.rsa.write_r(r); + } + self.set_start(); + } + + /// Starts the second step of modular multiplication operation. + /// This is a non blocking function that returns without an error if + /// operation is completed successfully. `start_step1` must be called + /// before calling this function. + pub fn start_step2(&mut self, operand_b: &T::InputType) -> nb::Result<(), Infallible> { + if !self.rsa.is_idle() { + return Err(nb::Error::WouldBlock); + } + self.rsa.clear_interrupt(); + unsafe { + self.rsa.write_operand_a(operand_b); + } + self.set_start(); + Ok(()) + } + + fn set_start(&mut self) { + self.rsa.write_multi_start(); + } +} + +impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T> +where + T: RsaMode, +{ + /// Creates an Instance of `RsaModularExponentiation`. + /// `m_prime` could be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`, for more information check 24.3.2 in the + /// + pub fn new( + rsa: &'a mut Rsa<'d>, + exponent: &T::InputType, + modulus: &T::InputType, + m_prime: u32, + ) -> Self { + Self::set_mode(rsa); + unsafe { + rsa.write_operand_b(exponent); + rsa.write_modulus(modulus); + } + rsa.write_mprime(m_prime); + Self { + rsa, + phantom: PhantomData, + } + } + + pub(super) fn set_mode(rsa: &mut Rsa) { + rsa.write_modexp_mode((N / 64 - 1) as u32) + } + + pub(super) fn set_start(&mut self) { + self.rsa.write_modexp_start(); + } +} + +impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T> +where + T: RsaMode, +{ + /// Creates an Instance of `RsaMultiplication`. + pub fn new(rsa: &'a mut Rsa<'d>) -> Self { + Self::set_mode(rsa); + Self { + rsa, + phantom: PhantomData, + } + } + + /// Starts the multiplication operation. + pub fn start_multiplication(&mut self, operand_a: &T::InputType, operand_b: &T::InputType) { + unsafe { + self.rsa.write_multi_operand_a(operand_a); + self.rsa.write_multi_operand_b(operand_b); + } + self.set_start(); + } + + pub(super) fn set_mode(rsa: &mut Rsa) { + rsa.write_multi_mode((N / 32 - 1 + 8) as u32) + } + + pub(super) fn set_start(&mut self) { + self.rsa.write_multi_start(); + } +} diff --git a/esp-hal-common/src/rsa/esp32cX.rs b/esp-hal-common/src/rsa/esp32cX.rs new file mode 100644 index 00000000000..058d4dff40f --- /dev/null +++ b/esp-hal-common/src/rsa/esp32cX.rs @@ -0,0 +1,347 @@ +use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping}; + +use crate::rsa::{ + implement_op, + Multi, + Rsa, + RsaMode, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, +}; + +impl<'d> Rsa<'d> { + /// After the RSA Accelerator is released from reset, the memory blocks + /// needs to be initialized, only after that peripheral should be used. + /// This function would return without an error if the memory is initialized + pub fn ready(&mut self) -> nb::Result<(), Infallible> { + if self.rsa.query_clean.read().query_clean().bit_is_clear() { + return Err(nb::Error::WouldBlock); + } + Ok(()) + } + + /// Enables/disables rsa interrupt, when enabled rsa perpheral would + /// generate an interrupt when a operation is finished. + pub fn enable_disable_interrupt(&mut self, enable: bool) { + match enable { + true => self.rsa.int_ena.write(|w| w.int_ena().set_bit()), + false => self.rsa.int_ena.write(|w| w.int_ena().clear_bit()), + } + } + + fn write_mode(&mut self, mode: u32) { + Self::write_to_register(&mut self.rsa.mode, mode as u32); + } + + /// Enables/disables search acceleration, when enabled it would increases + /// the performance of modular exponentiation by discarding the + /// exponent's bits before the most significant set bit. Note: this might + /// affect the security, for more info refer 18.3.4 of + pub fn enable_disable_search_acceleration(&mut self, enable: bool) { + match enable { + true => self + .rsa + .search_enable + .write(|w| w.search_enable().set_bit()), + false => self + .rsa + .search_enable + .write(|w| w.search_enable().clear_bit()), + } + } + + pub(super) fn is_search_enabled(&mut self) -> bool { + self.rsa.search_enable.read().search_enable().bit_is_set() + } + + pub(super) fn write_search_position(&mut self, search_position: u32) { + Self::write_to_register(&mut self.rsa.search_pos, search_position); + } + + /// Enables/disables constant time acceleration, when enabled it would + /// increases the performance of modular exponentiation by simplifying + /// the calculation concerning the 0 bits of the exponent i.e. lesser the + /// hamming weight, greater the performance. Note : this might affect + /// the security, for more info refer 18.3.4 of + pub fn enable_disable_constant_time_acceleration(&mut self, enable: bool) { + match enable { + true => self + .rsa + .constant_time + .write(|w| w.constant_time().clear_bit()), + false => self + .rsa + .constant_time + .write(|w| w.constant_time().set_bit()), + } + } + + pub(super) fn write_modexp_start(&mut self) { + self.rsa + .set_start_modexp + .write(|w| w.set_start_modexp().set_bit()); + } + + pub(super) fn write_multi_start(&mut self) { + self.rsa + .set_start_mult + .write(|w| w.set_start_mult().set_bit()); + } + + fn write_modmulti_start(&mut self) { + self.rsa + .set_start_modmult + .write(|w| w.set_start_modmult().set_bit()); + } + + pub(super) fn clear_interrupt(&mut self) { + self.rsa.int_clr.write(|w| w.clear_interrupt().set_bit()); + } + + pub(super) fn is_idle(&mut self) -> bool { + self.rsa.query_idle.read().query_idle().bit_is_set() + } + + unsafe fn write_multi_operand_b(&mut self, operand_b: &[u8; N]) { + copy_nonoverlapping( + operand_b.as_ptr(), + self.rsa.z_mem.as_mut_ptr().add(N) as *mut u8, + N, + ); + } +} + +pub mod operand_sizes { + //! Marker types for the operand sizes + use paste::paste; + + use super::{implement_op, Multi, RsaMode}; + + implement_op!( + (32, multi), + (64, multi), + (96, multi), + (128, multi), + (160, multi), + (192, multi), + (224, multi), + (256, multi), + (288, multi), + (320, multi), + (352, multi), + (384, multi), + (416, multi), + (448, multi), + (480, multi), + (512, multi), + (544, multi), + (576, multi), + (608, multi), + (640, multi), + (672, multi), + (704, multi), + (736, multi), + (768, multi), + (800, multi), + (832, multi), + (864, multi), + (896, multi), + (928, multi), + (960, multi), + (992, multi), + (1024, multi), + (1056, multi), + (1088, multi), + (1120, multi), + (1152, multi), + (1184, multi), + (1216, multi), + (1248, multi), + (1280, multi), + (1312, multi), + (1344, multi), + (1376, multi), + (1408, multi), + (1440, multi), + (1472, multi), + (1504, multi), + (1536, multi), + (1568), + (1600), + (1632), + (1664), + (1696), + (1728), + (1760), + (1792), + (1824), + (1856), + (1888), + (1920), + (1952), + (1984), + (2016), + (2048), + (2080), + (2112), + (2144), + (2176), + (2208), + (2240), + (2272), + (2304), + (2336), + (2368), + (2400), + (2432), + (2464), + (2496), + (2528), + (2560), + (2592), + (2624), + (2656), + (2688), + (2720), + (2752), + (2784), + (2816), + (2848), + (2880), + (2912), + (2944), + (2976), + (3008), + (3040), + (3072) + ); +} + +impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T> +where + T: RsaMode, +{ + /// Creates an Instance of `RsaModularExponentiation`. + /// `m_prime` could be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`, for more information check 19.3.1 in the + /// + pub fn new( + rsa: &'a mut Rsa<'d>, + exponent: &T::InputType, + modulus: &T::InputType, + m_prime: u32, + ) -> Self { + Self::set_mode(rsa); + unsafe { + rsa.write_operand_b(exponent); + rsa.write_modulus(modulus); + } + rsa.write_mprime(m_prime); + if rsa.is_search_enabled() { + rsa.write_search_position(Self::find_search_pos(exponent)); + } + Self { + rsa, + phantom: PhantomData, + } + } + + fn find_search_pos(exponent: &T::InputType) -> u32 { + for (i, byte) in exponent.iter().rev().enumerate() { + if *byte == 0 { + continue; + } + return (exponent.len() * 8) as u32 - (byte.leading_zeros() + i as u32 * 8) - 1; + } + 0 + } + + pub(super) fn set_mode(rsa: &mut Rsa) { + rsa.write_mode((N / 4 - 1) as u32) + } + + pub(super) fn set_start(&mut self) { + self.rsa.write_modexp_start(); + } +} + +impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T> +where + T: RsaMode, +{ + fn write_mode(rsa: &mut Rsa) { + rsa.write_mode((N / 4 - 1) as u32) + } + + /// Creates an Instance of `RsaModularMultiplication`. + /// `m_prime` could be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`, for more information check 19.3.1 in the + /// + pub fn new( + rsa: &'a mut Rsa<'d>, + operand_a: &T::InputType, + operand_b: &T::InputType, + modulus: &T::InputType, + m_prime: u32, + ) -> Self { + Self::write_mode(rsa); + rsa.write_mprime(m_prime); + unsafe { + rsa.write_modulus(modulus); + rsa.write_operand_a(operand_a); + rsa.write_operand_b(operand_b); + } + Self { + rsa, + phantom: PhantomData, + } + } + + /// Starts the modular multiplication operation. `r` could be calculated + /// using `2 ^ ( bitlength * 2 ) mod modulus`, for more information + /// check 19.3.1 in the + pub fn start_modular_multiplication(&mut self, r: &T::InputType) { + unsafe { + self.rsa.write_r(r); + } + self.set_start(); + } + + fn set_start(&mut self) { + self.rsa.write_modmulti_start(); + } +} + +impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T> +where + T: RsaMode, +{ + /// Creates an Instance of `RsaMultiplication`. + pub fn new(rsa: &'a mut Rsa<'d>, operand_a: &T::InputType) -> Self { + Self::set_mode(rsa); + unsafe { + rsa.write_operand_a(operand_a); + } + Self { + rsa, + phantom: PhantomData, + } + } + + /// Starts the multiplication operation. + pub fn start_multiplication(&mut self, operand_b: &T::InputType) { + unsafe { + self.rsa.write_multi_operand_b(operand_b); + } + self.set_start(); + } + + pub(super) fn set_mode(rsa: &mut Rsa) { + rsa.write_mode((N / 2 - 1) as u32) + } + + pub(super) fn set_start(&mut self) { + self.rsa.write_multi_start(); + } +} diff --git a/esp-hal-common/src/rsa/esp32sX.rs b/esp-hal-common/src/rsa/esp32sX.rs new file mode 100644 index 00000000000..3589d287a74 --- /dev/null +++ b/esp-hal-common/src/rsa/esp32sX.rs @@ -0,0 +1,386 @@ +use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping}; + +use crate::rsa::{ + implement_op, + Multi, + Rsa, + RsaMode, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, +}; + +impl<'d> Rsa<'d> { + /// After the RSA Accelerator is released from reset, the memory blocks + /// needs to be initialized, only after that peripheral should be used. + /// This function would return without an error if the memory is initialized + pub fn ready(&mut self) -> nb::Result<(), Infallible> { + if self.rsa.clean.read().clean().bit_is_clear() { + return Err(nb::Error::WouldBlock); + } + Ok(()) + } + + /// Enables/disables rsa interrupt, when enabled rsa perpheral would + /// generate an interrupt when a operation is finished. + pub fn enable_disable_interrupt(&mut self, enable: bool) { + match enable { + true => self + .rsa + .interrupt_ena + .write(|w| w.interrupt_ena().set_bit()), + false => self + .rsa + .interrupt_ena + .write(|w| w.interrupt_ena().clear_bit()), + } + } + + fn write_mode(&mut self, mode: u32) { + Self::write_to_register(&mut self.rsa.mode, mode as u32); + } + + /// Enables/disables search acceleration, when enabled it would increases + /// the performance of modular exponentiation by discarding the + /// exponent's bits before the most significant set bit. Note: this might + /// affect the security, for more info refer 18.3.4 of + pub fn enable_disable_search_acceleration(&mut self, enable: bool) { + match enable { + true => self + .rsa + .search_enable + .write(|w| w.search_enable().set_bit()), + false => self + .rsa + .search_enable + .write(|w| w.search_enable().clear_bit()), + } + } + + pub(super) fn is_search_enabled(&mut self) -> bool { + self.rsa.search_enable.read().search_enable().bit_is_set() + } + + pub(super) fn write_search_position(&mut self, search_position: u32) { + Self::write_to_register(&mut self.rsa.search_pos, search_position); + } + + /// Enables/disables constant time acceleration, when enabled it would + /// increases the performance of modular exponentiation by simplifying + /// the calculation concerning the 0 bits of the exponent i.e. lesser the + /// hamming weight, greater the performance. Note : this might affect + /// the security, for more info refer 18.3.4 of + pub fn enable_disable_constant_time_acceleration(&mut self, enable: bool) { + match enable { + true => self + .rsa + .constant_time + .write(|w| w.constant_time().clear_bit()), + false => self + .rsa + .constant_time + .write(|w| w.constant_time().set_bit()), + } + } + + pub(super) fn write_modexp_start(&mut self) { + self.rsa.modexp_start.write(|w| w.modexp_start().set_bit()); + } + + pub(super) fn write_multi_start(&mut self) { + self.rsa.mult_start.write(|w| w.mult_start().set_bit()); + } + + fn write_modmulti_start(&mut self) { + self.rsa + .modmult_start + .write(|w| w.modmult_start().set_bit()); + } + + pub(super) fn clear_interrupt(&mut self) { + self.rsa + .clear_interrupt + .write(|w| w.clear_interrupt().set_bit()); + } + + pub(super) fn is_idle(&mut self) -> bool { + self.rsa.idle.read().idle().bit_is_set() + } + + unsafe fn write_multi_operand_b(&mut self, operand_b: &[u8; N]) { + copy_nonoverlapping( + operand_b.as_ptr(), + self.rsa.z_mem.as_mut_ptr().add(N) as *mut u8, + N, + ); + } +} + +pub mod operand_sizes { + //! Marker types for the operand sizes + use paste::paste; + + use super::{implement_op, Multi, RsaMode}; + + implement_op!( + (32, multi), + (64, multi), + (96, multi), + (128, multi), + (160, multi), + (192, multi), + (224, multi), + (256, multi), + (288, multi), + (320, multi), + (352, multi), + (384, multi), + (416, multi), + (448, multi), + (480, multi), + (512, multi), + (544, multi), + (576, multi), + (608, multi), + (640, multi), + (672, multi), + (704, multi), + (736, multi), + (768, multi), + (800, multi), + (832, multi), + (864, multi), + (896, multi), + (928, multi), + (960, multi), + (992, multi), + (1024, multi), + (1056, multi), + (1088, multi), + (1120, multi), + (1152, multi), + (1184, multi), + (1216, multi), + (1248, multi), + (1280, multi), + (1312, multi), + (1344, multi), + (1376, multi), + (1408, multi), + (1440, multi), + (1472, multi), + (1504, multi), + (1536, multi), + (1568, multi), + (1600, multi), + (1632, multi), + (1664, multi), + (1696, multi), + (1728, multi), + (1760, multi), + (1792, multi), + (1824, multi), + (1856, multi), + (1888, multi), + (1920, multi), + (1952, multi), + (1984, multi), + (2016, multi), + (2048, multi) + ); + + implement_op!( + (2080), + (2112), + (2144), + (2176), + (2208), + (2240), + (2272), + (2304), + (2336), + (2368), + (2400), + (2432), + (2464), + (2496), + (2528), + (2560), + (2592), + (2624), + (2656), + (2688), + (2720), + (2752), + (2784), + (2816), + (2848), + (2880), + (2912), + (2944), + (2976), + (3008), + (3040), + (3072), + (3104), + (3136), + (3168), + (3200), + (3232), + (3264), + (3296), + (3328), + (3360), + (3392), + (3424), + (3456), + (3488), + (3520), + (3552), + (3584), + (3616), + (3648), + (3680), + (3712), + (3744), + (3776), + (3808), + (3840), + (3872), + (3904), + (3936), + (3968), + (4000), + (4032), + (4064), + (4096) + ); +} + +impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T> +where + T: RsaMode, +{ + /// Creates an Instance of `RsaModularExponentiation`. + /// `m_prime` could be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`, for more information check 19.3.1 in the + /// + pub fn new( + rsa: &'a mut Rsa<'d>, + exponent: &T::InputType, + modulus: &T::InputType, + m_prime: u32, + ) -> Self { + Self::set_mode(rsa); + unsafe { + rsa.write_operand_b(exponent); + rsa.write_modulus(modulus); + } + rsa.write_mprime(m_prime); + if rsa.is_search_enabled() { + rsa.write_search_position(Self::find_search_pos(exponent)); + } + Self { + rsa, + phantom: PhantomData, + } + } + + fn find_search_pos(exponent: &T::InputType) -> u32 { + for (i, byte) in exponent.iter().rev().enumerate() { + if *byte == 0 { + continue; + } + return (exponent.len() * 8) as u32 - (byte.leading_zeros() + i as u32 * 8) - 1; + } + 0 + } + + pub(super) fn set_mode(rsa: &mut Rsa) { + rsa.write_mode((N / 4 - 1) as u32) + } + + pub(super) fn set_start(&mut self) { + self.rsa.write_modexp_start(); + } +} + +impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T> +where + T: RsaMode, +{ + /// Creates an Instance of `RsaModularMultiplication`. + /// `m_prime` could be calculated using `-(modular multiplicative inverse of + /// modulus) mod 2^32`, for more information check 19.3.1 in the + /// + pub fn new( + rsa: &'a mut Rsa<'d>, + operand_a: &T::InputType, + operand_b: &T::InputType, + modulus: &T::InputType, + m_prime: u32, + ) -> Self { + Self::write_mode(rsa); + rsa.write_mprime(m_prime); + unsafe { + rsa.write_modulus(modulus); + rsa.write_operand_a(operand_a); + rsa.write_operand_b(operand_b); + } + Self { + rsa, + phantom: PhantomData, + } + } + + fn write_mode(rsa: &mut Rsa) { + rsa.write_mode((N / 4 - 1) as u32) + } + + /// Starts the modular multiplication operation. `r` could be calculated + /// using `2 ^ ( bitlength * 2 ) mod modulus`, for more information + /// check 19.3.1 in the + pub fn start_modular_multiplication(&mut self, r: &T::InputType) { + unsafe { + self.rsa.write_r(r); + } + self.set_start(); + } + + fn set_start(&mut self) { + self.rsa.write_modmulti_start(); + } +} + +impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T> +where + T: RsaMode, +{ + /// Creates an Instance of `RsaMultiplication`. + pub fn new(rsa: &'a mut Rsa<'d>, operand_a: &T::InputType) -> Self { + Self::set_mode(rsa); + unsafe { + rsa.write_operand_a(operand_a); + } + Self { + rsa, + phantom: PhantomData, + } + } + + /// Starts the multiplication operation. + pub fn start_multiplication(&mut self, operand_b: &T::InputType) { + unsafe { + self.rsa.write_multi_operand_b(operand_b); + } + self.set_start(); + } + + pub(super) fn set_mode(rsa: &mut Rsa) { + rsa.write_mode((N / 2 - 1) as u32) + } + + pub(super) fn set_start(&mut self) { + self.rsa.write_multi_start(); + } +} diff --git a/esp-hal-common/src/rsa/mod.rs b/esp-hal-common/src/rsa/mod.rs new file mode 100644 index 00000000000..b0267373835 --- /dev/null +++ b/esp-hal-common/src/rsa/mod.rs @@ -0,0 +1,238 @@ +//! RSA Accelerator support. +//! +//! This module provides functions and structs for multi precision arithmetic +//! operations used in RSA asym-metric cipher algorithms +//! +//! ### Features +//! The RSA peripheral supports maximum operand of the following sizes for each +//! individual chips: +//! +//! | Feature | ESP32| ESP32-C3| ESP32-C6| ESP32-S2| ESP32-S3| +//! |---------------------- |------|---------|---------|---------|---------| +//! |modular exponentiation |4096 |3072 |3072 |4096 |4096 | +//! |modular multiplication |4096 |3072 |3072 |4096 |4096 | +//! |multiplication |2048 |1536 |1536 |2048 |2048 | + +use core::{convert::Infallible, marker::PhantomData, ptr::copy_nonoverlapping}; + +use crate::{ + peripheral::{Peripheral, PeripheralRef}, + peripherals::{ + generic::{Reg, RegisterSpec, Resettable, Writable}, + RSA, + }, + system::{Peripheral as PeripheralEnable, PeripheralClockControl}, +}; + +#[cfg_attr(esp32s2, path = "esp32sX.rs")] +#[cfg_attr(esp32s3, path = "esp32sX.rs")] +#[cfg_attr(esp32c3, path = "esp32cX.rs")] +#[cfg_attr(esp32c6, path = "esp32cX.rs")] +#[cfg_attr(esp32, path = "esp32.rs")] +mod rsa_spec_impl; + +pub use rsa_spec_impl::operand_sizes; + +/// RSA peripheral container +pub struct Rsa<'d> { + rsa: PeripheralRef<'d, RSA>, +} + +impl<'d> Rsa<'d> { + pub fn new( + rsa: impl Peripheral

+ 'd, + peripheral_clock_control: &mut PeripheralClockControl, + ) -> Self { + crate::into_ref!(rsa); + let mut ret = Self { rsa }; + ret.init(peripheral_clock_control); + ret + } + + fn init(&mut self, peripheral_clock_control: &mut PeripheralClockControl) { + peripheral_clock_control.enable(PeripheralEnable::Rsa); + } + + unsafe fn write_operand_b(&mut self, operand_b: &[u8; N]) { + copy_nonoverlapping( + operand_b.as_ptr(), + self.rsa.y_mem.as_mut_ptr() as *mut u8, + N, + ); + } + + unsafe fn write_modulus(&mut self, modulus: &[u8; N]) { + copy_nonoverlapping(modulus.as_ptr(), self.rsa.m_mem.as_mut_ptr() as *mut u8, N); + } + + fn write_mprime(&mut self, m_prime: u32) { + Self::write_to_register(&mut self.rsa.m_prime, m_prime); + } + + unsafe fn write_operand_a(&mut self, operand_a: &[u8; N]) { + copy_nonoverlapping( + operand_a.as_ptr(), + self.rsa.x_mem.as_mut_ptr() as *mut u8, + N, + ); + } + + unsafe fn write_r(&mut self, r: &[u8; N]) { + copy_nonoverlapping(r.as_ptr(), self.rsa.z_mem.as_mut_ptr() as *mut u8, N); + } + + unsafe fn read_out(&mut self, outbuf: &mut [u8; N]) { + copy_nonoverlapping(self.rsa.z_mem.as_ptr() as *const u8, outbuf.as_mut_ptr(), N); + } + + fn write_to_register(reg: &mut Reg, data: u32) + where + T: RegisterSpec + Resettable + Writable, + { + reg.write(|w| unsafe { w.bits(data) }); + } +} + +mod sealed { + pub trait RsaMode { + type InputType; + } + pub trait Multi: RsaMode { + type OutputType; + } +} + +pub(self) use sealed::*; + +macro_rules! implement_op { + + (($x:literal, multi)) => { + paste! {pub struct [];} + paste! { + impl Multi for [] { + type OutputType = [u8; $x*2 / 8]; + }} + paste!{ + impl RsaMode for [] { + type InputType = [u8; $x / 8]; + }} + }; + + (($x:literal)) => { + paste! {pub struct [];} + paste!{ + impl RsaMode for [] { + type InputType = [u8; $x / 8]; + }} + }; + + ($x:tt, $($y:tt),+) => { + implement_op!($x); + implement_op!($($y),+); + }; +} + +pub(self) use implement_op; + +/// Support for RSA peripheral's modular exponentiation feature that could be +/// used to find the `(base ^ exponent) mod modulus`. +/// +/// Each operand is a little endian byte array of the same size +pub struct RsaModularExponentiation<'a, 'd, T: RsaMode> { + rsa: &'a mut Rsa<'d>, + phantom: PhantomData, +} + +impl<'a, 'd, T: RsaMode, const N: usize> RsaModularExponentiation<'a, 'd, T> +where + T: RsaMode, +{ + /// starts the modular exponentiation operation. `r` could be calculated + /// using `2 ^ ( bitlength * 2 ) mod modulus`, for more information + /// check 24.3.2 in the + pub fn start_exponentiation(&mut self, base: &T::InputType, r: &T::InputType) { + unsafe { + self.rsa.write_operand_a(base); + self.rsa.write_r(r); + } + self.set_start(); + } + + /// reads the result to the given buffer. + /// This is a non blocking function that returns without an error if + /// operation is completed successfully. `start_exponentiation` must be + /// called before calling this function. + pub fn read_results(&mut self, outbuf: &mut T::InputType) -> nb::Result<(), Infallible> { + if !self.rsa.is_idle() { + return Err(nb::Error::WouldBlock); + } + unsafe { + self.rsa.read_out(outbuf); + } + self.rsa.clear_interrupt(); + Ok(()) + } +} + +/// Support for RSA peripheral's modular multiplication feature that could be +/// used to find the `(operand a * operand b) mod modulus`. +/// +/// Each operand is a little endian byte array of the same size +pub struct RsaModularMultiplication<'a, 'd, T: RsaMode> { + rsa: &'a mut Rsa<'d>, + phantom: PhantomData, +} + +impl<'a, 'd, T: RsaMode, const N: usize> RsaModularMultiplication<'a, 'd, T> +where + T: RsaMode, +{ + /// Reads the result to the given buffer. + /// This is a non blocking function that returns without an error if + /// operation is completed successfully. + pub fn read_results(&mut self, outbuf: &mut T::InputType) -> nb::Result<(), Infallible> { + if !self.rsa.is_idle() { + return Err(nb::Error::WouldBlock); + } + unsafe { + self.rsa.read_out(outbuf); + } + self.rsa.clear_interrupt(); + Ok(()) + } +} + +/// Support for RSA peripheral's large number multiplication feature that could +/// be used to find the `operand a * operand b`. +/// +/// Each operand is a little endian byte array of the same size +pub struct RsaMultiplication<'a, 'd, T: RsaMode + Multi> { + rsa: &'a mut Rsa<'d>, + phantom: PhantomData, +} + +impl<'a, 'd, T: RsaMode + Multi, const N: usize> RsaMultiplication<'a, 'd, T> +where + T: RsaMode, +{ + /// Reads the result to the given buffer. + /// This is a non blocking function that returns without an error if + /// operation is completed successfully. `start_multiplication` must be + /// called before calling this function. + pub fn read_results<'b, const O: usize>( + &mut self, + outbuf: &mut T::OutputType, + ) -> nb::Result<(), Infallible> + where + T: Multi, + { + if !self.rsa.is_idle() { + return Err(nb::Error::WouldBlock); + } + unsafe { + self.rsa.read_out(outbuf); + } + self.rsa.clear_interrupt(); + Ok(()) + } +} diff --git a/esp-hal-common/src/system.rs b/esp-hal-common/src/system.rs old mode 100644 new mode 100755 index 7b4908b970f..c688457b4e5 --- a/esp-hal-common/src/system.rs +++ b/esp-hal-common/src/system.rs @@ -67,6 +67,8 @@ pub enum Peripheral { Uart1, #[cfg(uart2)] Uart2, + #[cfg(rsa)] + Rsa, } /// Controls the enablement of peripheral clocks. @@ -257,6 +259,17 @@ impl PeripheralClockControl { perip_clk_en0.modify(|_, w| w.uart2_clk_en().set_bit()); perip_rst_en0.modify(|_, w| w.uart2_rst().clear_bit()); } + #[cfg(esp32)] + Peripheral::Rsa => { + peri_clk_en.modify(|r, w| unsafe { w.bits(r.bits() | 1 << 2) }); + peri_rst_en.modify(|r, w| unsafe { w.bits(r.bits() & !(1 << 2)) }); + } + #[cfg(any(esp32c3, esp32s2, esp32s3))] + Peripheral::Rsa => { + perip_clk_en1.modify(|_, w| w.crypto_rsa_clk_en().set_bit()); + perip_rst_en1.modify(|_, w| w.crypto_rsa_rst().clear_bit()); + system.rsa_pd_ctrl.modify(|_, w| w.rsa_mem_pd().clear_bit()); + } } } } @@ -398,6 +411,11 @@ impl PeripheralClockControl { .uart1_conf .modify(|_, w| w.uart1_rst_en().clear_bit()); } + Peripheral::Rsa => { + system.rsa_conf.modify(|_, w| w.rsa_clk_en().set_bit()); + system.rsa_conf.modify(|_, w| w.rsa_rst_en().clear_bit()); + system.rsa_pd_ctrl.modify(|_, w| w.rsa_mem_pd().clear_bit()); + } } } } diff --git a/esp32-hal/Cargo.toml b/esp32-hal/Cargo.toml index 5d8908f5047..3c717fcf0a3 100644 --- a/esp32-hal/Cargo.toml +++ b/esp32-hal/Cargo.toml @@ -44,6 +44,7 @@ sha2 = { version = "0.10.6", default-features = false} smart-leds = "0.3.0" ssd1306 = "0.7.1" static_cell = "1.0.0" +crypto-bigint = {version = "0.5.0-pre.3",default-features = false} [features] default = ["rt", "vectored", "xtal40mhz"] diff --git a/esp32-hal/examples/rsa.rs b/esp32-hal/examples/rsa.rs new file mode 100644 index 00000000000..e8b132af645 --- /dev/null +++ b/esp32-hal/examples/rsa.rs @@ -0,0 +1,174 @@ +#![no_std] +#![no_main] + +use crypto_bigint::{ + modular::runtime_mod::{DynResidue, DynResidueParams}, + Encoding, + Uint, + U1024, + U512, +}; +use esp32_hal::{ + clock::ClockControl, + peripherals::Peripherals, + prelude::*, + rsa::{ + operand_sizes, + Rsa, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + timer::TimerGroup, + xtensa_lx, + Rtc, +}; +use esp_backtrace as _; +use esp_println::println; +use nb::block; + +const BIGNUM_1: U512 = Uint::from_be_hex( + "c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\ +7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878", +); +const BIGNUM_2: U512 = Uint::from_be_hex( + "1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\ +6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51", +); +const BIGNUM_3: U512 = Uint::from_be_hex( + "6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\ +67aeafa9ba25feb8712f188cb139b7d9b9af1c361", +); + +const fn compute_r(modulus: &U512) -> U512 { + let mut d = [0_u32; U512::LIMBS * 2 + 1]; + d[d.len() - 1] = 1; + let d = Uint::from_words(d); + d.const_rem(&modulus.resize()).0.resize() +} + +const fn compute_mprime(modulus: &U512) -> u32 { + let m_inv = modulus.inv_mod2k(32).to_words()[0]; + (-1 * m_inv as i64 % 4294967296) as u32 +} + +#[entry] +fn main() -> ! { + let peripherals = Peripherals::take(); + let mut system = peripherals.DPORT.split(); + let clocks = ClockControl::boot_defaults(system.clock_control).freeze(); + + // Disable the RTC and TIMG watchdog timers + let mut rtc = Rtc::new(peripherals.RTC_CNTL); + let timer_group0 = TimerGroup::new( + peripherals.TIMG0, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt0 = timer_group0.wdt; + let timer_group1 = TimerGroup::new( + peripherals.TIMG1, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt1 = timer_group1.wdt; + + rtc.rwdt.disable(); + wdt0.disable(); + wdt1.disable(); + + let rsa = peripherals.RSA; + let mut rsa = Rsa::new(rsa, &mut system.peripheral_clock_control); + block!(rsa.ready()).unwrap(); + mod_exp_example(&mut rsa); + mod_multi_example(&mut rsa); + multiplication_example(&mut rsa); + loop {} +} + +fn mod_multi_example(rsa: &mut Rsa) { + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_multi = RsaModularMultiplication::::new( + rsa, + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let pre_hw_modmul = xtensa_lx::timer::get_cycle_count(); + mod_multi.start_step1(&BIGNUM_1.to_le_bytes(), &r); + block!(mod_multi.start_step2(&BIGNUM_2.to_le_bytes())).unwrap(); + block!(mod_multi.read_results(&mut outbuf)).unwrap(); + let post_hw_modmul = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for hw modular multiplication", + post_hw_modmul - pre_hw_modmul + ); + + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue_num1 = DynResidue::new(&BIGNUM_1, residue_params); + let residue_num2 = DynResidue::new(&BIGNUM_2, residue_params); + let pre_sw_exp = xtensa_lx::timer::get_cycle_count(); + let sw_out = residue_num1.mul(&residue_num2); + let post_sw_exp = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for sw modular multiplication", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular multiplication done"); +} + +fn mod_exp_example(rsa: &mut Rsa) { + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_exp = RsaModularExponentiation::::new( + rsa, + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let base = &BIGNUM_1.to_le_bytes(); + let pre_hw_exp = xtensa_lx::timer::get_cycle_count(); + mod_exp.start_exponentiation(base, &r); + block!(mod_exp.read_results(&mut outbuf)).unwrap(); + let post_hw_exp = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for hw modular exponentiation", + post_hw_exp - pre_hw_exp + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue = DynResidue::new(&BIGNUM_1, residue_params); + let pre_sw_exp = xtensa_lx::timer::get_cycle_count(); + let sw_out = residue.pow(&BIGNUM_2); + let post_sw_exp = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for sw modular exponentiation", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular exponentiation done"); +} + +fn multiplication_example(rsa: &mut Rsa) { + let mut out = [0_u8; U1024::BYTES]; + let mut rsamulti = RsaMultiplication::::new(rsa); + let operand_a = &BIGNUM_1.to_le_bytes(); + let operand_b = &BIGNUM_2.to_le_bytes(); + let pre_hw_mul = xtensa_lx::timer::get_cycle_count(); + rsamulti.start_multiplication(&operand_a, &operand_b); + block!(rsamulti.read_results(&mut out)).unwrap(); + let post_hw_mul = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for hw multiplication", + post_hw_mul - pre_hw_mul + ); + let pre_sw_mul = xtensa_lx::timer::get_cycle_count(); + let sw_out = BIGNUM_1.mul_wide(&BIGNUM_2); + let post_sw_mul = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for sw multiplication", + post_sw_mul - pre_sw_mul + ); + assert_eq!(U1024::from_le_bytes(out), sw_out.1.concat(&sw_out.0)); + println!("multiplication done"); +} diff --git a/esp32c3-hal/Cargo.toml b/esp32c3-hal/Cargo.toml index 78f36b1d43a..6ef107aca1f 100644 --- a/esp32c3-hal/Cargo.toml +++ b/esp32c3-hal/Cargo.toml @@ -46,6 +46,7 @@ sha2 = { version = "0.10.6", default-features = false} smart-leds = "0.3.0" ssd1306 = "0.7.1" static_cell = "1.0.0" +crypto-bigint = {version = "0.5.0-pre.3",default-features = false} [features] default = ["rt", "vectored", "esp-hal-common/rv-zero-rtc-bss"] diff --git a/esp32c3-hal/examples/rsa.rs b/esp32c3-hal/examples/rsa.rs new file mode 100644 index 00000000000..0a27e6470cd --- /dev/null +++ b/esp32c3-hal/examples/rsa.rs @@ -0,0 +1,177 @@ +#![no_std] +#![no_main] + +use crypto_bigint::{ + modular::runtime_mod::{DynResidue, DynResidueParams}, + Encoding, + Uint, + U1024, + U512, +}; +use esp32c3_hal::{ + clock::ClockControl, + peripherals::Peripherals, + prelude::*, + rsa::{ + operand_sizes, + Rsa, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + systimer::SystemTimer, + timer::TimerGroup, + Rtc, +}; +use esp_backtrace as _; +use esp_println::println; +use nb::block; + +const BIGNUM_1: U512 = Uint::from_be_hex( + "c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\ +7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878", +); +const BIGNUM_2: U512 = Uint::from_be_hex( + "1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\ +6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51", +); +const BIGNUM_3: U512 = Uint::from_be_hex( + "6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\ +67aeafa9ba25feb8712f188cb139b7d9b9af1c361", +); + +const fn compute_r(modulus: &U512) -> U512 { + let mut d = [0_u32; U512::LIMBS * 2 + 1]; + d[d.len() - 1] = 1; + let d = Uint::from_words(d); + d.const_rem(&modulus.resize()).0.resize() +} + +const fn compute_mprime(modulus: &U512) -> u32 { + let m_inv = modulus.inv_mod2k(32).to_words()[0]; + (-1 * m_inv as i64 % 4294967296) as u32 +} + +#[entry] +fn main() -> ! { + let peripherals = Peripherals::take(); + let mut system = peripherals.SYSTEM.split(); + let clocks = ClockControl::boot_defaults(system.clock_control).freeze(); + + // Disable the RTC and TIMG watchdog timers + let mut rtc = Rtc::new(peripherals.RTC_CNTL); + let timer_group0 = TimerGroup::new( + peripherals.TIMG0, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt0 = timer_group0.wdt; + let timer_group1 = TimerGroup::new( + peripherals.TIMG1, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt1 = timer_group1.wdt; + + rtc.swd.disable(); + rtc.rwdt.disable(); + wdt0.disable(); + wdt1.disable(); + + let mut rsa = Rsa::new(peripherals.RSA, &mut system.peripheral_clock_control); + + block!(rsa.ready()).unwrap(); + mod_exp_example(&mut rsa); + mod_multi_example(&mut rsa); + multiplication_example(&mut rsa); + loop {} +} + +fn mod_multi_example(rsa: &mut Rsa) { + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_multi = RsaModularMultiplication::::new( + rsa, + &BIGNUM_1.to_le_bytes(), + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let pre_hw_modmul = SystemTimer::now(); + mod_multi.start_modular_multiplication(&r); + block!(mod_multi.read_results(&mut outbuf)).unwrap(); + let post_hw_modmul = SystemTimer::now(); + println!( + "it took {} cycles for hw modular multiplication", + post_hw_modmul - pre_hw_modmul + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue_num1 = DynResidue::new(&BIGNUM_1, residue_params); + let residue_num2 = DynResidue::new(&BIGNUM_2, residue_params); + let pre_sw_exp = SystemTimer::now(); + let sw_out = residue_num1.mul(&residue_num2); + let post_sw_exp = SystemTimer::now(); + println!( + "it took {} cycles for sw modular multiplication", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular multiplication done"); +} + +fn mod_exp_example(rsa: &mut Rsa) { + rsa.enable_disable_constant_time_acceleration(true); + rsa.enable_disable_search_acceleration(true); + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_exp = RsaModularExponentiation::::new( + rsa, + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let base = &BIGNUM_1.to_le_bytes(); + let pre_hw_exp = SystemTimer::now(); + mod_exp.start_exponentiation(&base, &r); + block!(mod_exp.read_results(&mut outbuf)).unwrap(); + let post_hw_exp = SystemTimer::now(); + println!( + "it took {} cycles for hw modular exponentiation", + post_hw_exp - pre_hw_exp + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue = DynResidue::new(&BIGNUM_1, residue_params); + let pre_sw_exp = SystemTimer::now(); + let sw_out = residue.pow(&BIGNUM_2); + let post_sw_exp = SystemTimer::now(); + println!( + "it took {} cycles for sw modular exponentiation", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular exponentiation done"); +} + +fn multiplication_example(rsa: &mut Rsa) { + let mut out = [0_u8; U1024::BYTES]; + let operand_a = &BIGNUM_1.to_le_bytes(); + let operand_b = &BIGNUM_2.to_le_bytes(); + let mut rsamulti = RsaMultiplication::::new(rsa, &operand_a); + let pre_hw_mul = SystemTimer::now(); + rsamulti.start_multiplication(&operand_b); + block!(rsamulti.read_results(&mut out)).unwrap(); + let post_hw_mul = SystemTimer::now(); + println!( + "it took {} cycles for hw multiplication", + post_hw_mul - pre_hw_mul + ); + let pre_sw_mul = SystemTimer::now(); + let sw_out = BIGNUM_1.mul_wide(&BIGNUM_2); + let post_sw_mul = SystemTimer::now(); + println!( + "it took {} cycles for sw multiplication", + post_sw_mul - pre_sw_mul + ); + assert_eq!(U1024::from_le_bytes(out), sw_out.1.concat(&sw_out.0)); + println!("multiplication done"); +} diff --git a/esp32c6-hal/Cargo.toml b/esp32c6-hal/Cargo.toml index 2fba5cf93d5..78319d02412 100644 --- a/esp32c6-hal/Cargo.toml +++ b/esp32c6-hal/Cargo.toml @@ -47,6 +47,8 @@ sha2 = { version = "0.10.6", default-features = false} smart-leds = "0.3.0" ssd1306 = "0.7.1" static_cell = "1.0.0" +crypto-bigint = {version = "0.5.0-pre.3",default-features = false} + [features] default = ["rt", "vectored", "esp-hal-common/rv-zero-rtc-bss"] diff --git a/esp32c6-hal/examples/rsa.rs b/esp32c6-hal/examples/rsa.rs new file mode 100644 index 00000000000..bf3f38ffae9 --- /dev/null +++ b/esp32c6-hal/examples/rsa.rs @@ -0,0 +1,185 @@ +#![no_std] +#![no_main] + +use crypto_bigint::{ + modular::runtime_mod::{DynResidue, DynResidueParams}, + Encoding, + Uint, + U1024, + U512, +}; +use esp32c6_hal::{ + clock::ClockControl, + peripherals::Peripherals, + prelude::*, + rsa::{ + operand_sizes, + Rsa, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + systimer::SystemTimer, + timer::TimerGroup, + Rtc, +}; +use esp_backtrace as _; +use esp_println::println; +use nb::block; + +const BIGNUM_1: U512 = Uint::from_be_hex( + "c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\ +7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878", +); +const BIGNUM_2: U512 = Uint::from_be_hex( + "1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\ +6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51", +); +const BIGNUM_3: U512 = Uint::from_be_hex( + "6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\ +67aeafa9ba25feb8712f188cb139b7d9b9af1c361", +); + +const fn compute_r(modulus: &U512) -> U512 { + let mut d = [0_u32; U512::LIMBS * 2 + 1]; + d[d.len() - 1] = 1; + let d = Uint::from_words(d); + d.const_rem(&modulus.resize()).0.resize() +} + +const fn compute_mprime(modulus: &U512) -> u32 { + let m_inv = modulus.inv_mod2k(32).to_words()[0]; + (-1 * m_inv as i64 % 4294967296) as u32 +} + +#[entry] +fn main() -> ! { + let peripherals = Peripherals::take(); + let mut system = peripherals.PCR.split(); + let clocks = ClockControl::boot_defaults(system.clock_control).freeze(); + + // Disable the watchdog timers. For the ESP32-C6, this includes the Super WDT, + // and the TIMG WDTs. + // Disable the watchdog timers. For the ESP32-C6, this includes the Super WDT, + // and the TIMG WDTs. + let mut rtc = Rtc::new(peripherals.LP_CLKRST); + let timer_group0 = TimerGroup::new( + peripherals.TIMG0, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt0 = timer_group0.wdt; + let timer_group1 = TimerGroup::new( + peripherals.TIMG1, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt1 = timer_group1.wdt; + + rtc.swd.disable(); + rtc.rwdt.disable(); + wdt0.disable(); + wdt1.disable(); + + rtc.swd.disable(); + rtc.rwdt.disable(); + wdt0.disable(); + wdt1.disable(); + + let mut rsa = Rsa::new(peripherals.RSA, &mut system.peripheral_clock_control); + + block!(rsa.ready()).unwrap(); + mod_exp_example(&mut rsa); + mod_multi_example(&mut rsa); + multiplication_example(&mut rsa); + loop {} +} + +fn mod_multi_example(rsa: &mut Rsa) { + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_multi = RsaModularMultiplication::::new( + rsa, + &BIGNUM_1.to_le_bytes(), + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let pre_hw_modmul = SystemTimer::now(); + mod_multi.start_modular_multiplication(&r); + block!(mod_multi.read_results(&mut outbuf)).unwrap(); + let post_hw_modmul = SystemTimer::now(); + println!( + "it took {} cycles for hw modular multiplication", + post_hw_modmul - pre_hw_modmul + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue_num1 = DynResidue::new(&BIGNUM_1, residue_params); + let residue_num2 = DynResidue::new(&BIGNUM_2, residue_params); + let pre_sw_exp = SystemTimer::now(); + let sw_out = residue_num1.mul(&residue_num2); + let post_sw_exp = SystemTimer::now(); + println!( + "it took {} cycles for sw modular multiplication", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular multiplication done"); +} + +fn mod_exp_example(rsa: &mut Rsa) { + rsa.enable_disable_constant_time_acceleration(true); + rsa.enable_disable_search_acceleration(true); + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_exp = RsaModularExponentiation::::new( + rsa, + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let base = &BIGNUM_1.to_le_bytes(); + let pre_hw_exp = SystemTimer::now(); + mod_exp.start_exponentiation(&base, &r); + block!(mod_exp.read_results(&mut outbuf)).unwrap(); + let post_hw_exp = SystemTimer::now(); + println!( + "it took {} cycles for hw modular exponentiation", + post_hw_exp - pre_hw_exp + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue = DynResidue::new(&BIGNUM_1, residue_params); + let pre_sw_exp = SystemTimer::now(); + let sw_out = residue.pow(&BIGNUM_2); + let post_sw_exp = SystemTimer::now(); + println!( + "it took {} cycles for sw modular exponentiation", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular exponentiation done"); +} + +fn multiplication_example(rsa: &mut Rsa) { + let mut out = [0_u8; U1024::BYTES]; + let operand_a = &BIGNUM_1.to_le_bytes(); + let operand_b = &BIGNUM_2.to_le_bytes(); + let mut rsamulti = RsaMultiplication::::new(rsa, &operand_a); + let pre_hw_mul = SystemTimer::now(); + rsamulti.start_multiplication(&operand_b); + block!(rsamulti.read_results(&mut out)).unwrap(); + let post_hw_mul = SystemTimer::now(); + println!( + "it took {} cycles for hw multiplication", + post_hw_mul - pre_hw_mul + ); + let pre_sw_mul = SystemTimer::now(); + let sw_out = BIGNUM_1.mul_wide(&BIGNUM_2); + let post_sw_mul = SystemTimer::now(); + println!( + "it took {} cycles for sw multiplication", + post_sw_mul - pre_sw_mul + ); + assert_eq!(U1024::from_le_bytes(out), sw_out.1.concat(&sw_out.0)); + println!("multiplication done"); +} diff --git a/esp32s2-hal/Cargo.toml b/esp32s2-hal/Cargo.toml index 8840c6ed2a1..d6cabad0503 100644 --- a/esp32s2-hal/Cargo.toml +++ b/esp32s2-hal/Cargo.toml @@ -47,6 +47,7 @@ ssd1306 = "0.7.1" static_cell = "1.0.0" usb-device = { version = "0.2.9" } usbd-serial = "0.1.1" +crypto-bigint = {version = "0.5.0-pre.3",default-features = false} [features] default = ["rt", "vectored"] diff --git a/esp32s2-hal/examples/rsa.rs b/esp32s2-hal/examples/rsa.rs new file mode 100644 index 00000000000..b37b51deace --- /dev/null +++ b/esp32s2-hal/examples/rsa.rs @@ -0,0 +1,176 @@ +#![no_std] +#![no_main] + +use crypto_bigint::{ + modular::runtime_mod::{DynResidue, DynResidueParams}, + Encoding, + Uint, + U1024, + U512, +}; +use esp32s2_hal::{ + clock::ClockControl, + peripherals::Peripherals, + prelude::*, + rsa::{ + operand_sizes, + Rsa, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + timer::TimerGroup, + xtensa_lx, + Rtc, +}; +use esp_backtrace as _; +use esp_println::println; +use nb::block; + +const BIGNUM_1: U512 = Uint::from_be_hex( + "c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\ + 7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878", +); +const BIGNUM_2: U512 = Uint::from_be_hex( + "1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\ + 6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51", +); +const BIGNUM_3: U512 = Uint::from_be_hex( + "6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\ + 67aeafa9ba25feb8712f188cb139b7d9b9af1c361", +); + +const fn compute_r(modulus: &U512) -> U512 { + let mut d = [0_u32; U512::LIMBS * 2 + 1]; + d[d.len() - 1] = 1; + let d = Uint::from_words(d); + d.const_rem(&modulus.resize()).0.resize() +} + +const fn compute_mprime(modulus: &U512) -> u32 { + let m_inv = modulus.inv_mod2k(32).to_words()[0]; + (-1 * m_inv as i64 % 4294967296) as u32 +} + +#[entry] +fn main() -> ! { + let peripherals = Peripherals::take(); + let mut system = peripherals.SYSTEM.split(); + let clocks = ClockControl::boot_defaults(system.clock_control).freeze(); + + // Disable the RTC and TIMG watchdog timers + let mut rtc = Rtc::new(peripherals.RTC_CNTL); + let timer_group0 = TimerGroup::new( + peripherals.TIMG0, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt0 = timer_group0.wdt; + let timer_group1 = TimerGroup::new( + peripherals.TIMG1, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt1 = timer_group1.wdt; + + rtc.rwdt.disable(); + wdt0.disable(); + wdt1.disable(); + let mut rsa = Rsa::new(peripherals.RSA, &mut system.peripheral_clock_control); + + block!(rsa.ready()).unwrap(); + mod_exp_example(&mut rsa); + mod_multi_example(&mut rsa); + multiplication_example(&mut rsa); + + loop {} +} + +fn mod_multi_example(rsa: &mut Rsa) { + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_multi = RsaModularMultiplication::::new( + rsa, + &BIGNUM_1.to_le_bytes(), + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let pre_hw_modmul = xtensa_lx::timer::get_cycle_count(); + mod_multi.start_modular_multiplication(&r); + block!(mod_multi.read_results(&mut outbuf)).unwrap(); + let post_hw_modmul = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for hw modular multiplication", + post_hw_modmul - pre_hw_modmul + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue_num1 = DynResidue::new(&BIGNUM_1, residue_params); + let residue_num2 = DynResidue::new(&BIGNUM_2, residue_params); + let pre_sw_exp = xtensa_lx::timer::get_cycle_count(); + let sw_out = residue_num1.mul(&residue_num2); + let post_sw_exp = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for sw modular multiplication", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular multiplication done"); +} + +fn mod_exp_example(rsa: &mut Rsa) { + rsa.enable_disable_constant_time_acceleration(true); + rsa.enable_disable_search_acceleration(true); + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_exp = RsaModularExponentiation::::new( + rsa, + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let base = &BIGNUM_1.to_le_bytes(); + let pre_hw_exp = xtensa_lx::timer::get_cycle_count(); + mod_exp.start_exponentiation(&base, &r); + block!(mod_exp.read_results(&mut outbuf)).unwrap(); + let post_hw_exp = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for hw modular exponentiation", + post_hw_exp - pre_hw_exp + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue = DynResidue::new(&BIGNUM_1, residue_params); + let pre_sw_exp = xtensa_lx::timer::get_cycle_count(); + let sw_out = residue.pow(&BIGNUM_2); + let post_sw_exp = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for sw modular exponentiation", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular exponentiation done"); +} + +fn multiplication_example(rsa: &mut Rsa) { + let mut out = [0_u8; U1024::BYTES]; + let operand_a = &BIGNUM_1.to_le_bytes(); + let operand_b = &BIGNUM_2.to_le_bytes(); + let mut rsamulti = RsaMultiplication::::new(rsa, &operand_a); + let pre_hw_mul = xtensa_lx::timer::get_cycle_count(); + rsamulti.start_multiplication(&operand_b); + block!(rsamulti.read_results(&mut out)).unwrap(); + let post_hw_mul = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for hw multiplication", + post_hw_mul - pre_hw_mul + ); + let pre_sw_mul = xtensa_lx::timer::get_cycle_count(); + let sw_out = BIGNUM_1.mul_wide(&BIGNUM_2); + let post_sw_mul = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for sw multiplication", + post_sw_mul - pre_sw_mul + ); + assert_eq!(U1024::from_le_bytes(out), sw_out.1.concat(&sw_out.0)); + println!("multiplication done"); +} diff --git a/esp32s3-hal/Cargo.toml b/esp32s3-hal/Cargo.toml index d62499258b4..acacf53dc75 100644 --- a/esp32s3-hal/Cargo.toml +++ b/esp32s3-hal/Cargo.toml @@ -49,6 +49,7 @@ ssd1306 = "0.7.1" static_cell = "1.0.0" usb-device = { version = "0.2.9" } usbd-serial = "0.1.1" +crypto-bigint = {version = "0.5.0-pre.3",default-features = false} [features] default = ["rt", "vectored"] diff --git a/esp32s3-hal/examples/rsa.rs b/esp32s3-hal/examples/rsa.rs new file mode 100644 index 00000000000..d74ad8203ba --- /dev/null +++ b/esp32s3-hal/examples/rsa.rs @@ -0,0 +1,176 @@ +#![no_std] +#![no_main] + +use crypto_bigint::{ + modular::runtime_mod::{DynResidue, DynResidueParams}, + Encoding, + Uint, + U1024, + U512, +}; +use esp32s3_hal::{ + clock::ClockControl, + peripherals::Peripherals, + prelude::*, + rsa::{ + operand_sizes, + Rsa, + RsaModularExponentiation, + RsaModularMultiplication, + RsaMultiplication, + }, + timer::TimerGroup, + xtensa_lx, + Rtc, +}; +use esp_backtrace as _; +use esp_println::println; +use nb::block; + +const BIGNUM_1: U512 = Uint::from_be_hex( + "c7f61058f96db3bd87dbab08ab03b4f7f2f864eac249144adea6a65f97803b719d8ca980b7b3c0389c1c7c6\ + 7dc353c5e0ec11f5fc8ce7f6073796cc8f73fa878", +); +const BIGNUM_2: U512 = Uint::from_be_hex( + "1763db3344e97be15d04de4868badb12a38046bb793f7630d87cf100aa1c759afac15a01f3c4c83ec2d2f66\ + 6bd22f71c3c1f075ec0e2cb0cb29994d091b73f51", +); +const BIGNUM_3: U512 = Uint::from_be_hex( + "6b6bb3d2b6cbeb45a769eaa0384e611e1b89b0c9b45a045aca1c5fd6e8785b38df7118cf5dd45b9b63d293b\ + 67aeafa9ba25feb8712f188cb139b7d9b9af1c361", +); + +const fn compute_r(modulus: &U512) -> U512 { + let mut d = [0_u32; U512::LIMBS * 2 + 1]; + d[d.len() - 1] = 1; + let d = Uint::from_words(d); + d.const_rem(&modulus.resize()).0.resize() +} + +const fn compute_mprime(modulus: &U512) -> u32 { + let m_inv = modulus.inv_mod2k(32).to_words()[0]; + (-1 * m_inv as i64 % 4294967296) as u32 +} + +#[entry] +fn main() -> ! { + let peripherals = Peripherals::take(); + let mut system = peripherals.SYSTEM.split(); + let clocks = ClockControl::boot_defaults(system.clock_control).freeze(); + + // Disable the RTC and TIMG watchdog timers + let mut rtc = Rtc::new(peripherals.RTC_CNTL); + let timer_group0 = TimerGroup::new( + peripherals.TIMG0, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt0 = timer_group0.wdt; + let timer_group1 = TimerGroup::new( + peripherals.TIMG1, + &clocks, + &mut system.peripheral_clock_control, + ); + let mut wdt1 = timer_group1.wdt; + + rtc.rwdt.disable(); + wdt0.disable(); + wdt1.disable(); + let mut rsa = Rsa::new(peripherals.RSA, &mut system.peripheral_clock_control); + + block!(rsa.ready()).unwrap(); + mod_exp_example(&mut rsa); + mod_multi_example(&mut rsa); + multiplication_example(&mut rsa); + + loop {} +} + +fn mod_multi_example(rsa: &mut Rsa) { + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_multi = RsaModularMultiplication::::new( + rsa, + &BIGNUM_1.to_le_bytes(), + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let pre_hw_modmul = xtensa_lx::timer::get_cycle_count(); + mod_multi.start_modular_multiplication(&r); + block!(mod_multi.read_results(&mut outbuf)).unwrap(); + let post_hw_modmul = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for hw modular multiplication", + post_hw_modmul - pre_hw_modmul + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue_num1 = DynResidue::new(&BIGNUM_1, residue_params); + let residue_num2 = DynResidue::new(&BIGNUM_2, residue_params); + let pre_sw_exp = xtensa_lx::timer::get_cycle_count(); + let sw_out = residue_num1.mul(&residue_num2); + let post_sw_exp = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for sw modular multiplication", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular multiplication done"); +} + +fn mod_exp_example(rsa: &mut Rsa) { + rsa.enable_disable_constant_time_acceleration(true); + rsa.enable_disable_search_acceleration(true); + let mut outbuf = [0_u8; U512::BYTES]; + let mut mod_exp = RsaModularExponentiation::::new( + rsa, + &BIGNUM_2.to_le_bytes(), + &BIGNUM_3.to_le_bytes(), + compute_mprime(&BIGNUM_3), + ); + let r = compute_r(&BIGNUM_3).to_le_bytes(); + let base = &BIGNUM_1.to_le_bytes(); + let pre_hw_exp = xtensa_lx::timer::get_cycle_count(); + mod_exp.start_exponentiation(&base, &r); + block!(mod_exp.read_results(&mut outbuf)).unwrap(); + let post_hw_exp = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for hw modular exponentiation", + post_hw_exp - pre_hw_exp + ); + let residue_params = DynResidueParams::new(&BIGNUM_3); + let residue = DynResidue::new(&BIGNUM_1, residue_params); + let pre_sw_exp = xtensa_lx::timer::get_cycle_count(); + let sw_out = residue.pow(&BIGNUM_2); + let post_sw_exp = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for sw modular exponentiation", + post_sw_exp - pre_sw_exp + ); + assert_eq!(U512::from_le_bytes(outbuf), sw_out.retrieve()); + println!("modular exponentiation done"); +} + +fn multiplication_example(rsa: &mut Rsa) { + let mut out = [0_u8; U1024::BYTES]; + let operand_a = &BIGNUM_1.to_le_bytes(); + let operand_b = &BIGNUM_2.to_le_bytes(); + let mut rsamulti = RsaMultiplication::::new(rsa, &operand_a); + let pre_hw_mul = xtensa_lx::timer::get_cycle_count(); + rsamulti.start_multiplication(&operand_b); + block!(rsamulti.read_results(&mut out)).unwrap(); + let post_hw_mul = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for hw multiplication", + post_hw_mul - pre_hw_mul + ); + let pre_sw_mul = xtensa_lx::timer::get_cycle_count(); + let sw_out = BIGNUM_1.mul_wide(&BIGNUM_2); + let post_sw_mul = xtensa_lx::timer::get_cycle_count(); + println!( + "it took {} cycles for sw multiplication", + post_sw_mul - pre_sw_mul + ); + assert_eq!(U1024::from_le_bytes(out), sw_out.1.concat(&sw_out.0)); + println!("multiplication done"); +}