Skip to content

Commit

Permalink
Remove linq (#1949)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wraith2 authored Mar 24, 2023
1 parent 0c0ae37 commit 23ca8e7
Show file tree
Hide file tree
Showing 25 changed files with 499 additions and 224 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using System.Text;
using Azure.Core;
using Azure.Security.KeyVault.Keys.Cryptography;
Expand Down Expand Up @@ -229,14 +228,18 @@ byte[] DecryptEncryptionKey()
}

// Get ciphertext
byte[] cipherText = encryptedColumnEncryptionKey.Skip(currentIndex).Take(cipherTextLength).ToArray();
byte[] cipherText = new byte[cipherTextLength];
Array.Copy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength);

currentIndex += cipherTextLength;

// Get signature
byte[] signature = encryptedColumnEncryptionKey.Skip(currentIndex).Take(signatureLength).ToArray();
byte[] signature = new byte[signatureLength];
Buffer.BlockCopy(encryptedColumnEncryptionKey, currentIndex, signature, 0, signatureLength);

// Compute the message to validate the signature
byte[] message = encryptedColumnEncryptionKey.Take(encryptedColumnEncryptionKey.Length - signatureLength).ToArray();
byte[] message = new byte[encryptedColumnEncryptionKey.Length - signatureLength];
Buffer.BlockCopy(encryptedColumnEncryptionKey, 0, message, 0, encryptedColumnEncryptionKey.Length - signatureLength);

if (null == message)
{
Expand Down Expand Up @@ -294,7 +297,24 @@ public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string e

// Compute message
// SHA-2-256(version + keyPathLength + ciphertextLength + keyPath + ciphertext)
byte[] message = s_firstVersion.Concat(keyPathLength).Concat(cipherTextLength).Concat(masterKeyPathBytes).Concat(cipherText).ToArray();
int messageLength = s_firstVersion.Length + keyPathLength.Length + cipherTextLength.Length + masterKeyPathBytes.Length + cipherText.Length;
byte[] message = new byte[messageLength];
int position = 0;

Buffer.BlockCopy(s_firstVersion, 0, message, position, s_firstVersion.Length);
position += s_firstVersion.Length;

Buffer.BlockCopy(keyPathLength, 0, message, position, keyPathLength.Length);
position += keyPathLength.Length;

Buffer.BlockCopy(cipherTextLength, 0, message, position, cipherTextLength.Length);
position += cipherTextLength.Length;

Buffer.BlockCopy(masterKeyPathBytes, 0, message, position, masterKeyPathBytes.Length);
position += masterKeyPathBytes.Length;

Buffer.BlockCopy(cipherText, 0, message, position, cipherText.Length);
position += cipherText.Length;

// Sign the message
byte[] signature = KeyCryptographer.SignData(message, masterKeyPath);
Expand All @@ -306,7 +326,11 @@ public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string e

ValidateSignature(masterKeyPath, message, signature);

return message.Concat(signature).ToArray();
byte[] retval = new byte[message.Length + signature.Length];
Buffer.BlockCopy(message, 0, retval, 0, message.Length);
Buffer.BlockCopy(signature, 0, retval, message.Length, signature.Length);

return retval;
}

#endregion
Expand Down Expand Up @@ -345,7 +369,7 @@ internal void ValidateNonEmptyAKVPath(string masterKeyPath, bool isSystemOp)

// Return an error indicating that the AKV url is invalid.
AKVEventSource.Log.TryTraceEvent("Master Key Path could not be validated as it does not end with trusted endpoints: {0}", masterKeyPath);
throw ADP.InvalidAKVUrlTrustedEndpoints(masterKeyPath, string.Join(", ", TrustedEndPoints.ToArray()));
throw ADP.InvalidAKVUrlTrustedEndpoints(masterKeyPath, string.Join(", ", TrustedEndPoints));
}

private void ValidateSignature(string masterKeyPath, byte[] message, byte[] signature)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Collections;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Security.Cryptography;

namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider
Expand All @@ -15,7 +14,7 @@ internal static class Validator
{
internal static void ValidateNotNull(object parameter, string name)
{
if (null == parameter)
if (parameter == null)
{
throw ADP.NullArgument(name);
}
Expand All @@ -31,9 +30,15 @@ internal static void ValidateNotEmpty(IList parameter, string name)

internal static void ValidateNotNullOrWhitespaceForEach(string[] parameters, string name)
{
if (parameters.Any(s => string.IsNullOrWhiteSpace(s)))
if (parameters != null && parameters.Length > 0)
{
throw ADP.NullOrWhitespaceForEach(name);
for (int index = 0; index < parameters.Length; index++)
{
if (string.IsNullOrWhiteSpace(parameters[index]))
{
throw ADP.NullOrWhitespaceForEach(name);
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.Loader;

Expand Down Expand Up @@ -69,7 +69,30 @@ private static Assembly AssemblyResolver(AssemblyName arg)
return fullPath == null ? null : AssemblyLoadContext.Default.LoadFromAssemblyPath(fullPath);
}

private static Type TypeResolver(Assembly arg1, string arg2, bool arg3) => arg1?.ExportedTypes.Single(t => t.FullName == arg2);
private static Type TypeResolver(Assembly arg1, string arg2, bool arg3)
{
IEnumerable<Type> types = arg1?.ExportedTypes;
Type result = null;
if (types != null)
{
foreach (Type type in types)
{
if (type.FullName == arg2)
{
if (result != null)
{
throw new InvalidOperationException("Sequence contains more than one matching element");
}
result = type;
}
}
}
if (result == null)
{
throw new InvalidOperationException("Sequence contains no matching element");
}
return result;
}

/// <summary>
/// Load assemblies on request.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Linq;

namespace Microsoft.Data.SqlClient.SNI
{
Expand Down Expand Up @@ -83,12 +83,16 @@ public override void ReturnPacket(SNIPacket packet)
#if DEBUG
private string GetStackParts()
{
return string.Join(Environment.NewLine,
Environment.StackTrace
.Split(new string[] { Environment.NewLine }, StringSplitOptions.None)
.Skip(3) // trims off the common parts at the top of the stack so you can see what the actual caller was
.Take(7) // trims off most of the bottom of the stack because when running under xunit there's a lot of spam
);
// trims off the common parts at the top of the stack so you can see what the actual caller was
// trims off most of the bottom of the stack because when running under xunit there's a lot of spam
string[] parts = Environment.StackTrace.Split(new string[] { Environment.NewLine }, StringSplitOptions.None);
List<string> take = new List<string>(7);
for (int index = 3; take.Count < 7 && index < parts.Length; index++)
{
take.Add(parts[index]);
}

return string.Join(Environment.NewLine, take.ToArray());
}
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
Expand Down Expand Up @@ -192,52 +191,77 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re

IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname);
Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve");

IPAddress[] ipv4Addresses = null;
IPAddress[] ipv6Addresses = null;
switch (ipPreference)
{
case SqlConnectionIPAddressPreference.IPv4First:
{
SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses);

SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
{
return response4.ResponsePacket;
}

SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
{
return response6.ResponsePacket;
}

// No responses so throw first error
if (response4 != null && response4.Error != null)
{
throw response4.Error;
}
else if (response6 != null && response6.Error != null)
{
throw response6.Error;
}

break;
}
case SqlConnectionIPAddressPreference.IPv6First:
{
SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel);
SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses);

SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel);
if (response6 != null && response6.ResponsePacket != null)
{
return response6.ResponsePacket;
}

SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel);
SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel);
if (response4 != null && response4.ResponsePacket != null)
{
return response4.ResponsePacket;
}

// No responses so throw first error
if (response6 != null && response6.Error != null)
{
throw response6.Error;
}
else if (response4 != null && response4.Error != null)
{
throw response4.Error;
}

break;
}
default:
{
SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel);
if (response != null && response.ResponsePacket != null)
{
return response.ResponsePacket;
}
else if (response != null && response.Error != null)
{
throw response.Error;
}

