Skip to content

Commit

Permalink
aes: refactor ARMv8 expand_key (#367)
Browse files Browse the repository at this point in the history
Changes `expand_key` to an `unsafe fn` that uses `target_feature`.

Removes the TODOs: due to AES-192 this function can't be easily
refactored to use `vinterpretq_u8_u32`.
  • Loading branch information
tarcieri authored Jun 17, 2023
1 parent 8d03900 commit eb309c6
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 18 deletions.
2 changes: 1 addition & 1 deletion aes/src/armv8.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ macro_rules! define_aes_impl {
impl KeyInit for $name_enc {
fn new(key: &Key<Self>) -> Self {
Self {
round_keys: expand_key(key.as_ref()),
round_keys: unsafe { expand_key(key.as_ref()) },
}
}
}
Expand Down
26 changes: 12 additions & 14 deletions aes/src/armv8/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@ const WORD_SIZE: usize = 4;
/// AES round constants.
const ROUND_CONSTS: [u32; 10] = [0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36];

/// AES key expansion
// TODO(tarcieri): big endian support?
#[inline]
pub(super) fn expand_key<const L: usize, const N: usize>(key: &[u8; L]) -> [uint8x16_t; N] {
/// AES key expansion.
#[target_feature(enable = "aes")]
pub unsafe fn expand_key<const L: usize, const N: usize>(key: &[u8; L]) -> [uint8x16_t; N] {
assert!((L == 16 && N == 11) || (L == 24 && N == 13) || (L == 32 && N == 15));

let mut expanded_keys: [uint8x16_t; N] = unsafe { mem::zeroed() };
let mut expanded_keys: [uint8x16_t; N] = mem::zeroed();

// TODO(tarcieri): construct expanded keys using `vreinterpretq_u8_u32`
let ek_words = unsafe {
slice::from_raw_parts_mut(expanded_keys.as_mut_ptr() as *mut u32, N * BLOCK_WORDS)
};
let columns =
slice::from_raw_parts_mut(expanded_keys.as_mut_ptr() as *mut u32, N * BLOCK_WORDS);

for (i, chunk) in key.chunks_exact(WORD_SIZE).enumerate() {
ek_words[i] = u32::from_ne_bytes(chunk.try_into().unwrap());
columns[i] = u32::from_ne_bytes(chunk.try_into().unwrap());
}

// From "The Rijndael Block Cipher" Section 4.1:
Expand All @@ -38,15 +35,15 @@ pub(super) fn expand_key<const L: usize, const N: usize>(key: &[u8; L]) -> [uint
let nk = L / WORD_SIZE;

for i in nk..(N * BLOCK_WORDS) {
let mut word = ek_words[i - 1];
let mut word = columns[i - 1];

if i % nk == 0 {
word = unsafe { sub_word(word) }.rotate_right(8) ^ ROUND_CONSTS[i / nk - 1];
word = sub_word(word).rotate_right(8) ^ ROUND_CONSTS[i / nk - 1];
} else if nk > 6 && i % nk == 4 {
word = unsafe { sub_word(word) };
word = sub_word(word);
}

ek_words[i] = ek_words[i - nk] ^ word;
columns[i] = columns[i - nk] ^ word;
}

expanded_keys
Expand All @@ -68,6 +65,7 @@ pub(super) unsafe fn inv_expanded_keys<const N: usize>(expanded_keys: &mut [uint
}

/// Sub bytes for a single AES word: used for key expansion.
#[inline]
#[target_feature(enable = "aes")]
unsafe fn sub_word(input: u32) -> u32 {
let input = vreinterpretq_u8_u32(vdupq_n_u32(input));
Expand Down
6 changes: 3 additions & 3 deletions aes/src/armv8/test_expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn store_expanded_keys<const N: usize>(input: [uint8x16_t; N]) -> [[u8; 16]; N]

#[test]
fn aes128_key_expansion() {
let ek = expand_key(&AES128_KEY);
let ek = unsafe { expand_key(&AES128_KEY) };
assert_eq!(store_expanded_keys(ek), AES128_EXP_KEYS);
}

Expand All @@ -119,12 +119,12 @@ fn aes128_key_expansion_inv() {

#[test]
fn aes192_key_expansion() {
let ek = expand_key(&AES192_KEY);
let ek = unsafe { expand_key(&AES192_KEY) };
assert_eq!(store_expanded_keys(ek), AES192_EXP_KEYS);
}

#[test]
fn aes256_key_expansion() {
let ek = expand_key(&AES256_KEY);
let ek = unsafe { expand_key(&AES256_KEY) };
assert_eq!(store_expanded_keys(ek), AES256_EXP_KEYS);
}

0 comments on commit eb309c6

Please sign in to comment.