From 5f61ce35553834569d277274c0e4490bd0426935 Mon Sep 17 00:00:00 2001 From: Bogdan Opanchuk Date: Tue, 9 Aug 2022 16:59:49 -0700 Subject: [PATCH] Get rid of eager computation in mul_shift_vartime() in places where it is shift-dependent It is expected to be vartime on the value of `shift`, and calculating some of the paths eagerly when they are not required leads to panics. --- k256/src/arithmetic/scalar/wide32.rs | 176 ++++++++++++++------------- k256/src/arithmetic/scalar/wide64.rs | 78 ++++++------ 2 files changed, 129 insertions(+), 125 deletions(-) diff --git a/k256/src/arithmetic/scalar/wide32.rs b/k256/src/arithmetic/scalar/wide32.rs index 838860ae..4af8cdc1 100644 --- a/k256/src/arithmetic/scalar/wide32.rs +++ b/k256/src/arithmetic/scalar/wide32.rs @@ -132,96 +132,100 @@ impl WideScalar { pub(crate) fn mul_shift_vartime(a: &Scalar, b: &Scalar, shift: usize) -> Scalar { debug_assert!(shift >= 256); - fn ifelse(c: bool, x: u32, y: u32) -> u32 { - if c { - x - } else { - y - } - } - let l = Self::mul_wide(a, b).0.to_words(); let shiftlimbs = shift >> 5; let shiftlow = shift & 0x1F; let shifthigh = 32 - shiftlow; - let r0 = ifelse( - shift < 512, - (l[shiftlimbs] >> shiftlow) - | ifelse( - shift < 480 && shiftlow != 0, - l[1 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); - - let r1 = ifelse( - shift < 480, - (l[1 + shiftlimbs] >> shiftlow) - | ifelse( - shift < 448 && shiftlow != 0, - l[2 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); - - let r2 = ifelse( - shift < 448, - (l[2 + shiftlimbs] >> shiftlow) - | ifelse( - shift < 416 && shiftlow != 0, - l[3 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); - - let r3 = ifelse( - shift < 416, - (l[3 + shiftlimbs] >> shiftlow) - | ifelse( - shift < 384 && shiftlow != 0, - l[4 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); - - let r4 = ifelse( - shift < 384, - (l[4 + shiftlimbs] >> shiftlow) - | ifelse( - shift < 352 && shiftlow != 0, - l[5 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); - - let r5 = ifelse( - shift < 352, - (l[5 + shiftlimbs] >> shiftlow) - | ifelse( - shift < 320 && shiftlow != 0, - l[6 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); - - let r6 = ifelse( - shift < 320, - (l[6 + shiftlimbs] >> shiftlow) - | ifelse( - shift < 288 && shiftlow != 0, - l[7 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); - - let r7 = ifelse(shift < 288, l[7 + shiftlimbs] >> shiftlow, 0); + + let r0 = if shift < 512 { + let lo = l[shiftlimbs] >> shiftlow; + let hi = if shift < 480 && shiftlow != 0 { + l[1 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; + + let r1 = if shift < 480 { + let lo = l[1 + shiftlimbs] >> shiftlow; + let hi = if shift < 448 && shiftlow != 0 { + l[2 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; + + let r2 = if shift < 448 { + let lo = l[2 + shiftlimbs] >> shiftlow; + let hi = if shift < 416 && shiftlow != 0 { + l[3 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; + + let r3 = if shift < 416 { + let lo = l[3 + shiftlimbs] >> shiftlow; + let hi = if shift < 384 && shiftlow != 0 { + l[4 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; + + let r4 = if shift < 384 { + let lo = l[4 + shiftlimbs] >> shiftlow; + let hi = if shift < 352 && shiftlow != 0 { + l[5 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; + + let r5 = if shift < 352 { + let lo = l[5 + shiftlimbs] >> shiftlow; + let hi = if shift < 320 && shiftlow != 0 { + l[6 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; + + let r6 = if shift < 320 { + let lo = l[6 + shiftlimbs] >> shiftlow; + let hi = if shift < 288 && shiftlow != 0 { + l[7 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; + + let r7 = if shift < 288 { + l[7 + shiftlimbs] >> shiftlow + } else { + 0 + }; let res = Scalar(U256::from_words([r0, r1, r2, r3, r4, r5, r6, r7])); diff --git a/k256/src/arithmetic/scalar/wide64.rs b/k256/src/arithmetic/scalar/wide64.rs index ef78bff7..d88032a7 100644 --- a/k256/src/arithmetic/scalar/wide64.rs +++ b/k256/src/arithmetic/scalar/wide64.rs @@ -64,52 +64,52 @@ impl WideScalar { pub(crate) fn mul_shift_vartime(a: &Scalar, b: &Scalar, shift: usize) -> Scalar { debug_assert!(shift >= 256); - fn ifelse(c: bool, x: u64, y: u64) -> u64 { - if c { - x - } else { - y - } - } - let l = Self::mul_wide(a, b).0.to_words(); let shiftlimbs = shift >> 6; let shiftlow = shift & 0x3F; let shifthigh = 64 - shiftlow; - let r0 = ifelse( - shift < 512, - (l[shiftlimbs] >> shiftlow) - | ifelse( - shift < 448 && shiftlow != 0, - l[1 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); - let r1 = ifelse( - shift < 448, - (l[1 + shiftlimbs] >> shiftlow) - | ifelse( - shift < 384 && shiftlow != 0, - l[2 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); + let r0 = if shift < 512 { + let lo = l[shiftlimbs] >> shiftlow; + let hi = if shift < 448 && shiftlow != 0 { + l[1 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; + + let r1 = if shift < 448 { + let lo = l[1 + shiftlimbs] >> shiftlow; + let hi = if shift < 384 && shiftlow != 0 { + l[2 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; - let r2 = ifelse( - shift < 384, - (l[2 + shiftlimbs] >> shiftlow) - | ifelse( - shift < 320 && shiftlow != 0, - l[3 + shiftlimbs] << shifthigh, - 0, - ), - 0, - ); + let r2 = if shift < 384 { + let lo = l[2 + shiftlimbs] >> shiftlow; + let hi = if shift < 320 && shiftlow != 0 { + l[3 + shiftlimbs] << shifthigh + } else { + 0 + }; + hi | lo + } else { + 0 + }; - let r3 = ifelse(shift < 320, l[3 + shiftlimbs] >> shiftlow, 0); + let r3 = if shift < 320 { + l[3 + shiftlimbs] >> shiftlow + } else { + 0 + }; let res = Scalar(U256::from_words([r0, r1, r2, r3]));