break;
}
Expand Down Expand Up @@ -372,5 +396,40 @@ internal static string SendBroadcastUDPRequest()
}
return response.ToString();
}

private static void SplitIPv4AndIPv6(IPAddress[] input, out IPAddress[] ipv4Addresses, out IPAddress[] ipv6Addresses)
{
ipv4Addresses = Array.Empty<IPAddress>();
ipv6Addresses = Array.Empty<IPAddress>();

if (input != null && input.Length > 0)
{
List<IPAddress> v4 = new List<IPAddress>(1);
List<IPAddress> v6 = new List<IPAddress>(0);

for (int index = 0; index < input.Length; index++)
{
switch (input[index].AddressFamily)
{
case AddressFamily.InterNetwork:
v4.Add(input[index]);
break;
case AddressFamily.InterNetworkV6:
v6.Add(input[index]);
break;
}
}

if (v4.Count > 0)
{
ipv4Addresses = v4.ToArray();
}

if (v6.Count > 0)
{
ipv6Addresses = v6.ToArray();
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,25 +43,24 @@ static SqlAuthenticationProviderManager()
public SqlAuthenticationProviderManager(SqlAuthenticationProviderConfigurationSection configSection = null)
{
var methodName = "Ctor";
_typeName = GetType().Name;
_providers = new ConcurrentDictionary<SqlAuthenticationMethod, SqlAuthenticationProvider>();
var authenticationsWithAppSpecifiedProvider = new HashSet<SqlAuthenticationMethod>();
_authenticationsWithAppSpecifiedProvider = authenticationsWithAppSpecifiedProvider;

if (configSection == null)
{
_sqlAuthLogger.LogInfo(_typeName, methodName, "Neither SqlClientAuthenticationProviders nor SqlAuthenticationProviders configuration section found.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "Neither SqlClientAuthenticationProviders nor SqlAuthenticationProviders configuration section found.");
return;
}

if (!string.IsNullOrEmpty(configSection.ApplicationClientId))
{
_applicationClientId = configSection.ApplicationClientId;
_sqlAuthLogger.LogInfo(_typeName, methodName, "Received user-defined Application Client Id");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "Received user-defined Application Client Id");
}
else
{
_sqlAuthLogger.LogInfo(_typeName, methodName, "No user-defined Application Client Id found.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "No user-defined Application Client Id found.");
}

// Create user-defined auth initializer, if any.
Expand All @@ -77,11 +76,11 @@ public SqlAuthenticationProviderManager(SqlAuthenticationProviderConfigurationSe
{
throw SQL.CannotCreateSqlAuthInitializer(configSection.InitializerType, e);
}
_sqlAuthLogger.LogInfo(_typeName, methodName, "Created user-defined SqlAuthenticationInitializer.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "Created user-defined SqlAuthenticationInitializer.");
}
else
{
_sqlAuthLogger.LogInfo(_typeName, methodName, "No user-defined SqlAuthenticationInitializer found.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "No user-defined SqlAuthenticationInitializer found.");
}

// add user-defined providers, if any.
Expand All @@ -107,12 +106,12 @@ public SqlAuthenticationProviderManager(SqlAuthenticationProviderConfigurationSe

_providers[authentication] = provider;
authenticationsWithAppSpecifiedProvider.Add(authentication);
_sqlAuthLogger.LogInfo(_typeName, methodName, string.Format("Added user-defined auth provider: {0} for authentication {1}.", providerSettings?.Type, authentication));
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, string.Format("Added user-defined auth provider: {0} for authentication {1}.", providerSettings?.Type, authentication));
}
}
else
{
_sqlAuthLogger.LogInfo(_typeName, methodName, "No user-defined auth providers.");
_sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "No user-defined auth providers.");
}
}

Expand Down
Loading

0 comments on commit 23ca8e7

Please sign in to comment.