diff --git a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/SqlColumnEncryptionAzureKeyVaultProvider.cs b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/SqlColumnEncryptionAzureKeyVaultProvider.cs index 2a7a4267b6..d06063f8c0 100644 --- a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/SqlColumnEncryptionAzureKeyVaultProvider.cs +++ b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/SqlColumnEncryptionAzureKeyVaultProvider.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.Linq; using System.Text; using Azure.Core; using Azure.Security.KeyVault.Keys.Cryptography; @@ -229,14 +228,18 @@ byte[] DecryptEncryptionKey() } // Get ciphertext - byte[] cipherText = encryptedColumnEncryptionKey.Skip(currentIndex).Take(cipherTextLength).ToArray(); + byte[] cipherText = new byte[cipherTextLength]; + Array.Copy(encryptedColumnEncryptionKey, currentIndex, cipherText, 0, cipherTextLength); + currentIndex += cipherTextLength; // Get signature - byte[] signature = encryptedColumnEncryptionKey.Skip(currentIndex).Take(signatureLength).ToArray(); + byte[] signature = new byte[signatureLength]; + Buffer.BlockCopy(encryptedColumnEncryptionKey, currentIndex, signature, 0, signatureLength); // Compute the message to validate the signature - byte[] message = encryptedColumnEncryptionKey.Take(encryptedColumnEncryptionKey.Length - signatureLength).ToArray(); + byte[] message = new byte[encryptedColumnEncryptionKey.Length - signatureLength]; + Buffer.BlockCopy(encryptedColumnEncryptionKey, 0, message, 0, encryptedColumnEncryptionKey.Length - signatureLength); if (null == message) { @@ -294,7 +297,24 @@ public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string e // Compute message // SHA-2-256(version + keyPathLength + ciphertextLength + keyPath + ciphertext) - byte[] message = s_firstVersion.Concat(keyPathLength).Concat(cipherTextLength).Concat(masterKeyPathBytes).Concat(cipherText).ToArray(); + int messageLength = s_firstVersion.Length + keyPathLength.Length + cipherTextLength.Length + masterKeyPathBytes.Length + cipherText.Length; + byte[] message = new byte[messageLength]; + int position = 0; + + Buffer.BlockCopy(s_firstVersion, 0, message, position, s_firstVersion.Length); + position += s_firstVersion.Length; + + Buffer.BlockCopy(keyPathLength, 0, message, position, keyPathLength.Length); + position += keyPathLength.Length; + + Buffer.BlockCopy(cipherTextLength, 0, message, position, cipherTextLength.Length); + position += cipherTextLength.Length; + + Buffer.BlockCopy(masterKeyPathBytes, 0, message, position, masterKeyPathBytes.Length); + position += masterKeyPathBytes.Length; + + Buffer.BlockCopy(cipherText, 0, message, position, cipherText.Length); + position += cipherText.Length; // Sign the message byte[] signature = KeyCryptographer.SignData(message, masterKeyPath); @@ -306,7 +326,11 @@ public override byte[] EncryptColumnEncryptionKey(string masterKeyPath, string e ValidateSignature(masterKeyPath, message, signature); - return message.Concat(signature).ToArray(); + byte[] retval = new byte[message.Length + signature.Length]; + Buffer.BlockCopy(message, 0, retval, 0, message.Length); + Buffer.BlockCopy(signature, 0, retval, message.Length, signature.Length); + + return retval; } #endregion @@ -345,7 +369,7 @@ internal void ValidateNonEmptyAKVPath(string masterKeyPath, bool isSystemOp) // Return an error indicating that the AKV url is invalid. AKVEventSource.Log.TryTraceEvent("Master Key Path could not be validated as it does not end with trusted endpoints: {0}", masterKeyPath); - throw ADP.InvalidAKVUrlTrustedEndpoints(masterKeyPath, string.Join(", ", TrustedEndPoints.ToArray())); + throw ADP.InvalidAKVUrlTrustedEndpoints(masterKeyPath, string.Join(", ", TrustedEndPoints)); } private void ValidateSignature(string masterKeyPath, byte[] message, byte[] signature) diff --git a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Utils.cs b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Utils.cs index a5d74ed0a7..2909422c1a 100644 --- a/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Utils.cs +++ b/src/Microsoft.Data.SqlClient/add-ons/AzureKeyVaultProvider/Utils.cs @@ -6,7 +6,6 @@ using System.Collections; using System.Collections.Generic; using System.Globalization; -using System.Linq; using System.Security.Cryptography; namespace Microsoft.Data.SqlClient.AlwaysEncrypted.AzureKeyVaultProvider @@ -15,7 +14,7 @@ internal static class Validator { internal static void ValidateNotNull(object parameter, string name) { - if (null == parameter) + if (parameter == null) { throw ADP.NullArgument(name); } @@ -31,9 +30,15 @@ internal static void ValidateNotEmpty(IList parameter, string name) internal static void ValidateNotNullOrWhitespaceForEach(string[] parameters, string name) { - if (parameters.Any(s => string.IsNullOrWhiteSpace(s))) + if (parameters != null && parameters.Length > 0) { - throw ADP.NullOrWhitespaceForEach(name); + for (int index = 0; index < parameters.Length; index++) + { + if (string.IsNullOrWhiteSpace(parameters[index])) + { + throw ADP.NullOrWhitespaceForEach(name); + } + } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryLogicManager.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryLogicManager.NetCoreApp.cs index c7af3c9534..fc0e4c80a3 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryLogicManager.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryLogicManager.NetCoreApp.cs @@ -3,8 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.IO; -using System.Linq; using System.Reflection; using System.Runtime.Loader; @@ -69,7 +69,30 @@ private static Assembly AssemblyResolver(AssemblyName arg) return fullPath == null ? null : AssemblyLoadContext.Default.LoadFromAssemblyPath(fullPath); } - private static Type TypeResolver(Assembly arg1, string arg2, bool arg3) => arg1?.ExportedTypes.Single(t => t.FullName == arg2); + private static Type TypeResolver(Assembly arg1, string arg2, bool arg3) + { + IEnumerable types = arg1?.ExportedTypes; + Type result = null; + if (types != null) + { + foreach (Type type in types) + { + if (type.FullName == arg2) + { + if (result != null) + { + throw new InvalidOperationException("Sequence contains more than one matching element"); + } + result = type; + } + } + } + if (result == null) + { + throw new InvalidOperationException("Sequence contains no matching element"); + } + return result; + } /// /// Load assemblies on request. diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPhysicalHandle.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPhysicalHandle.cs index 94f37d0c6a..3677abe6cd 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPhysicalHandle.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIPhysicalHandle.cs @@ -3,9 +3,9 @@ // See the LICENSE file in the project root for more information. using System; +using System.Collections.Generic; using System.Diagnostics; using System.Threading; -using System.Linq; namespace Microsoft.Data.SqlClient.SNI { @@ -83,12 +83,16 @@ public override void ReturnPacket(SNIPacket packet) #if DEBUG private string GetStackParts() { - return string.Join(Environment.NewLine, - Environment.StackTrace - .Split(new string[] { Environment.NewLine }, StringSplitOptions.None) - .Skip(3) // trims off the common parts at the top of the stack so you can see what the actual caller was - .Take(7) // trims off most of the bottom of the stack because when running under xunit there's a lot of spam - ); + // trims off the common parts at the top of the stack so you can see what the actual caller was + // trims off most of the bottom of the stack because when running under xunit there's a lot of spam + string[] parts = Environment.StackTrace.Split(new string[] { Environment.NewLine }, StringSplitOptions.None); + List take = new List(7); + for (int index = 3; take.Count < 7 && index < parts.Length; index++) + { + take.Add(parts[index]); + } + + return string.Join(Environment.NewLine, take.ToArray()); } #endif } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs index e51175059a..7bef6a6f9e 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SSRP.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Net; using System.Net.Sockets; using System.Text; @@ -192,42 +191,63 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re IPAddress[] ipAddresses = SNICommon.GetDnsIpAddresses(browserHostname); Debug.Assert(ipAddresses.Length > 0, "DNS should throw if zero addresses resolve"); - + IPAddress[] ipv4Addresses = null; + IPAddress[] ipv6Addresses = null; switch (ipPreference) { case SqlConnectionIPAddressPreference.IPv4First: { - SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel); + SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses); + + SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel); if (response4 != null && response4.ResponsePacket != null) + { return response4.ResponsePacket; + } - SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel); + SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel); if (response6 != null && response6.ResponsePacket != null) + { return response6.ResponsePacket; + } // No responses so throw first error if (response4 != null && response4.Error != null) + { throw response4.Error; + } else if (response6 != null && response6.Error != null) + { throw response6.Error; + } break; } case SqlConnectionIPAddressPreference.IPv6First: { - SsrpResult response6 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetworkV6).ToArray(), port, requestPacket, allIPsInParallel); + SplitIPv4AndIPv6(ipAddresses, out ipv4Addresses, out ipv6Addresses); + + SsrpResult response6 = SendUDPRequest(ipv6Addresses, port, requestPacket, allIPsInParallel); if (response6 != null && response6.ResponsePacket != null) + { return response6.ResponsePacket; + } - SsrpResult response4 = SendUDPRequest(ipAddresses.Where(i => i.AddressFamily == AddressFamily.InterNetwork).ToArray(), port, requestPacket, allIPsInParallel); + SsrpResult response4 = SendUDPRequest(ipv4Addresses, port, requestPacket, allIPsInParallel); if (response4 != null && response4.ResponsePacket != null) + { return response4.ResponsePacket; + } // No responses so throw first error if (response6 != null && response6.Error != null) + { throw response6.Error; + } else if (response4 != null && response4.Error != null) + { throw response4.Error; + } break; } @@ -235,9 +255,13 @@ private static byte[] SendUDPRequest(string browserHostname, int port, byte[] re { SsrpResult response = SendUDPRequest(ipAddresses, port, requestPacket, true); // allIPsInParallel); if (response != null && response.ResponsePacket != null) + { return response.ResponsePacket; + } else if (response != null && response.Error != null) + { throw response.Error; + } break; } @@ -372,5 +396,40 @@ internal static string SendBroadcastUDPRequest() } return response.ToString(); } + + private static void SplitIPv4AndIPv6(IPAddress[] input, out IPAddress[] ipv4Addresses, out IPAddress[] ipv6Addresses) + { + ipv4Addresses = Array.Empty(); + ipv6Addresses = Array.Empty(); + + if (input != null && input.Length > 0) + { + List v4 = new List(1); + List v6 = new List(0); + + for (int index = 0; index < input.Length; index++) + { + switch (input[index].AddressFamily) + { + case AddressFamily.InterNetwork: + v4.Add(input[index]); + break; + case AddressFamily.InterNetworkV6: + v6.Add(input[index]); + break; + } + } + + if (v4.Count > 0) + { + ipv4Addresses = v4.ToArray(); + } + + if (v6.Count > 0) + { + ipv6Addresses = v6.ToArray(); + } + } + } } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.NetCoreApp.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.NetCoreApp.cs index b204c8df81..094114f357 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.NetCoreApp.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.NetCoreApp.cs @@ -43,25 +43,24 @@ static SqlAuthenticationProviderManager() public SqlAuthenticationProviderManager(SqlAuthenticationProviderConfigurationSection configSection = null) { var methodName = "Ctor"; - _typeName = GetType().Name; _providers = new ConcurrentDictionary(); var authenticationsWithAppSpecifiedProvider = new HashSet(); _authenticationsWithAppSpecifiedProvider = authenticationsWithAppSpecifiedProvider; if (configSection == null) { - _sqlAuthLogger.LogInfo(_typeName, methodName, "Neither SqlClientAuthenticationProviders nor SqlAuthenticationProviders configuration section found."); + _sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "Neither SqlClientAuthenticationProviders nor SqlAuthenticationProviders configuration section found."); return; } if (!string.IsNullOrEmpty(configSection.ApplicationClientId)) { _applicationClientId = configSection.ApplicationClientId; - _sqlAuthLogger.LogInfo(_typeName, methodName, "Received user-defined Application Client Id"); + _sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "Received user-defined Application Client Id"); } else { - _sqlAuthLogger.LogInfo(_typeName, methodName, "No user-defined Application Client Id found."); + _sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "No user-defined Application Client Id found."); } // Create user-defined auth initializer, if any. @@ -77,11 +76,11 @@ public SqlAuthenticationProviderManager(SqlAuthenticationProviderConfigurationSe { throw SQL.CannotCreateSqlAuthInitializer(configSection.InitializerType, e); } - _sqlAuthLogger.LogInfo(_typeName, methodName, "Created user-defined SqlAuthenticationInitializer."); + _sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "Created user-defined SqlAuthenticationInitializer."); } else { - _sqlAuthLogger.LogInfo(_typeName, methodName, "No user-defined SqlAuthenticationInitializer found."); + _sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "No user-defined SqlAuthenticationInitializer found."); } // add user-defined providers, if any. @@ -107,12 +106,12 @@ public SqlAuthenticationProviderManager(SqlAuthenticationProviderConfigurationSe _providers[authentication] = provider; authenticationsWithAppSpecifiedProvider.Add(authentication); - _sqlAuthLogger.LogInfo(_typeName, methodName, string.Format("Added user-defined auth provider: {0} for authentication {1}.", providerSettings?.Type, authentication)); + _sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, string.Format("Added user-defined auth provider: {0} for authentication {1}.", providerSettings?.Type, authentication)); } } else { - _sqlAuthLogger.LogInfo(_typeName, methodName, "No user-defined auth providers."); + _sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, "No user-defined auth providers."); } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.cs index 4c101d30df..401fc23466 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.cs @@ -2,10 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Linq; namespace Microsoft.Data.SqlClient { @@ -23,7 +21,6 @@ internal partial class SqlAuthenticationProviderManager private const string ActiveDirectoryMSI = "active directory msi"; private const string ActiveDirectoryDefault = "active directory default"; - private readonly string _typeName; private readonly IReadOnlyCollection _authenticationsWithAppSpecifiedProvider; private readonly ConcurrentDictionary _providers; private readonly SqlClientLogger _sqlAuthLogger = new SqlClientLogger(); @@ -55,10 +52,9 @@ private static void SetDefaultAuthProviders(SqlAuthenticationProviderManager ins /// public SqlAuthenticationProviderManager() { - _typeName = GetType().Name; _providers = new ConcurrentDictionary(); _authenticationsWithAppSpecifiedProvider = new HashSet(); - _sqlAuthLogger.LogInfo(_typeName, "Ctor", "No SqlAuthProviders configuration section found."); + _sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), "Ctor", "No SqlAuthProviders configuration section found."); } /// @@ -85,9 +81,16 @@ public bool SetProvider(SqlAuthenticationMethod authenticationMethod, SqlAuthent throw SQL.UnsupportedAuthenticationByProvider(authenticationMethod.ToString(), provider.GetType().Name); } var methodName = "SetProvider"; - if (_authenticationsWithAppSpecifiedProvider.Contains(authenticationMethod)) + if (_authenticationsWithAppSpecifiedProvider.Count > 0) { - _sqlAuthLogger.LogError(_typeName, methodName, $"Failed to add provider {GetProviderType(provider)} because a user-defined provider with type {GetProviderType(_providers[authenticationMethod])} already existed for authentication {authenticationMethod}."); + foreach (SqlAuthenticationMethod candidateMethod in _authenticationsWithAppSpecifiedProvider) + { + if (candidateMethod == authenticationMethod) + { + _sqlAuthLogger.LogError(nameof(SqlAuthenticationProviderManager), methodName, $"Failed to add provider {GetProviderType(provider)} because a user-defined provider with type {GetProviderType(_providers[authenticationMethod])} already existed for authentication {authenticationMethod}."); + break; + } + } } _providers.AddOrUpdate(authenticationMethod, provider, (key, oldProvider) => { @@ -99,7 +102,7 @@ public bool SetProvider(SqlAuthenticationMethod authenticationMethod, SqlAuthent { provider.BeforeLoad(authenticationMethod); } - _sqlAuthLogger.LogInfo(_typeName, methodName, $"Added auth provider {GetProviderType(provider)}, overriding existed provider {GetProviderType(oldProvider)} for authentication {authenticationMethod}."); + _sqlAuthLogger.LogInfo(nameof(SqlAuthenticationProviderManager), methodName, $"Added auth provider {GetProviderType(provider)}, overriding existed provider {GetProviderType(oldProvider)} for authentication {authenticationMethod}."); return provider; }); return true; diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs index b576ef3510..c3ebf3d4d7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -12,7 +12,6 @@ using System.Data.SqlTypes; using System.Diagnostics; using System.IO; -using System.Linq; using System.Runtime.CompilerServices; using System.Text; using System.Threading; @@ -2966,7 +2965,11 @@ internal bool TryGetColumnEncryptionKeyStoreProvider(string providerName, out Sq /// Combined list of provider names internal List GetColumnEncryptionCustomKeyStoreProvidersNames() { - return _customColumnEncryptionKeyStoreProviders.Keys.ToList(); + if (_customColumnEncryptionKeyStoreProviders.Count > 0) + { + return new List(_customColumnEncryptionKeyStoreProviders.Keys); + } + return new List(0); } // If the user part is quoted, remove first and last brackets and then unquote any right square @@ -3973,7 +3976,8 @@ private SqlDataReader TryFetchInputParameterEncryptionInfo(int timeout, inputParameterEncryptionNeeded = true; } - _sqlRPCParameterEncryptionReqArray = describeParameterEncryptionRpcOriginalRpcMap.Keys.ToArray(); + _sqlRPCParameterEncryptionReqArray = new _SqlRPC[describeParameterEncryptionRpcOriginalRpcMap.Count]; + describeParameterEncryptionRpcOriginalRpcMap.Keys.CopyTo(_sqlRPCParameterEncryptionReqArray, 0); Debug.Assert(_sqlRPCParameterEncryptionReqArray.Length > 0, "There should be at-least 1 describe parameter encryption rpc request."); Debug.Assert(_sqlRPCParameterEncryptionReqArray.Length <= _SqlRPCBatchArray.Length, diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index 81bfc9084e..b101a21ea2 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -6,7 +6,6 @@ using System.Collections; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Collections.ObjectModel; using System.ComponentModel; using System.Data; using System.Data.Common; @@ -14,7 +13,6 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; -using System.Linq; using System.Reflection; using System.Security; using System.Threading; @@ -266,7 +264,11 @@ internal bool TryGetColumnEncryptionKeyStoreProvider(string providerName, out Sq /// Combined list of provider names internal static List GetColumnEncryptionSystemKeyStoreProvidersNames() { - return s_systemColumnEncryptionKeyStoreProviders.Keys.ToList(); + if (s_systemColumnEncryptionKeyStoreProviders.Count > 0) + { + return new List(s_systemColumnEncryptionKeyStoreProviders.Keys); + } + return new List(0); } /// @@ -279,13 +281,13 @@ internal List GetColumnEncryptionCustomKeyStoreProvidersNames() if (_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0) { - return _customColumnEncryptionKeyStoreProviders.Keys.ToList(); + return new List(_customColumnEncryptionKeyStoreProviders.Keys); } if (s_globalCustomColumnEncryptionKeyStoreProviders is not null) { - return s_globalCustomColumnEncryptionKeyStoreProviders.Keys.ToList(); + return new List(s_globalCustomColumnEncryptionKeyStoreProviders.Keys); } - return new List(); + return new List(0); } /// diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs index 4aebe4b518..2e8cd441ce 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -8,7 +8,6 @@ using System.Data; using System.Diagnostics; using System.Globalization; -using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; @@ -1509,7 +1508,15 @@ internal static Exception InvalidKeySize(string algorithmName, int actualKeyleng internal static Exception InvalidEncryptionType(string algorithmName, SqlClientEncryptionType encryptionType, params SqlClientEncryptionType[] validEncryptionTypes) { const string valueSeparator = @", "; - return ADP.Argument(StringsHelper.GetString(Strings.TCE_InvalidEncryptionType, algorithmName, encryptionType.ToString(), string.Join(valueSeparator, validEncryptionTypes.Select((validEncryptionType => @"'" + validEncryptionType + @"'")))), TdsEnums.TCE_PARAM_ENCRYPTIONTYPE); + return ADP.Argument( + StringsHelper.GetString( + Strings.TCE_InvalidEncryptionType, + algorithmName, + encryptionType.ToString(), + string.Join(valueSeparator, Map(validEncryptionTypes, static validEncryptionType => $"'{validEncryptionType:G}'")) + ), + TdsEnums.TCE_PARAM_ENCRYPTIONTYPE + ); } internal static Exception InvalidCipherTextSize(int actualSize, int minimumSize) @@ -1577,8 +1584,8 @@ internal static Exception ColumnMasterKeySignatureVerificationFailed(string cmkP internal static Exception InvalidKeyStoreProviderName(string providerName, List systemProviders, List customProviders) { const string valueSeparator = @", "; - string systemProviderStr = string.Join(valueSeparator, systemProviders.Select(provider => $"'{provider}'")); - string customProviderStr = string.Join(valueSeparator, customProviders.Select(provider => $"'{provider}'")); + string systemProviderStr = string.Join(valueSeparator, Map(systemProviders, static provider => $"'{provider}'")); + string customProviderStr = string.Join(valueSeparator, Map(customProviders, static provider => $"'{provider}'")); return ADP.Argument(StringsHelper.GetString(Strings.TCE_InvalidKeyStoreProviderName, providerName, systemProviderStr, customProviderStr)); } @@ -1797,8 +1804,8 @@ internal static Exception UnsupportedNormalizationVersion(byte version) internal static Exception UnrecognizedKeyStoreProviderName(string providerName, List systemProviders, List customProviders) { const string valueSeparator = @", "; - string systemProviderStr = string.Join(valueSeparator, systemProviders.Select(provider => @"'" + provider + @"'")); - string customProviderStr = string.Join(valueSeparator, customProviders.Select(provider => @"'" + provider + @"'")); + string systemProviderStr = string.Join(valueSeparator, Map(systemProviders, static provider => @"'" + provider + @"'")); + string customProviderStr = string.Join(valueSeparator, Map(customProviders, static provider => @"'" + provider + @"'")); return ADP.Argument(StringsHelper.GetString(Strings.TCE_UnrecognizedKeyStoreProviderName, providerName, systemProviderStr, customProviderStr)); } @@ -1941,6 +1948,24 @@ internal static string GetSNIErrorMessage(int sniError) internal const int SqlDependencyServerTimeout = 5 * 24 * 3600; // 5 days - used to compute default TTL of the dependency internal const string SqlNotificationServiceDefault = "SqlQueryNotificationService"; internal const string SqlNotificationStoredProcedureDefault = "SqlQueryNotificationStoredProcedure"; + + private static IEnumerable Map(IEnumerable source, Func selector) + { + if (source == null) + { + throw new ArgumentNullException(nameof(source)); + } + + if (selector == null) + { + throw new ArgumentNullException(nameof(selector)); + } + + foreach (T element in source) + { + yield return selector(element); + } + } } sealed internal class SQLMessage diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Common/DBConnectionString.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Common/DBConnectionString.cs index 57391584d2..261e91b8dd 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Common/DBConnectionString.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Common/DBConnectionString.cs @@ -6,11 +6,9 @@ namespace Microsoft.Data.Common { using System; - using System.Collections; using System.Collections.Generic; using System.Data; using System.Diagnostics; - using System.Linq; using System.Text; using Microsoft.Data.SqlClient; @@ -311,7 +309,12 @@ private void ValidateCombinedSet(DBConnectionString componentSet, DBConnectionSt // Component==Allow, Combined==Allow // All values in the Combined Set should also be in the Component Set // Combined - Component == null - Debug.Assert(combinedSet._restrictionValues.Except(componentSet._restrictionValues).Count() == 0, "Combined set allows values not allowed by component set"); +#if DEBUG + HashSet combined = new HashSet(combinedSet._restrictionValues); + HashSet component = new HashSet(componentSet._restrictionValues); + combined.ExceptWith(component); + Debug.Assert(combined.Count == 0, "Combined set allows values not allowed by component set"); +#endif } else if (combinedSet._behavior == KeyRestrictionBehavior.PreventUsage) { @@ -330,14 +333,24 @@ private void ValidateCombinedSet(DBConnectionString componentSet, DBConnectionSt // Component==PreventUsage, Combined==Allow // There shouldn't be any of the values from the Component Set in the Combined Set // Intersect(Component, Combined) == null - Debug.Assert(combinedSet._restrictionValues.Intersect(componentSet._restrictionValues).Count() == 0, "Combined values allows values prevented by component set"); +#if DEBUG + HashSet combined = new HashSet(combinedSet._restrictionValues); + HashSet component = new HashSet(componentSet._restrictionValues); + combined.IntersectWith(component); + Debug.Assert(combined.Count == 0, "Combined values allows values prevented by component set"); +#endif } else if (combinedSet._behavior == KeyRestrictionBehavior.PreventUsage) { // Component==PreventUsage, Combined==PreventUsage // All values in the Component Set should also be in the Combined Set // Component - Combined == null - Debug.Assert(componentSet._restrictionValues.Except(combinedSet._restrictionValues).Count() == 0, "Combined values does not prevent all of the values prevented by the component set"); +#if DEBUG + HashSet combined = new HashSet(combinedSet._restrictionValues); + HashSet component = new HashSet(componentSet._restrictionValues); + component.IntersectWith(combined); + Debug.Assert(component.Count == 0, "Combined values does not prevent all of the values prevented by the component set"); +#endif } else { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.cs index d0ab1507d9..d0757807be 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlAuthenticationProviderManager.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Collections.Concurrent; -using System.Linq; using System.Configuration; namespace Microsoft.Data.SqlClient @@ -166,10 +165,17 @@ public bool SetProvider(SqlAuthenticationMethod authenticationMethod, SqlAuthent } var methodName = "SetProvider"; - if (_authenticationsWithAppSpecifiedProvider.Contains(authenticationMethod)) + + if (_authenticationsWithAppSpecifiedProvider.Count > 0) { - _sqlAuthLogger.LogError(_typeName, methodName, $"Failed to add provider {GetProviderType(provider)} because a user-defined provider with type {GetProviderType(_providers[authenticationMethod])} already existed for authentication {authenticationMethod}."); - return false; + foreach (SqlAuthenticationMethod candidateMethod in _authenticationsWithAppSpecifiedProvider) + { + if (candidateMethod == authenticationMethod) + { + _sqlAuthLogger.LogError(_typeName, methodName, $"Failed to add provider {GetProviderType(provider)} because a user-defined provider with type {GetProviderType(_providers[authenticationMethod])} already existed for authentication {authenticationMethod}."); + break; + } + } } _providers.AddOrUpdate(authenticationMethod, provider, (key, oldProvider) => { diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs index fdfa0772ea..5aa96c540a 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlCommand.cs @@ -11,7 +11,6 @@ using System.Data.SqlTypes; using System.Diagnostics; using System.IO; -using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.ConstrainedExecution; using System.Security.Permissions; @@ -3320,7 +3319,11 @@ internal bool TryGetColumnEncryptionKeyStoreProvider(string providerName, out Sq /// Combined list of provider names internal List GetColumnEncryptionCustomKeyStoreProvidersNames() { - return _customColumnEncryptionKeyStoreProviders.Keys.ToList(); + if (_customColumnEncryptionKeyStoreProviders.Count > 0) + { + return new List(_customColumnEncryptionKeyStoreProviders.Keys); + } + return new List(0); } // If the user part is quoted, remove first and last brackets and then unquote any right square @@ -4480,7 +4483,8 @@ private SqlDataReader TryFetchInputParameterEncryptionInfo(int timeout, inputParameterEncryptionNeeded = true; } - _sqlRPCParameterEncryptionReqArray = describeParameterEncryptionRpcOriginalRpcMap.Keys.ToArray(); + _sqlRPCParameterEncryptionReqArray = new _SqlRPC[describeParameterEncryptionRpcOriginalRpcMap.Count]; + describeParameterEncryptionRpcOriginalRpcMap.Keys.CopyTo(_sqlRPCParameterEncryptionReqArray, 0); Debug.Assert(_sqlRPCParameterEncryptionReqArray.Length > 0, "There should be at-least 1 describe parameter encryption rpc request."); Debug.Assert(_sqlRPCParameterEncryptionReqArray.Length <= _SqlRPCBatchArray.Length, diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs index 04acfcb612..4d79f1d690 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -13,7 +13,6 @@ using System.Diagnostics.CodeAnalysis; using System.Globalization; using System.IO; -using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; using System.Runtime.ConstrainedExecution; @@ -255,7 +254,11 @@ internal bool TryGetColumnEncryptionKeyStoreProvider(string providerName, out Sq /// Combined list of provider names internal static List GetColumnEncryptionSystemKeyStoreProvidersNames() { - return s_systemColumnEncryptionKeyStoreProviders.Keys.ToList(); + if (s_systemColumnEncryptionKeyStoreProviders.Count > 0) + { + return new List(s_systemColumnEncryptionKeyStoreProviders.Keys); + } + return new List(0); } /// @@ -268,13 +271,13 @@ internal List GetColumnEncryptionCustomKeyStoreProvidersNames() if (_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0) { - return _customColumnEncryptionKeyStoreProviders.Keys.ToList(); + return new List(_customColumnEncryptionKeyStoreProviders.Keys); } if (s_globalCustomColumnEncryptionKeyStoreProviders is not null) { - return s_globalCustomColumnEncryptionKeyStoreProviders.Keys.ToList(); + return new List(s_globalCustomColumnEncryptionKeyStoreProviders.Keys); } - return new List(); + return new List(0); } private SqlDebugContext _sdc; // SQL Debugging support diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs index 621d47b20f..aa75618a7d 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlUtil.cs @@ -7,7 +7,6 @@ using System.Data; using System.Diagnostics; using System.Globalization; -using System.Linq; using System.Reflection; using System.Runtime.CompilerServices; using System.Runtime.ExceptionServices; @@ -1639,11 +1638,15 @@ static internal Exception InvalidKeySize(string algorithmName, int actualKeyleng static internal Exception InvalidEncryptionType(string algorithmName, SqlClientEncryptionType encryptionType, params SqlClientEncryptionType[] validEncryptionTypes) { const string valueSeparator = @", "; - return ADP.Argument(StringsHelper.GetString( - Strings.TCE_InvalidEncryptionType, - algorithmName, - encryptionType.ToString(), - string.Join(valueSeparator, validEncryptionTypes.Select((validEncryptionType => @"'" + validEncryptionType + @"'")))), TdsEnums.TCE_PARAM_ENCRYPTIONTYPE); + return ADP.Argument( + StringsHelper.GetString( + Strings.TCE_InvalidEncryptionType, + algorithmName, + encryptionType.ToString(), + string.Join(valueSeparator, Map(validEncryptionTypes, static validEncryptionType => $"'{validEncryptionType:G}'")) + ), + TdsEnums.TCE_PARAM_ENCRYPTIONTYPE + ); } static internal Exception NullPlainText() @@ -1731,8 +1734,8 @@ static internal Exception ProcEncryptionMetadataMissing(string procedureName) static internal Exception InvalidKeyStoreProviderName(string providerName, List systemProviders, List customProviders) { const string valueSeparator = @", "; - string systemProviderStr = string.Join(valueSeparator, systemProviders.Select(provider => $"'{provider}'")); - string customProviderStr = string.Join(valueSeparator, customProviders.Select(provider => $"'{provider}'")); + string systemProviderStr = string.Join(valueSeparator, Map(systemProviders, static provider => $"'{provider}'")); + string customProviderStr = string.Join(valueSeparator, Map(customProviders, static provider => $"'{provider}'")); return ADP.Argument(StringsHelper.GetString(Strings.TCE_InvalidKeyStoreProviderName, providerName, systemProviderStr, customProviderStr)); } @@ -1956,8 +1959,8 @@ static internal Exception UnsupportedNormalizationVersion(byte version) static internal Exception UnrecognizedKeyStoreProviderName(string providerName, List systemProviders, List customProviders) { const string valueSeparator = @", "; - string systemProviderStr = string.Join(valueSeparator, systemProviders.Select(provider => @"'" + provider + @"'")); - string customProviderStr = string.Join(valueSeparator, customProviders.Select(provider => @"'" + provider + @"'")); + string systemProviderStr = string.Join(valueSeparator, Map(systemProviders, static provider => @"'" + provider + @"'")); + string customProviderStr = string.Join(valueSeparator, Map(customProviders, static provider => @"'" + provider + @"'")); return ADP.Argument(StringsHelper.GetString(Strings.TCE_UnrecognizedKeyStoreProviderName, providerName, systemProviderStr, customProviderStr)); } @@ -2406,6 +2409,24 @@ static internal string GetSNIErrorMessage(int sniError) // constant strings internal const string Transaction = "Transaction"; internal const string Connection = "Connection"; + + private static IEnumerable Map(IEnumerable source, Func selector) + { + if (source == null) + { + throw new ArgumentNullException(nameof(source)); + } + + if (selector == null) + { + throw new ArgumentNullException(nameof(selector)); + } + + foreach (T element in source) + { + yield return selector(element); + } + } } sealed internal class SQLMessage diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs index d8171c63b6..74427d732a 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Runtime.CompilerServices; using System.Runtime.ConstrainedExecution; using System.Runtime.InteropServices; @@ -3333,8 +3332,18 @@ internal void CloneCleanupAltMetaDataSetArray() internal void PushBuffer(byte[] buffer, int read) { - Debug.Assert(!_snapshotInBuffs.Any(b => object.ReferenceEquals(b, buffer))); - +#if DEBUG + if (_snapshotInBuffs != null && _snapshotInBuffs.Count > 0) + { + foreach (PacketData packet in _snapshotInBuffs) + { + if (object.ReferenceEquals(packet.Buffer, buffer)) + { + Debug.Assert(false,"buffer is already present in packet list"); + } + } + } +#endif PacketData packetData = new PacketData(); packetData.Buffer = buffer; packetData.Read = read; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AlwaysEncryptedEnclaveProviderUtils.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AlwaysEncryptedEnclaveProviderUtils.cs index 3263031dde..ff84ab3310 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AlwaysEncryptedEnclaveProviderUtils.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AlwaysEncryptedEnclaveProviderUtils.cs @@ -18,23 +18,21 @@ public EnclavePublicKey(byte[] payload) internal class EnclaveDiffieHellmanInfo { - public int Size { get; private set; } + public int Size => sizeof(int) + sizeof(int) + PublicKey?.Length ?? 0 + PublicKeySignature?.Length ?? 0; public byte[] PublicKey { get; private set; } public byte[] PublicKeySignature { get; private set; } - public EnclaveDiffieHellmanInfo(byte[] payload) + public EnclaveDiffieHellmanInfo(byte[] payload, int offset) { - Size = payload.Length; - - int publicKeySize = BitConverter.ToInt32(payload, 0); - int publicKeySignatureSize = BitConverter.ToInt32(payload, 4); + int publicKeySize = BitConverter.ToInt32(payload, offset + 0); + int publicKeySignatureSize = BitConverter.ToInt32(payload, offset + 4); PublicKey = new byte[publicKeySize]; PublicKeySignature = new byte[publicKeySignatureSize]; - Buffer.BlockCopy(payload, 8, PublicKey, 0, publicKeySize); - Buffer.BlockCopy(payload, 8 + publicKeySize, PublicKeySignature, 0, publicKeySignatureSize); + Buffer.BlockCopy(payload, offset + 8, PublicKey, 0, publicKeySize); + Buffer.BlockCopy(payload, offset + 8 + publicKeySize, PublicKeySignature, 0, publicKeySignatureSize); } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs index 433347ef10..47bd74b18f 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/AzureAttestationBasedEnclaveProvider.cs @@ -4,8 +4,8 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IdentityModel.Tokens.Jwt; -using System.Linq; using System.Runtime.Caching; using System.Security.Claims; using System.Security.Cryptography; @@ -191,14 +191,12 @@ public AzureAttestationInfo(byte[] attestationInfo) offset += sizeof(uint); // Get the enclave public key - byte[] identityBuffer = attestationInfo.Skip(offset).Take(identitySize).ToArray(); + byte[] identityBuffer = EnclaveHelpers.TakeBytesAndAdvance(attestationInfo, ref offset, identitySize); Identity = new EnclavePublicKey(identityBuffer); - offset += identitySize; // Get Azure attestation token - byte[] attestationTokenBuffer = attestationInfo.Skip(offset).Take(attestationTokenSize).ToArray(); - AttestationToken = new AzureAttestationToken(attestationTokenBuffer); - offset += attestationTokenSize; + byte[] attestationTokenBuffer = EnclaveHelpers.TakeBytesAndAdvance(attestationInfo, ref offset, attestationTokenSize); + AttestationToken = new AzureAttestationToken(attestationTokenBuffer); uint secureSessionInfoResponseSize = BitConverter.ToUInt32(attestationInfo, offset); offset += sizeof(uint); @@ -206,10 +204,10 @@ public AzureAttestationInfo(byte[] attestationInfo) SessionId = BitConverter.ToInt64(attestationInfo, offset); offset += sizeof(long); - int secureSessionBufferSize = Convert.ToInt32(secureSessionInfoResponseSize) - sizeof(uint); - byte[] secureSessionBuffer = attestationInfo.Skip(offset).Take(secureSessionBufferSize).ToArray(); - EnclaveDHInfo = new EnclaveDiffieHellmanInfo(secureSessionBuffer); - offset += Convert.ToInt32(EnclaveDHInfo.Size); + EnclaveDHInfo = new EnclaveDiffieHellmanInfo(attestationInfo, offset); + offset += EnclaveDHInfo.Size; + + Debug.Assert(offset == attestationInfo.Length); } catch (Exception exception) { @@ -467,7 +465,7 @@ private void ValidateAttestationClaims(EnclaveType enclaveType, string attestati // Get all the claims from the token Dictionary claims = new Dictionary(); - foreach (Claim claim in token.Claims.ToList()) + foreach (Claim claim in token.Claims) { claims.Add(claim.Type, claim.Value); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryFactory.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryFactory.cs index 0e87c10aa6..859c3ef510 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryFactory.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryFactory.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; namespace Microsoft.Data.SqlClient { @@ -115,10 +114,24 @@ private static bool TransientErrorsCondition(Exception e, IEnumerable retri { foreach (SqlError item in ex.Errors) { - bool retriable; + bool retriable = false; lock (s_syncObject) { - retriable = retriableConditions.Contains(item.Number); + if (retriableConditions is ICollection collection) + { + retriable = collection.Contains(item.Number); + } + else + { + foreach (int candidate in retriableConditions) + { + if (candidate == item.Number) + { + retriable = true; + break; + } + } + } } if (retriable) { diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryLogicLoader.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryLogicLoader.cs index dfc96c566a..eccf21d4d0 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryLogicLoader.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Reliability/SqlConfigurableRetryLogicLoader.cs @@ -3,7 +3,7 @@ // See the LICENSE file in the project root for more information. using System; -using System.Linq; +using System.Collections.Generic; using System.Reflection; using System.Text.RegularExpressions; @@ -72,7 +72,7 @@ private static SqlRetryLogicBaseProvider CreateRetryLogicProvider(string section // Prepare the transient error lists if (!string.IsNullOrEmpty(configSection.TransientErrors)) { - retryOption.TransientErrors = configSection.TransientErrors.Split(',').Select(x => Convert.ToInt32(x)).ToList(); + retryOption.TransientErrors = SplitErrorNumberList(configSection.TransientErrors); } // Prepare the authorized SQL statement just for SqlCommands @@ -215,11 +215,24 @@ private static object CreateInstance(Type type, string retryMethodName, SqlRetry private static object[] PrepareParamValues(ParameterInfo[] parameterInfos, SqlRetryLogicOption option, string retryMethod) { // The retry method must have at least one `SqlRetryLogicOption` - if (parameterInfos.FirstOrDefault(x => x.ParameterType == typeof(SqlRetryLogicOption)) == null) + if (parameterInfos != null && parameterInfos.Length > 0) { - string message = $"Failed to create {nameof(SqlRetryLogicBaseProvider)} object because of invalid {retryMethod}'s parameters." + - $"{Environment.NewLine}The function must have a paramter of type '{nameof(SqlRetryLogicOption)}'."; - throw new InvalidOperationException(message); + bool found = false; + for (int index = 0; index < parameterInfos.Length; index++) + { + if (parameterInfos[index].ParameterType == typeof(SqlRetryLogicOption)) + { + found = true; + break; + } + } + + if (!found) + { + string message = $"Failed to create {nameof(SqlRetryLogicBaseProvider)} object because of invalid {retryMethod}'s parameters." + + $"{Environment.NewLine}The function must have a paramter of type '{nameof(SqlRetryLogicOption)}'."; + throw new InvalidOperationException(message); + } } object[] funcParams = new object[parameterInfos.Length]; @@ -238,9 +251,21 @@ private static object[] PrepareParamValues(ParameterInfo[] parameterInfos, SqlRe // or there isn't another parameter with the same type and without a default value. else if (paramInfo.ParameterType == typeof(SqlRetryLogicOption)) { - if (!paramInfo.HasDefaultValue - || (paramInfo.HasDefaultValue - && parameterInfos.FirstOrDefault(x => x != paramInfo && !x.HasDefaultValue && x.ParameterType == typeof(SqlRetryLogicOption)) == null)) + bool foundOptionsParamWithNoDefaultValue = false; + for (int index = 0; index < parameterInfos.Length; index++) + { + if ( + parameterInfos[index] != paramInfo && + parameterInfos[index].ParameterType == typeof(SqlRetryLogicOption) && + !parameterInfos[index].HasDefaultValue + ) + { + foundOptionsParamWithNoDefaultValue = true; + break; + } + } + + if (!paramInfo.HasDefaultValue || (paramInfo.HasDefaultValue && !foundOptionsParamWithNoDefaultValue)) { funcParams[i] = option; } @@ -273,5 +298,26 @@ private static void OnRetryingEvent(object sender, SqlRetryingEventArgs args) SqlClientEventSource.Log.TryTraceEvent("{0}, Last exception:<{1}>", msg, lastException.Message); SqlClientEventSource.Log.TryAdvancedTraceEvent("{0}, Last exception: {1}", msg, lastException); } + + private static ICollection SplitErrorNumberList(string list) + { + if (!string.IsNullOrEmpty(list)) + { + string[] parts = list.Split(new char[] { ',' }, StringSplitOptions.RemoveEmptyEntries); + if (parts != null && parts.Length > 0) + { + HashSet set = new HashSet(); + for (int index = 0; index < parts.Length; index++) + { + if (int.TryParse(parts[index], System.Globalization.NumberStyles.AllowLeadingWhite | System.Globalization.NumberStyles.AllowTrailingWhite, null, out int value)) + { + set.Add(value); + } + } + return set; + } + } + return new HashSet(); + } } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionStringBuilder.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionStringBuilder.cs index 55849d033e..c9f9fd2912 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionStringBuilder.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlConnectionStringBuilder.cs @@ -13,7 +13,6 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Globalization; -using System.Linq; using System.Reflection; using Microsoft.Data.Common; @@ -640,15 +639,21 @@ private void SetPoolBlockingPeriodValue(PoolBlockingPeriod value) private Exception UnsupportedKeyword(string keyword) { #if !NETFRAMEWORK - if (s_notSupportedKeywords.Contains(keyword, StringComparer.OrdinalIgnoreCase)) + for (int index = 0; index < s_notSupportedKeywords.Length; index++) { - return SQL.UnsupportedKeyword(keyword); + if (string.Equals(keyword, s_notSupportedKeywords[index], StringComparison.OrdinalIgnoreCase)) + { + return SQL.UnsupportedKeyword(keyword); + } } - else if (s_notSupportedNetworkLibraryKeywords.Contains(keyword, StringComparer.OrdinalIgnoreCase)) + + for (int index = 0; index < s_notSupportedNetworkLibraryKeywords.Length; index++) { - return SQL.NetworkLibraryKeywordNotSupported(); + if (string.Equals(keyword, s_notSupportedNetworkLibraryKeywords[index], StringComparison.OrdinalIgnoreCase)) + { + return SQL.NetworkLibraryKeywordNotSupported(); + } } - else #endif return ADP.KeywordNotSupported(keyword); } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlSecurityUtility.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlSecurityUtility.cs index ced6e62755..d59fa1f91a 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlSecurityUtility.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlSecurityUtility.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Reflection; using System.Security.Cryptography; using System.Text; @@ -421,8 +420,21 @@ internal static void ThrowIfKeyPathIsNotTrustedForServer(string serverName, stri if (SqlConnection.ColumnEncryptionTrustedMasterKeyPaths.TryGetValue(serverName, out IList trustedKeyPaths)) { // If the list is null or is empty or if the keyPath doesn't exist in the trusted key paths, then throw an exception. - if (trustedKeyPaths is null || trustedKeyPaths.Count() == 0 || - trustedKeyPaths.Any(trustedKeyPath => trustedKeyPath.Equals(keyPath, StringComparison.InvariantCultureIgnoreCase)) == false) + + bool pathIsKnown = false; + if (trustedKeyPaths != null) + { + foreach (string candidate in trustedKeyPaths) + { + if (string.Equals(keyPath, candidate, StringComparison.InvariantCultureIgnoreCase)) + { + pathIsKnown = true; + break; + } + } + } + + if (!pathIsKnown) { // throw an exception since the key path is not in the trusted key paths list for this server throw SQL.UntrustedKeyPath(keyPath, serverName); diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlSymmetricKeyCache.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlSymmetricKeyCache.cs index 385fd4981c..663116ed59 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlSymmetricKeyCache.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SqlSymmetricKeyCache.cs @@ -3,9 +3,7 @@ // See the LICENSE file in the project root for more information. using System; -using System.Collections.Generic; using System.Diagnostics; -using System.Linq; using System.Runtime.Caching; using System.Text; diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProvider.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProvider.cs index 66bb63c31d..fd180d12d6 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProvider.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProvider.cs @@ -4,9 +4,8 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; -using System.Linq; -using System.Net; using System.Net.Http; using System.Runtime.Serialization.Json; using System.Security.Cryptography.X509Certificates; @@ -135,15 +134,14 @@ public AttestationInfo(byte[] attestationInfo) int enclaveReportSize = BitConverter.ToInt32(attestationInfo, offset); offset += sizeof(uint); - byte[] identityBuffer = attestationInfo.Skip(offset).Take(identitySize).ToArray(); + byte[] identityBuffer = EnclaveHelpers.TakeBytesAndAdvance(attestationInfo, ref offset, identitySize); Identity = new EnclavePublicKey(identityBuffer); - offset += identitySize; - byte[] healthReportBuffer = attestationInfo.Skip(offset).Take(healthReportSize).ToArray(); + byte[] healthReportBuffer = EnclaveHelpers.TakeBytesAndAdvance(attestationInfo, ref offset, healthReportSize); HealthReport = new HealthReport(healthReportBuffer); - offset += healthReportSize; - byte[] enclaveReportBuffer = attestationInfo.Skip(offset).Take(enclaveReportSize).ToArray(); + byte[] enclaveReportBuffer = new byte[enclaveReportSize]; + Buffer.BlockCopy(attestationInfo, offset, enclaveReportBuffer, 0, enclaveReportSize); EnclaveReportPackage = new EnclaveReportPackage(enclaveReportBuffer); offset += EnclaveReportPackage.GetSizeInPayload(); @@ -153,10 +151,10 @@ public AttestationInfo(byte[] attestationInfo) SessionId = BitConverter.ToInt64(attestationInfo, offset); offset += sizeof(long); - int secureSessionBufferSize = Convert.ToInt32(secureSessionInfoResponseSize) - sizeof(uint); - byte[] secureSessionBuffer = attestationInfo.Skip(offset).Take(secureSessionBufferSize).ToArray(); - EnclaveDHInfo = new EnclaveDiffieHellmanInfo(secureSessionBuffer); + EnclaveDHInfo = new EnclaveDiffieHellmanInfo(attestationInfo, offset); offset += Convert.ToInt32(EnclaveDHInfo.Size); + + Debug.Assert(offset == attestationInfo.Length); } } @@ -200,11 +198,11 @@ public EnclaveReportPackage(byte[] payload) Size = payload.Length; int offset = 0; - PackageHeader = new EnclaveReportPackageHeader(payload.Skip(offset).ToArray()); - offset += PackageHeader.GetSizeInPayload(); + byte[] headerBuffer = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, EnclaveReportPackageHeader.SizeInPayload); + PackageHeader = new EnclaveReportPackageHeader(headerBuffer); - Report = new EnclaveReport(payload.Skip(offset).ToArray()); - offset += Report.GetSizeInPayload(); + byte[] reportBuffer = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, EnclaveReport.SizeInPayload); + Report = new EnclaveReport(reportBuffer); // Modules are not used for anything currently, ignore parsing for now // @@ -220,14 +218,13 @@ public EnclaveReportPackage(byte[] payload) // Moving the offset back to the start of the report, // we need the report as a byte buffer for signature verification. - offset = PackageHeader.GetSizeInPayload(); + offset = EnclaveReportPackageHeader.SizeInPayload; + int dataToHashSize = Convert.ToInt32(PackageHeader.SignedStatementSize); - ReportAsBytes = payload.Skip(offset).Take(dataToHashSize).ToArray(); - offset += dataToHashSize; + ReportAsBytes = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, dataToHashSize); int signatureSize = Convert.ToInt32(PackageHeader.SignatureSize); - SignatureBlob = payload.Skip(offset).Take(signatureSize).ToArray(); - offset += signatureSize; + SignatureBlob = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, signatureSize); } public int GetSizeInPayload() @@ -274,17 +271,14 @@ public EnclaveReportPackageHeader(byte[] payload) offset += sizeof(uint); } - public int GetSizeInPayload() - { - return 6 * sizeof(uint); - } + public static int SizeInPayload => 6 * sizeof(uint); } // A managed model of struct VBS_ENCLAVE_REPORT // https://msdn.microsoft.com/en-us/library/windows/desktop/mt844255(v=vs.85).aspx internal class EnclaveReport { - private int Size { get; set; } + private const int EnclaveDataLength = 64; public uint ReportSize { get; set; } @@ -292,14 +286,10 @@ internal class EnclaveReport public byte[] EnclaveData { get; set; } - private const int EnclaveDataLength = 64; - public EnclaveIdentity Identity { get; set; } public EnclaveReport(byte[] payload) { - Size = payload.Length; - int offset = 0; ReportSize = BitConverter.ToUInt32(payload, offset); @@ -308,38 +298,31 @@ public EnclaveReport(byte[] payload) ReportVersion = BitConverter.ToUInt32(payload, offset); offset += sizeof(uint); - EnclaveData = payload.Skip(offset).Take(EnclaveDataLength).ToArray(); - offset += EnclaveDataLength; + EnclaveData = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, EnclaveDataLength); - Identity = new EnclaveIdentity(payload.Skip(offset).ToArray()); - offset += Identity.GetSizeInPayload(); + Identity = new EnclaveIdentity(EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, EnclaveIdentity.SizeInPayload)); } - public int GetSizeInPayload() - { - return sizeof(uint) * 2 + sizeof(byte) * 64 + Identity.GetSizeInPayload(); - } + public static int SizeInPayload => sizeof(uint) * 2 + sizeof(byte) * 64 + EnclaveIdentity.SizeInPayload; } // A managed model of struct ENCLAVE_IDENTITY // https://msdn.microsoft.com/en-us/library/windows/desktop/mt844239(v=vs.85).aspx internal class EnclaveIdentity { - private int Size { get; set; } + private const int ImageEnclaveLongIdLength = 32; - private static readonly int ImageEnclaveLongIdLength = 32; + private const int ImageEnclaveShortIdLength = 16; - private static readonly int ImageEnclaveShortIdLength = 16; + public byte[] OwnerId; - public byte[] OwnerId = new byte[ImageEnclaveLongIdLength]; + public byte[] UniqueId; - public byte[] UniqueId = new byte[ImageEnclaveLongIdLength]; + public byte[] AuthorId; - public byte[] AuthorId = new byte[ImageEnclaveLongIdLength]; + public byte[] FamilyId; - public byte[] FamilyId = new byte[ImageEnclaveShortIdLength]; - - public byte[] ImageId = new byte[ImageEnclaveShortIdLength]; + public byte[] ImageId; public uint EnclaveSvn { get; set; } @@ -357,29 +340,17 @@ public EnclaveIdentity() { } public EnclaveIdentity(byte[] payload) { - Size = payload.Length; - int offset = 0; - int ownerIdLength = ImageEnclaveLongIdLength; - OwnerId = payload.Skip(offset).Take(ownerIdLength).ToArray(); - offset += ownerIdLength; + OwnerId = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, ImageEnclaveLongIdLength); - int uniqueIdLength = ImageEnclaveLongIdLength; - UniqueId = payload.Skip(offset).Take(uniqueIdLength).ToArray(); - offset += uniqueIdLength; + UniqueId = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, ImageEnclaveLongIdLength); - int authorIdLength = ImageEnclaveLongIdLength; - AuthorId = payload.Skip(offset).Take(authorIdLength).ToArray(); - offset += authorIdLength; + AuthorId = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, ImageEnclaveLongIdLength); - int familyIdLength = ImageEnclaveShortIdLength; - FamilyId = payload.Skip(offset).Take(familyIdLength).ToArray(); - offset += familyIdLength; + FamilyId = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, ImageEnclaveShortIdLength); - int imageIdLength = ImageEnclaveShortIdLength; - ImageId = payload.Skip(offset).Take(imageIdLength).ToArray(); - offset += imageIdLength; + ImageId = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, ImageEnclaveShortIdLength); EnclaveSvn = BitConverter.ToUInt32(payload, offset); offset += sizeof(uint); @@ -400,10 +371,8 @@ public EnclaveIdentity(byte[] payload) offset += sizeof(uint); } - public int GetSizeInPayload() - { - return sizeof(byte) * ImageEnclaveLongIdLength * 3 + sizeof(byte) * ImageEnclaveShortIdLength * 2 + sizeof(uint) * 6; - } + public static int SizeInPayload => sizeof(byte) * ImageEnclaveLongIdLength * 3 + sizeof(byte) * ImageEnclaveShortIdLength * 2 + sizeof(uint) * 6; + } // A managed model of struct VBS_ENCLAVE_REPORT_VARDATA_HEADER @@ -424,29 +393,26 @@ public EnclaveReportModuleHeader(byte[] payload) offset += sizeof(uint); } - public int GetSizeInPayload() - { - return 2 * sizeof(uint); - } + public static int SizeInPayload => 2 * sizeof(uint); } // A managed model of struct VBS_ENCLAVE_REPORT_MODULE // https://msdn.microsoft.com/en-us/library/windows/desktop/mt844256(v=vs.85).aspx internal class EnclaveReportModule { - private static readonly int ImageEnclaveLongIdLength = 32; + private const int ImageEnclaveLongIdLength = 32; - private static readonly int ImageEnclaveShortIdLength = 16; + private const int ImageEnclaveShortIdLength = 16; public EnclaveReportModuleHeader Header { get; set; } - public byte[] UniqueId = new byte[ImageEnclaveLongIdLength]; + public byte[] UniqueId; - public byte[] AuthorId = new byte[ImageEnclaveLongIdLength]; + public byte[] AuthorId; - public byte[] FamilyId = new byte[ImageEnclaveShortIdLength]; + public byte[] FamilyId; - public byte[] ImageId = new byte[ImageEnclaveShortIdLength]; + public byte[] ImageId; public uint Svn { get; set; } @@ -456,35 +422,26 @@ public EnclaveReportModule(byte[] payload) { int offset = 0; Header = new EnclaveReportModuleHeader(payload); - offset += Convert.ToInt32(Header.GetSizeInPayload()); + offset += EnclaveReportModuleHeader.SizeInPayload; - int uniqueIdLength = ImageEnclaveLongIdLength; - UniqueId = payload.Skip(offset).Take(uniqueIdLength).ToArray(); - offset += uniqueIdLength; + UniqueId = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, ImageEnclaveLongIdLength); - int authorIdLength = ImageEnclaveLongIdLength; - AuthorId = payload.Skip(offset).Take(authorIdLength).ToArray(); - offset += authorIdLength; + AuthorId = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, ImageEnclaveLongIdLength); - int familyIdLength = ImageEnclaveShortIdLength; - FamilyId = payload.Skip(offset).Take(familyIdLength).ToArray(); - offset += familyIdLength; + FamilyId = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, ImageEnclaveShortIdLength); - int imageIdLength = ImageEnclaveShortIdLength; - ImageId = payload.Skip(offset).Take(familyIdLength).ToArray(); - offset += imageIdLength; + ImageId = EnclaveHelpers.TakeBytesAndAdvance(payload, ref offset, ImageEnclaveShortIdLength); Svn = BitConverter.ToUInt32(payload, offset); offset += sizeof(uint); - int strLen = Convert.ToInt32(Header.ModuleSize) - offset; ModuleName = BitConverter.ToString(payload, offset, 1); offset += sizeof(char) * 1; } public int GetSizeInPayload() { - return Header.GetSizeInPayload() + Convert.ToInt32(Header.ModuleSize); + return EnclaveReportModuleHeader.SizeInPayload + Convert.ToInt32(Header.ModuleSize); } } @@ -499,4 +456,15 @@ internal enum EnclaveIdentityFlags } #endregion + + internal static class EnclaveHelpers + { + public static byte[] TakeBytesAndAdvance(byte[] input, ref int offset, int count) + { + byte[] output = new byte[count]; + Buffer.BlockCopy(input, offset, output, 0, count); + offset += count; + return output; + } + } } diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProviderBase.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProviderBase.cs index c4b9987f01..c6f0d20743 100644 --- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProviderBase.cs +++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/VirtualSecureModeEnclaveProviderBase.cs @@ -3,7 +3,6 @@ // See the LICENSE file in the project root for more information. using System; -using System.Linq; using System.Runtime.Caching; using System.Security.Cryptography; using System.Security.Cryptography.X509Certificates; @@ -220,7 +219,17 @@ private X509Certificate2Collection GetSigningCertificate(string attestationUrl, // Checks if any certificates in the collection are expired private bool AnyCertificatesExpired(X509Certificate2Collection certificates) { - return certificates.OfType().Any(c => c.NotAfter < DateTime.Now); + if (certificates != null) + { + foreach (object item in certificates) + { + if (item is X509Certificate2 certificate && certificate.NotAfter < DateTime.Now) + { + return true; + } + } + } + return false; } // Verifies that a chain of trust can be built from the health report provided @@ -285,7 +294,7 @@ private X509ChainStatusFlags VerifyHealthReportAgainstRootCertificate(X509Certif private void VerifyEnclaveReportSignature(EnclaveReportPackage enclaveReportPackage, X509Certificate2 healthReportCert) { // Check if report is formatted correctly - uint calculatedSize = Convert.ToUInt32(enclaveReportPackage.PackageHeader.GetSizeInPayload()) + enclaveReportPackage.PackageHeader.SignedStatementSize + enclaveReportPackage.PackageHeader.SignatureSize; + uint calculatedSize = Convert.ToUInt32(EnclaveReportPackageHeader.SizeInPayload) + enclaveReportPackage.PackageHeader.SignedStatementSize + enclaveReportPackage.PackageHeader.SignatureSize; if (calculatedSize != enclaveReportPackage.PackageHeader.PackageSize) { @@ -325,7 +334,31 @@ private void VerifyEnclavePolicy(EnclaveReportPackage enclaveReportPackage) // Verifies a byte[] enclave policy property private void VerifyEnclavePolicyProperty(string property, byte[] actual, byte[] expected) { - if (!actual.SequenceEqual(expected)) + bool different = false; + if (actual == null || expected == null) + { + different = true; + } + else + { + if (actual.Length != expected.Length) + { + different = true; + } + else + { + for (int index = 0; index < actual.Length; index++) + { + if (actual[index] != expected[index]) + { + different = true; + break; + } + } + } + } + + if (different) { string exceptionMessage = String.Format(Strings.VerifyEnclavePolicyFailedFormat, property, BitConverter.ToString(actual), BitConverter.ToString(expected)); throw new ArgumentException(exceptionMessage);