diff --git a/cmov/src/lib.rs b/cmov/src/lib.rs index 0491fdfa..48b682cd 100644 --- a/cmov/src/lib.rs +++ b/cmov/src/lib.rs @@ -134,3 +134,25 @@ impl CmovEq for u128 { tmp.cmoveq(&1, input, output); } } + +impl CmovEq for [T] { + fn cmoveq(&self, rhs: &Self, input: Condition, output: &mut Condition) { + let mut tmp = 1u8; + self.cmovne(rhs, 0u8, &mut tmp); + tmp.cmoveq(&1, input, output); + } + + fn cmovne(&self, rhs: &Self, input: Condition, output: &mut Condition) { + // Short-circuit the comparison if the slices are of different lengths, and set the output + // condition to the input condition. + if self.len() != rhs.len() { + *output = input; + return; + } + + // Compare each byte. + for (a, b) in self.iter().zip(rhs.iter()) { + a.cmovne(b, input, output); + } + } +} diff --git a/cmov/tests/lib.rs b/cmov/tests/lib.rs index 294b5828..13c0f284 100644 --- a/cmov/tests/lib.rs +++ b/cmov/tests/lib.rs @@ -332,3 +332,41 @@ mod u128 { assert_eq!(o, 55u8); } } + +mod slices { + use cmov::CmovEq; + + #[test] + fn cmoveq_works() { + let mut o = 0u8; + + // Same slices. + [1u8, 2, 3].cmoveq(&[1, 2, 3], 43, &mut o); + assert_eq!(o, 43); + + // Different lengths. + [1u8, 2, 3].cmoveq(&[1, 2], 44, &mut o); + assert_ne!(o, 44); + + // Different contents. + [1u8, 2, 3].cmoveq(&[1, 2, 4], 45, &mut o); + assert_ne!(o, 45); + } + + #[test] + fn cmovne_works() { + let mut o = 0u8; + + // Same slices. + [1u8, 2, 3].cmovne(&[1, 2, 3], 43, &mut o); + assert_ne!(o, 43); + + // Different lengths. + [1u8, 2, 3].cmovne(&[1, 2], 44, &mut o); + assert_eq!(o, 44); + + // Different contents. + [1u8, 2, 3].cmovne(&[1, 2, 4], 45, &mut o); + assert_eq!(o, 45); + } +}