Skip to content

Commit

Permalink
[release/8.0-staging] Manually depad RSAES-PKCS1 on Apple OSes
Browse files Browse the repository at this point in the history
Co-authored-by: Jeremy Barton <jbarton@microsoft.com>
  • Loading branch information
github-actions[bot] and bartonjs authored Feb 14, 2024
1 parent fe5e36a commit 91b2946
Show file tree
Hide file tree
Showing 6 changed files with 319 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Buffers;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
Expand Down Expand Up @@ -69,8 +70,8 @@ private static partial int RsaDecryptOaep(
out SafeCFDataHandle pEncryptedOut,
out SafeCFErrorHandle pErrorOut);

[LibraryImport(Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_RsaDecryptPkcs")]
private static partial int RsaDecryptPkcs(
[LibraryImport(Libraries.AppleCryptoNative, EntryPoint = "AppleCryptoNative_RsaDecryptRaw")]
private static partial int RsaDecryptRaw(
SafeSecKeyRefHandle publicKey,
ReadOnlySpan<byte> pbData,
int cbData,
Expand Down Expand Up @@ -166,17 +167,40 @@ internal static byte[] RsaDecrypt(
byte[] data,
RSAEncryptionPadding padding)
{
if (padding == RSAEncryptionPadding.Pkcs1)
{
byte[] padded = ExecuteTransform(
data,
(ReadOnlySpan<byte> source, out SafeCFDataHandle decrypted, out SafeCFErrorHandle error) =>
RsaDecryptRaw(privateKey, source, source.Length, out decrypted, out error));

byte[] depad = CryptoPool.Rent(padded.Length);
OperationStatus status = RsaPaddingProcessor.DepadPkcs1Encryption(padded, depad, out int written);
byte[]? ret = null;

if (status == OperationStatus.Done)
{
ret = depad.AsSpan(0, written).ToArray();
}

// Clear the whole thing, especially on failure.
CryptoPool.Return(depad);
CryptographicOperations.ZeroMemory(padded);

if (ret is null)
{
throw new CryptographicException(SR.Cryptography_InvalidPadding);
}

return ret;
}

Debug.Assert(padding.Mode == RSAEncryptionPaddingMode.Oaep);

return ExecuteTransform(
data,
(ReadOnlySpan<byte> source, out SafeCFDataHandle decrypted, out SafeCFErrorHandle error) =>
{
if (padding == RSAEncryptionPadding.Pkcs1)
{
return RsaDecryptPkcs(privateKey, source, source.Length, out decrypted, out error);
}

Debug.Assert(padding.Mode == RSAEncryptionPaddingMode.Oaep);

return RsaDecryptOaep(
privateKey,
source,
Expand All @@ -195,14 +219,63 @@ internal static bool TryRsaDecrypt(
out int bytesWritten)
{
Debug.Assert(padding.Mode == RSAEncryptionPaddingMode.Pkcs1 || padding.Mode == RSAEncryptionPaddingMode.Oaep);

if (padding.Mode == RSAEncryptionPaddingMode.Pkcs1)
{
byte[] padded = CryptoPool.Rent(source.Length);
byte[] depad = CryptoPool.Rent(source.Length);

bool processed = TryExecuteTransform(
source,
padded,
out int paddedLength,
(ReadOnlySpan<byte> innerSource, out SafeCFDataHandle outputHandle, out SafeCFErrorHandle errorHandle) =>
RsaDecryptRaw(privateKey, innerSource, innerSource.Length, out outputHandle, out errorHandle));

Debug.Assert(
processed,
"TryExecuteTransform should always return true for a large enough buffer.");

OperationStatus status = OperationStatus.InvalidData;
int depaddedLength = 0;

if (processed)
{
status = RsaPaddingProcessor.DepadPkcs1Encryption(
new ReadOnlySpan<byte>(padded, 0, paddedLength),
depad,
out depaddedLength);
}

CryptoPool.Return(padded);

if (status == OperationStatus.Done)
{
if (depaddedLength <= destination.Length)
{
depad.AsSpan(0, depaddedLength).CopyTo(destination);
CryptoPool.Return(depad);
bytesWritten = depaddedLength;
return true;
}

CryptoPool.Return(depad);
bytesWritten = 0;
return false;
}

CryptoPool.Return(depad);
Debug.Assert(status == OperationStatus.InvalidData);
throw new CryptographicException(SR.Cryptography_InvalidPadding);
}

return TryExecuteTransform(
source,
destination,
out bytesWritten,
delegate (ReadOnlySpan<byte> innerSource, out SafeCFDataHandle outputHandle, out SafeCFErrorHandle errorHandle)
{
return padding.Mode == RSAEncryptionPaddingMode.Pkcs1 ?
RsaDecryptPkcs(privateKey, innerSource, innerSource.Length, out outputHandle, out errorHandle) :
return
RsaDecryptOaep(privateKey, innerSource, innerSource.Length, PalAlgorithmFromAlgorithmName(padding.OaepHashAlgorithm), out outputHandle, out errorHandle);
});
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Buffers;
using System.Buffers.Binary;
using System.Collections.Concurrent;
using System.Diagnostics;
Expand Down Expand Up @@ -142,6 +143,109 @@ internal static void PadPkcs1Encryption(
source.CopyTo(mInEM);
}

internal static OperationStatus DepadPkcs1Encryption(
ReadOnlySpan<byte> source,
Span<byte> destination,
out int bytesWritten)
{
int primitive = DepadPkcs1Encryption(source);
int primitiveSign = SignStretch(primitive);

// Primitive is a positive length, or ~length to indicate
// an error, so flip ~length to length if the high bit is set.
int len = Choose(primitiveSign, ~primitive, primitive);
int spaceRemain = destination.Length - len;
int spaceRemainSign = SignStretch(spaceRemain);

// len = clampHigh(len, destination.Length);
len = Choose(spaceRemainSign, destination.Length, len);

// ret = spaceRemain < 0 ? DestinationTooSmall : Done
int ret = Choose(
spaceRemainSign,
(int)OperationStatus.DestinationTooSmall,
(int)OperationStatus.Done);

// ret = primitive < 0 ? InvalidData : ret;
ret = Choose(primitiveSign, (int)OperationStatus.InvalidData, ret);

// Write some number of bytes, regardless of the final return.
source[^len..].CopyTo(destination);

// bytesWritten = ret == Done ? len : 0;
bytesWritten = Choose(CheckZero(ret), len, 0);
return (OperationStatus)ret;
}

private static int DepadPkcs1Encryption(ReadOnlySpan<byte> source)
{
Debug.Assert(source.Length > 11);
ReadOnlySpan<byte> afterPadding = source.Slice(10);
ReadOnlySpan<byte> noZeros = source.Slice(2, 8);

// Find the first zero in noZeros, or -1 for no zeros.
int zeroPos = BlindFindFirstZero(noZeros);

// If zeroPos is negative, valid is -1, otherwise 0.
int valid = SignStretch(zeroPos);

// If there are no zeros in afterPadding then zeroPos is negative,
// so negating the sign stretch is 0, which makes hasPos 0.
// If there -was- a zero, sign stretching is 0, so negating it makes hasPos -1.
zeroPos = BlindFindFirstZero(afterPadding);
int hasLen = ~SignStretch(zeroPos);
valid &= hasLen;

// Check that the first two bytes are { 00 02 }
valid &= CheckZero(source[0] | (source[1] ^ 0x02));

int lenIfGood = afterPadding.Length - zeroPos - 1;
// If there were no zeros, use the full after-min-padding segment.
int lenIfBad = ~Choose(hasLen, lenIfGood, source.Length - 11);

Debug.Assert(lenIfBad < 0);
return Choose(valid, lenIfGood, lenIfBad);
}

private static int BlindFindFirstZero(ReadOnlySpan<byte> source)
{
// Any vectorization of this routine needs to use non-early termination,
// and instructions that do not vary their completion time on the input.

int pos = -1;

for (int i = source.Length - 1; i >= 0; i--)
{
// pos = source[i] == 0 ? i : pos;
int local = CheckZero(source[i]);
pos = Choose(local, i, pos);
}

return pos;
}

private static int SignStretch(int value)
{
return value >> 31;
}

private static int Choose(int selector, int yes, int no)
{
Debug.Assert((selector | (selector - 1)) == -1);
return (selector & yes) | (~selector & no);
}

private static int CheckZero(int value)
{
// For zero, ~value and value-1 are both all bits set (negative).
// For positive values, ~value is negative and value-1 is positive.
// For negative values except MinValue, ~value is positive and value-1 is negative.
// For MinValue, ~value is positive and value-1 is also positive.
// All together, the only thing that has negative & negative is 0, so stretch the sign bit.
int mask = ~value & (value - 1);
return SignStretch(mask);
}

internal static void PadPkcs1Signature(
HashAlgorithmName hashAlgorithmName,
ReadOnlySpan<byte> source,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Diagnostics;
using System.Numerics;
using Test.Cryptography;
using Microsoft.DotNet.XUnitExtensions;
using Xunit;
Expand Down Expand Up @@ -736,6 +738,119 @@ public void Decrypt_Pkcs1_ErrorsForInvalidPadding(byte[] data)
}
}

[Fact]
public void Decrypt_Pkcs1_BadPadding()
{
if ((PlatformDetection.IsWindows && !PlatformDetection.IsWindows10Version2004OrGreater))
{
return;
}

RSAParameters keyParams = TestData.RSA2048Params;
BigInteger e = new BigInteger(keyParams.Exponent, true, true);
BigInteger n = new BigInteger(keyParams.Modulus, true, true);
byte[] buf = new byte[keyParams.Modulus.Length];
byte[] c = new byte[buf.Length];

buf[1] = 2;
buf.AsSpan(2).Fill(1);

ref byte afterMinPadding = ref buf[10];
ref byte lastByte = ref buf[^1];
afterMinPadding = 0;

using (RSA rsa = RSAFactory.Create(keyParams))
{
RawEncrypt(buf, e, n, c);
// Assert.NoThrow, check that manual padding is coherent
Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1);

// All RSA encryption schemes start with 00, so pick any other number.
//
// If buf > modulus then encrypt should fail, so this
// is the largest legal-but-invalid value to test.
buf[0] = keyParams.Modulus[0];
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Check again with a zero length payload
(afterMinPadding, lastByte) = (lastByte, afterMinPadding);
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Back to valid padding
buf[0] = 0;
(afterMinPadding, lastByte) = (lastByte, afterMinPadding);
RawEncrypt(buf, e, n, c);
Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1);

// This is (sort of) legal for PKCS1 signatures, but not decryption.
buf[1] = 1;
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// No RSA PKCS1 padding scheme starts with 00 FF.
buf[1] = 255;
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Check again with a zero length payload
(afterMinPadding, lastByte) = (lastByte, afterMinPadding);
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Back to valid padding
buf[1] = 2;
(afterMinPadding, lastByte) = (lastByte, afterMinPadding);
RawEncrypt(buf, e, n, c);
Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1);

// Try a zero in every possible required padding position
for (int i = 2; i < 10; i++)
{
buf[i] = 0;

RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// It used to be 1, now it's 2, still not zero.
buf[i] = 2;
}

// Back to valid padding
RawEncrypt(buf, e, n, c);
Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1);

// Make it such that
// "there is no octet with hexadecimal value 0x00 to separate PS from M"
// (RFC 3447 sec 7.2.2, rule 3, third clause)
buf.AsSpan(10).Fill(3);
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));

// Every possible problem, for good measure.
buf[0] = 2;
buf[1] = 0;
buf[4] = 0;
RawEncrypt(buf, e, n, c);
Assert.ThrowsAny<CryptographicException>(() => Decrypt(rsa, c, RSAEncryptionPadding.Pkcs1));
}

static void RawEncrypt(ReadOnlySpan<byte> source, BigInteger e, BigInteger n, Span<byte> destination)
{
BigInteger m = new BigInteger(source, true, true);
BigInteger c = BigInteger.ModPow(m, e, n);
int shift = destination.Length - c.GetByteCount(true);
destination.Slice(0, shift).Clear();
bool wrote = c.TryWriteBytes(destination.Slice(shift), out int written, true, true);

if (!wrote || written + shift != destination.Length)
{
throw new UnreachableException();
}
}
}

public static IEnumerable<object[]> OaepPaddingModes
{
get
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ static const Entry s_cryptoAppleNative[] =
DllImportEntry(AppleCryptoNative_RsaGenerateKey)
DllImportEntry(AppleCryptoNative_RsaDecryptOaep)
DllImportEntry(AppleCryptoNative_RsaDecryptPkcs)
DllImportEntry(AppleCryptoNative_RsaDecryptRaw)
DllImportEntry(AppleCryptoNative_RsaEncryptOaep)
DllImportEntry(AppleCryptoNative_RsaEncryptPkcs)
DllImportEntry(AppleCryptoNative_RsaSignaturePrimitive)
Expand Down
Loading

0 comments on commit 91b2946

Please sign in to comment.