From 9a4e332bd2f362317d23e5287a85659eb2798c1a Mon Sep 17 00:00:00 2001
From: Parminder Kaur <88398605+Kaur-Parminder@users.noreply.github.com>
Date: Tue, 11 Jan 2022 18:45:03 -0800
Subject: [PATCH 001/123] Move to Shared - SqlSer.cs (#1313)
---
.../src/Microsoft.Data.SqlClient.csproj | 5 +-
.../netfx/src/Microsoft.Data.SqlClient.csproj | 4 +-
.../Microsoft/Data/SqlClient/Server/sqlser.cs | 295 ------------------
.../Microsoft/Data/SqlClient/Server/SqlSer.cs | 59 ++--
4 files changed, 44 insertions(+), 319 deletions(-)
delete mode 100644 src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/Server/sqlser.cs
rename src/Microsoft.Data.SqlClient/{netcore => }/src/Microsoft/Data/SqlClient/Server/SqlSer.cs (89%)
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
index d905ad16a9..71b71aa33b 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj
@@ -364,6 +364,9 @@
Microsoft\Data\SqlClient\Server\ValueUtilsSmi.cs
+
+ Microsoft\Data\SqlClient\Server\SqlSer.cs
+
Microsoft\Data\SqlClient\SignatureVerificationCache.cs
@@ -558,7 +561,7 @@
Microsoft\Data\SqlClient\SqlSequentialStream.cs
-
+
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
index f3d4c3b7d9..49dc12bd0a 100644
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
+++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj
@@ -249,6 +249,9 @@
Microsoft\Data\SqlClient\Server\SmiMetaDataProperty.cs
+
+ Microsoft\Data\SqlClient\Server\SqlSer.cs
+
Microsoft\Data\SqlClient\ColumnEncryptionKeyInfo.cs
@@ -629,7 +632,6 @@
-
diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/Server/sqlser.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/Server/sqlser.cs
deleted file mode 100644
index 2967aaaa55..0000000000
--- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/Server/sqlser.cs
+++ /dev/null
@@ -1,295 +0,0 @@
-// Licensed to the .NET Foundation under one or more agreements.
-// The .NET Foundation licenses this file to you under the MIT license.
-// See the LICENSE file in the project root for more information.
-
-using System;
-using System.Collections;
-using System.IO;
-using System.Reflection;
-using System.Runtime.CompilerServices;
-using Microsoft.Data.Common;
-
-namespace Microsoft.Data.SqlClient.Server
-{
- internal class SerializationHelperSql9
- {
- // Don't let anyone create an instance of this class.
- private SerializationHelperSql9() { }
-
- // Get the m_size of the serialized stream for this type, in bytes.
- // This method creates an instance of the type using the public
- // no-argument constructor, serializes it, and returns the m_size
- // in bytes.
- // Prevent inlining so that reflection calls are not moved to caller that may be in a different assembly that may have a different grant set.
- [MethodImpl(MethodImplOptions.NoInlining)]
- internal static int SizeInBytes(Type t) => SizeInBytes(Activator.CreateInstance(t));
-
- // Get the m_size of the serialized stream for this type, in bytes.
- internal static int SizeInBytes(object instance)
- {
- Type t = instance.GetType();
- Format k = GetFormat(t);
- DummyStream stream = new DummyStream();
- Serializer ser = GetSerializer(instance.GetType());
- ser.Serialize(stream, instance);
- return (int)stream.Length;
- }
-
- internal static void Serialize(Stream s, object instance)
- {
- GetSerializer(instance.GetType()).Serialize(s, instance);
- }
-
- internal static object Deserialize(Stream s, Type resultType) => GetSerializer(resultType).Deserialize(s);
-
- private static Format GetFormat(Type t) => GetUdtAttribute(t).Format;
-
- // Cache the relationship between a type and its serializer.
- // This is expensive to compute since it involves traversing the
- // custom attributes of the type using reflection.
- //
- // Use a per-thread cache, so that there are no synchronization
- // issues when accessing cache entries from multiple threads.
- [ThreadStatic]
- private static Hashtable s_types2Serializers;
-
- private static Serializer GetSerializer(Type t)
- {
- if (s_types2Serializers == null)
- s_types2Serializers = new Hashtable();
-
- Serializer s = (Serializer)s_types2Serializers[t];
- if (s == null)
- {
- s = GetNewSerializer(t);
- s_types2Serializers[t] = s;
- }
- return s;
- }
-
- internal static int GetUdtMaxLength(Type t)
- {
- SqlUdtInfo udtInfo = SqlUdtInfo.GetFromType(t);
-
- if (Format.Native == udtInfo.SerializationFormat)
- {
- // In the native format, the user does not specify the
- // max byte size, it is computed from the type definition
- return SerializationHelperSql9.SizeInBytes(t);
- }
- else
- {
- // In all other formats, the user specifies the maximum size in bytes.
- return udtInfo.MaxByteSize;
- }
- }
-
- private static object[] GetCustomAttributes(Type t)
- {
- object[] attrs = t.GetCustomAttributes(typeof(SqlUserDefinedTypeAttribute), false);
-
- // If we don't find a Microsoft.Data.SqlClient.Server.SqlUserDefinedTypeAttribute,
- // search for a Microsoft.SqlServer.Server.SqlUserDefinedTypeAttribute from the
- // old System.Data.SqlClient assembly and copy it to our
- // Microsoft.Data.SqlClient.Server.SqlUserDefinedTypeAttribute for reference.
- if (attrs == null || attrs.Length == 0)
- {
- object[] attr = t.GetCustomAttributes(false);
- attrs = new object[0];
- if (attr != null && attr.Length > 0)
- {
- for (int i = 0; i < attr.Length; i++)
- {
- if (attr[i].GetType().FullName.Equals("Microsoft.SqlServer.Server.SqlUserDefinedTypeAttribute"))
- {
- SqlUserDefinedTypeAttribute newAttr = null;
- PropertyInfo[] sourceProps = attr[i].GetType().GetProperties();
-
- foreach (PropertyInfo sourceProp in sourceProps)
- {
- if (sourceProp.Name.Equals("Format"))
- {
- newAttr = new SqlUserDefinedTypeAttribute((Format)sourceProp.GetValue(attr[i], null));
- break;
- }
- }
- if (newAttr != null)
- {
- foreach (PropertyInfo targetProp in newAttr.GetType().GetProperties())
- {
- if (targetProp.CanRead && targetProp.CanWrite)
- {
- object copyValue = attr[i].GetType().GetProperty(targetProp.Name).GetValue(attr[i]);
- targetProp.SetValue(newAttr, copyValue);
- }
- }
- }
-
- attrs = new object[1] { newAttr };
- break;
- }
- }
- }
- }
-
- return attrs;
- }
-
- internal static SqlUserDefinedTypeAttribute GetUdtAttribute(Type t)
- {
- SqlUserDefinedTypeAttribute udtAttr = null;
- object[] attr = GetCustomAttributes(t);
-
- if (attr != null && attr.Length == 1)
- {
- udtAttr = (SqlUserDefinedTypeAttribute)attr[0];
- }
- else
- {
- Type InvalidUdtExceptionType = typeof(InvalidUdtException);
- var arguments = new Type[] { typeof(Type), typeof(String) };
- MethodInfo Create = InvalidUdtExceptionType.GetMethod("Create", arguments);
- Create.Invoke(null, new object[] { t, Strings.SqlUdtReason_NoUdtAttribute });
- }
- return udtAttr;
- }
-
- // Create a new serializer for the given type.
- private static Serializer GetNewSerializer(Type t)
- {
- SqlUserDefinedTypeAttribute udtAttr = GetUdtAttribute(t);
-
- switch (udtAttr.Format)
- {
- case Format.Native:
- return new NormalizedSerializer(t);
- case Format.UserDefined:
- return new BinarySerializeSerializer(t);
- case Format.Unknown: // should never happen, but fall through
- default:
- throw ADP.InvalidUserDefinedTypeSerializationFormat(udtAttr.Format);
- }
- }
- }
-
- // The base serializer class.
- internal abstract class Serializer
- {
- public abstract object Deserialize(Stream s);
- public abstract void Serialize(Stream s, object o);
- protected Type _type;
-
- protected Serializer(Type t)
- {
- _type = t;
- }
- }
-
- internal sealed class NormalizedSerializer : Serializer
- {
- private BinaryOrderedUdtNormalizer _normalizer;
- private bool _isFixedSize;
- private int _maxSize;
-
- internal NormalizedSerializer(Type t) : base(t)
- {
- SqlUserDefinedTypeAttribute udtAttr = SerializationHelperSql9.GetUdtAttribute(t);
- _normalizer = new BinaryOrderedUdtNormalizer(t, true);
- _isFixedSize = udtAttr.IsFixedLength;
- _maxSize = _normalizer.Size;
- }
-
- public override void Serialize(Stream s, object o) => _normalizer.NormalizeTopObject(o, s);
-
- public override object Deserialize(Stream s) => _normalizer.DeNormalizeTopObject(_type, s);
- }
-
- internal sealed class BinarySerializeSerializer : Serializer
- {
- internal BinarySerializeSerializer(Type t) : base(t)
- {
- }
-
- public override void Serialize(Stream s, object o)
- {
- BinaryWriter w = new BinaryWriter(s);
- if (o is Microsoft.SqlServer.Server.IBinarySerialize)
- {
- ((SqlServer.Server.IBinarySerialize)o).Write(w);
- }
- else
- {
- ((IBinarySerialize)o).Write(w);
- }
- }
-
- // Prevent inlining so that reflection calls are not moved
- // to a caller that may be in a different assembly that may
- // have a different grant set.
- [MethodImpl(MethodImplOptions.NoInlining)]
- public override object Deserialize(Stream s)
- {
- object instance = Activator.CreateInstance(_type);
- BinaryReader r = new BinaryReader(s);
- if (instance is Microsoft.SqlServer.Server.IBinarySerialize)
- {
- ((SqlServer.Server.IBinarySerialize)instance).Read(r);
- }
- else
- {
- ((IBinarySerialize)instance).Read(r);
- }
- return instance;
- }
- }
-
- // A dummy stream class, used to get the number of bytes written
- // to the stream.
- internal sealed class DummyStream : Stream
- {
- private long _size;
-
- public DummyStream()
- {
- }
-
- private void DontDoIt()
- {
- throw new Exception(StringsHelper.GetString(Strings.Sql_InternalError));
- }
-
- public override bool CanRead => false;
-
- public override bool CanWrite => true;
-
- public override bool CanSeek => false;
-
- public override long Position
- {
- get => _size;
- set =>_size = value;
- }
-
- public override long Length => _size;
-
- public override void SetLength(long value) => _size = value;
-
- public override long Seek(long value, SeekOrigin loc)
- {
- DontDoIt();
- return -1;
- }
-
- public override void Flush()
- {
- }
-
- public override int Read(byte[] buffer, int offset, int count)
- {
- DontDoIt();
- return -1;
- }
-
- public override void Write(byte[] buffer, int offset, int count) => _size += count;
- }
-}
diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/Server/SqlSer.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/SqlSer.cs
similarity index 89%
rename from src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/Server/SqlSer.cs
rename to src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/SqlSer.cs
index c9f1536f50..cf510834b6 100644
--- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/Server/SqlSer.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/Server/SqlSer.cs
@@ -3,7 +3,7 @@
// See the LICENSE file in the project root for more information.
using System;
-using System.Collections;
+using System.Collections.Concurrent;
using System.IO;
using System.Reflection;
using System.Runtime.CompilerServices;
@@ -11,7 +11,7 @@
namespace Microsoft.Data.SqlClient.Server
{
- internal class SerializationHelperSql9
+ internal sealed class SerializationHelperSql9
{
// Don't let anyone create an instance of this class.
private SerializationHelperSql9() { }
@@ -28,7 +28,8 @@ private SerializationHelperSql9() { }
internal static int SizeInBytes(object instance)
{
Type t = instance.GetType();
- Format k = GetFormat(t);
+
+ _ = GetFormat(t);
DummyStream stream = new DummyStream();
Serializer ser = GetSerializer(instance.GetType());
ser.Serialize(stream, instance);
@@ -50,20 +51,22 @@ internal static void Serialize(Stream s, object instance)
//
// Use a per-thread cache, so that there are no synchronization
// issues when accessing cache entries from multiple threads.
- [ThreadStatic]
- private static Hashtable s_types2Serializers;
+ private static ConcurrentDictionary s_types2Serializers;
private static Serializer GetSerializer(Type t)
{
if (s_types2Serializers == null)
- s_types2Serializers = new Hashtable();
+ {
+ s_types2Serializers = new ConcurrentDictionary();
+ }
- Serializer s = (Serializer)s_types2Serializers[t];
- if (s == null)
+ Serializer s;
+ if (!s_types2Serializers.TryGetValue(t, out s))
{
s = GetNewSerializer(t);
s_types2Serializers[t] = s;
}
+
return s;
}
@@ -137,9 +140,8 @@ private static object[] GetCustomAttributes(Type t)
internal static SqlUserDefinedTypeAttribute GetUdtAttribute(Type t)
{
- SqlUserDefinedTypeAttribute udtAttr = null;
+ SqlUserDefinedTypeAttribute udtAttr;
object[] attr = GetCustomAttributes(t);
-
if (attr != null && attr.Length == 1)
{
udtAttr = (SqlUserDefinedTypeAttribute)attr[0];
@@ -155,9 +157,8 @@ internal static SqlUserDefinedTypeAttribute GetUdtAttribute(Type t)
private static Serializer GetNewSerializer(Type t)
{
SqlUserDefinedTypeAttribute udtAttr = GetUdtAttribute(t);
- Format k = GetFormat(t);
-
- switch (k)
+
+ switch (udtAttr.Format)
{
case Format.Native:
return new NormalizedSerializer(t);
@@ -165,7 +166,7 @@ private static Serializer GetNewSerializer(Type t)
return new BinarySerializeSerializer(t);
case Format.Unknown: // should never happen, but fall through
default:
- throw ADP.InvalidUserDefinedTypeSerializationFormat(k);
+ throw ADP.InvalidUserDefinedTypeSerializationFormat(udtAttr.Format);
}
}
}
@@ -183,16 +184,12 @@ internal abstract class Serializer
internal sealed class NormalizedSerializer : Serializer
{
- private BinaryOrderedUdtNormalizer _normalizer;
- private bool _isFixedSize;
- private int _maxSize;
-
+ private readonly BinaryOrderedUdtNormalizer _normalizer;
+
internal NormalizedSerializer(Type t) : base(t)
{
- SqlUserDefinedTypeAttribute udtAttr = SerializationHelperSql9.GetUdtAttribute(t);
+ _ = SerializationHelperSql9.GetUdtAttribute(t);
_normalizer = new BinaryOrderedUdtNormalizer(t, true);
- _isFixedSize = udtAttr.IsFixedLength;
- _maxSize = _normalizer.Size;
}
public override void Serialize(Stream s, object o) => _normalizer.NormalizeTopObject(o, s);
@@ -209,7 +206,16 @@ internal BinarySerializeSerializer(Type t) : base(t)
public override void Serialize(Stream s, object o)
{
BinaryWriter w = new BinaryWriter(s);
+
+#if NETFRAMEWORK
+ if (o is SqlServer.Server.IBinarySerialize bs)
+ {
+ (bs).Write(w);
+ return;
+ }
+#endif
((IBinarySerialize)o).Write(w);
+
}
// Prevent inlining so that reflection calls are not moved
@@ -220,8 +226,17 @@ public override object Deserialize(Stream s)
{
object instance = Activator.CreateInstance(_type);
BinaryReader r = new BinaryReader(s);
- ((IBinarySerialize)instance).Read(r);
+
+#if NETFRAMEWORK
+ if (instance is SqlServer.Server.IBinarySerialize bs)
+ {
+ bs.Read(r);
+ return instance;
+ }
+#endif
+ ((IBinarySerialize)instance).Read(r);
return instance;
+
}
}
From 330de7652a41d9610d63ff0a752a37ff80e900f4 Mon Sep 17 00:00:00 2001
From: Johnny Pham
Date: Wed, 12 Jan 2022 12:07:25 -0800
Subject: [PATCH 002/123] change to ConcurrentDictionary (#1451)
---
.../src/Microsoft/Data/SqlClient/EnclaveDelegate.Crypto.cs | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.Crypto.cs b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.Crypto.cs
index fbc2b50523..fdd2812d1b 100644
--- a/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.Crypto.cs
+++ b/src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/EnclaveDelegate.Crypto.cs
@@ -10,7 +10,7 @@ namespace Microsoft.Data.SqlClient
{
internal sealed partial class EnclaveDelegate
{
- private static readonly Dictionary s_enclaveProviders = new Dictionary();
+ private static readonly ConcurrentDictionary s_enclaveProviders = new();
///
/// Create a new enclave session
From d9efa74a8934476236e531e12536b56e8e6cc2d1 Mon Sep 17 00:00:00 2001
From: Johnny Pham
Date: Thu, 13 Jan 2022 16:46:48 -0800
Subject: [PATCH 003/123] Test | Add lock when using
ClearSqlConnectionGlobalProvidersk (#1461)
---
.../ExceptionRegisterKeyStoreProvider.cs | 25 ++---
.../ExceptionsAlgorithmErrors.cs | 74 +++++++-------
...ncryptionCertificateStoreProviderShould.cs | 97 +++++++++----------
.../AlwaysEncryptedTests/Utility.cs | 4 +-
4 files changed, 101 insertions(+), 99 deletions(-)
diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionRegisterKeyStoreProvider.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionRegisterKeyStoreProvider.cs
index dfec766011..4c62013843 100644
--- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionRegisterKeyStoreProvider.cs
+++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionRegisterKeyStoreProvider.cs
@@ -85,21 +85,24 @@ public void TestEmptyProviderName()
[Fact]
public void TestCanSetGlobalProvidersOnlyOnce()
{
- Utility.ClearSqlConnectionGlobalProviders();
+ lock (Utility.ClearSqlConnectionGlobalProvidersLock)
+ {
+ Utility.ClearSqlConnectionGlobalProviders();
- IDictionary customProviders =
- new Dictionary()
- {
+ IDictionary customProviders =
+ new Dictionary()
+ {
{ DummyKeyStoreProvider.Name, new DummyKeyStoreProvider() }
- };
- SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders);
+ };
+ SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders);
- InvalidOperationException e = Assert.Throws(
- () => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
- string expectedMessage = SystemDataResourceManager.Instance.TCE_CanOnlyCallOnce;
- Assert.Contains(expectedMessage, e.Message);
+ InvalidOperationException e = Assert.Throws(
+ () => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
+ string expectedMessage = SystemDataResourceManager.Instance.TCE_CanOnlyCallOnce;
+ Assert.Contains(expectedMessage, e.Message);
- Utility.ClearSqlConnectionGlobalProviders();
+ Utility.ClearSqlConnectionGlobalProviders();
+ }
}
[Fact]
diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionsAlgorithmErrors.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionsAlgorithmErrors.cs
index 7395816fb1..e2d8e02b0b 100644
--- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionsAlgorithmErrors.cs
+++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionsAlgorithmErrors.cs
@@ -82,7 +82,7 @@ public void TestInvalidCipherText()
[PlatformSpecific(TestPlatforms.Windows)]
public void TestInvalidAlgorithmVersion()
{
- string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_InvalidAlgorithmVersion,
+ string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_InvalidAlgorithmVersion,
40, "01");
byte[] plainText = Encoding.Unicode.GetBytes("Hello World");
byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);
@@ -112,7 +112,7 @@ public void TestInvalidAuthenticationTag()
[PlatformSpecific(TestPlatforms.Windows)]
public void TestNullColumnEncryptionAlgorithm()
{
- string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_NullColumnEncryptionAlgorithm,
+ string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_NullColumnEncryptionAlgorithm,
"'AEAD_AES_256_CBC_HMAC_SHA256'");
Object cipherMD = GetSqlCipherMetadata(0, 0, null, 1, 0x01);
AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, "MSSQL_CERTIFICATE_STORE", "RSA_OAEP");
@@ -148,24 +148,27 @@ public void TestUnknownEncryptionAlgorithmId()
[PlatformSpecific(TestPlatforms.Windows)]
public void TestUnknownCustomKeyStoreProvider()
{
- // Clear out the existing providers (to ensure test reliability)
- ClearSqlConnectionGlobalProviders();
-
- const string invalidProviderName = "Dummy_Provider";
- string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_UnrecognizedKeyStoreProviderName,
- invalidProviderName, "'MSSQL_CERTIFICATE_STORE', 'MSSQL_CNG_STORE', 'MSSQL_CSP_PROVIDER'", "");
- Object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x03);
- AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, invalidProviderName, "RSA_OAEP");
- byte[] plainText = Encoding.Unicode.GetBytes("HelloWorld");
- byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);
+ lock (Utility.ClearSqlConnectionGlobalProvidersLock)
+ {
+ // Clear out the existing providers (to ensure test reliability)
+ ClearSqlConnectionGlobalProviders();
- Exception decryptEx = Assert.Throws(() => DecryptWithKey(plainText, cipherMD));
- Assert.Contains(expectedMessage, decryptEx.InnerException.Message);
+ const string invalidProviderName = "Dummy_Provider";
+ string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_UnrecognizedKeyStoreProviderName,
+ invalidProviderName, "'MSSQL_CERTIFICATE_STORE', 'MSSQL_CNG_STORE', 'MSSQL_CSP_PROVIDER'", "");
+ Object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x03);
+ AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, invalidProviderName, "RSA_OAEP");
+ byte[] plainText = Encoding.Unicode.GetBytes("HelloWorld");
+ byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);
- Exception encryptEx = Assert.Throws(() => EncryptWithKey(plainText, cipherMD));
- Assert.Contains(expectedMessage, encryptEx.InnerException.Message);
+ Exception decryptEx = Assert.Throws(() => DecryptWithKey(plainText, cipherMD));
+ Assert.Contains(expectedMessage, decryptEx.InnerException.Message);
+
+ Exception encryptEx = Assert.Throws(() => EncryptWithKey(plainText, cipherMD));
+ Assert.Contains(expectedMessage, encryptEx.InnerException.Message);
- ClearSqlConnectionGlobalProviders();
+ ClearSqlConnectionGlobalProviders();
+ }
}
[Fact]
@@ -173,7 +176,7 @@ public void TestUnknownCustomKeyStoreProvider()
public void TestTceUnknownEncryptionAlgorithm()
{
const string unknownEncryptionAlgorithm = "Dummy";
- string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_UnknownColumnEncryptionAlgorithm,
+ string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_UnknownColumnEncryptionAlgorithm,
unknownEncryptionAlgorithm, "'AEAD_AES_256_CBC_HMAC_SHA256'");
Object cipherMD = GetSqlCipherMetadata(0, 0, "Dummy", 1, 0x01);
AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, "MSSQL_CERTIFICATE_STORE", "RSA_OAEP");
@@ -193,7 +196,7 @@ public void TestExceptionsFromCertStore()
{
byte[] corruptedCek = GenerateInvalidEncryptedCek(CertFixture.cek, ECEKCorruption.SIGNATURE);
- string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_KeyDecryptionFailedCertStore,
+ string expectedMessage = string.Format(SystemDataResourceManager.Instance.TCE_KeyDecryptionFailedCertStore,
"MSSQL_CERTIFICATE_STORE", BitConverter.ToString(corruptedCek, corruptedCek.Length - 10, 10));
Object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x01);
@@ -209,27 +212,30 @@ public void TestExceptionsFromCertStore()
[PlatformSpecific(TestPlatforms.Windows)]
public void TestExceptionsFromCustomKeyStore()
{
- string expectedMessage = "Failed to decrypt a column encryption key";
+ lock (Utility.ClearSqlConnectionGlobalProvidersLock)
+ {
+ string expectedMessage = "Failed to decrypt a column encryption key";
- // Clear out the existing providers (to ensure test reliability)
- ClearSqlConnectionGlobalProviders();
+ // Clear out the existing providers (to ensure test reliability)
+ ClearSqlConnectionGlobalProviders();
- IDictionary customProviders = new Dictionary();
- customProviders.Add(DummyKeyStoreProvider.Name, new DummyKeyStoreProvider());
- SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders);
+ IDictionary customProviders = new Dictionary();
+ customProviders.Add(DummyKeyStoreProvider.Name, new DummyKeyStoreProvider());
+ SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders);
- object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x01);
- AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, "DummyProvider", "DummyAlgo");
- byte[] plainText = Encoding.Unicode.GetBytes("HelloWorld");
- byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);
+ object cipherMD = GetSqlCipherMetadata(0, 1, null, 1, 0x01);
+ AddEncryptionKeyToCipherMD(cipherMD, CertFixture.encryptedCek, 0, 0, 0, new byte[] { 0x01, 0x02, 0x03 }, CertFixture.certificatePath, "DummyProvider", "DummyAlgo");
+ byte[] plainText = Encoding.Unicode.GetBytes("HelloWorld");
+ byte[] cipherText = EncryptDataUsingAED(plainText, CertFixture.cek, CColumnEncryptionType.Deterministic);
- Exception decryptEx = Assert.Throws(() => DecryptWithKey(cipherText, cipherMD));
- Assert.Contains(expectedMessage, decryptEx.InnerException.Message);
+ Exception decryptEx = Assert.Throws(() => DecryptWithKey(cipherText, cipherMD));
+ Assert.Contains(expectedMessage, decryptEx.InnerException.Message);
- Exception encryptEx = Assert.Throws(() => EncryptWithKey(cipherText, cipherMD));
- Assert.Contains(expectedMessage, encryptEx.InnerException.Message);
+ Exception encryptEx = Assert.Throws(() => EncryptWithKey(cipherText, cipherMD));
+ Assert.Contains(expectedMessage, encryptEx.InnerException.Message);
- ClearSqlConnectionGlobalProviders();
+ ClearSqlConnectionGlobalProviders();
+ }
}
}
diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/SqlColumnEncryptionCertificateStoreProviderShould.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/SqlColumnEncryptionCertificateStoreProviderShould.cs
index 54dd6bc6be..b0c6297cda 100644
--- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/SqlColumnEncryptionCertificateStoreProviderShould.cs
+++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/SqlColumnEncryptionCertificateStoreProviderShould.cs
@@ -92,11 +92,6 @@ public class SqlColumnEncryptionCertificateStoreProviderWindowsShould : IClassFi
///
private const int CipherTextStartIndex = IVStartIndex + IVLengthInBytes;
- ///
- /// SetCustomColumnEncryptionKeyStoreProvider can be called only once in a process. To workaround that, we use this flag.
- ///
- private static bool s_testCustomEncryptioKeyStoreProviderExecutedOnce = false;
-
[Theory]
[InvalidDecryptionParameters]
[PlatformSpecific(TestPlatforms.Windows)]
@@ -326,55 +321,51 @@ public void TestAeadEncryptionReversal(string dataType, object data, Utility.CCo
[PlatformSpecific(TestPlatforms.Windows)]
public void TestCustomKeyProviderListSetter()
{
- // SqlConnection.RegisterColumnEncryptionKeyStoreProviders can be called only once in a process.
- // This is a workaround to ensure re-runnability of the test.
- if (s_testCustomEncryptioKeyStoreProviderExecutedOnce)
+ lock (Utility.ClearSqlConnectionGlobalProvidersLock)
{
- return;
+ string expectedMessage1 = "Column encryption key store provider dictionary cannot be null. Expecting a non-null value.";
+ // Verify that we are able to set it to null.
+ ArgumentException e1 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(null));
+ Assert.Contains(expectedMessage1, e1.Message);
+
+ // A dictionary holding custom providers.
+ IDictionary customProviders = new Dictionary();
+ customProviders.Add(new KeyValuePair(@"DummyProvider", new DummyKeyStoreProvider()));
+
+ // Verify that setting a provider in the list with null value throws an exception.
+ customProviders.Add(new KeyValuePair(@"CustomProvider", null));
+ string expectedMessage2 = "Null reference specified for key store provider 'CustomProvider'. Expecting a non-null value.";
+ ArgumentNullException e2 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
+ Assert.Contains(expectedMessage2, e2.Message);
+ customProviders.Remove(@"CustomProvider");
+
+ // Verify that setting a provider in the list with an empty provider name throws an exception.
+ customProviders.Add(new KeyValuePair(@"", new DummyKeyStoreProvider()));
+ string expectedMessage3 = "Invalid key store provider name specified. Key store provider names cannot be null or empty";
+ ArgumentNullException e3 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
+ Assert.Contains(expectedMessage3, e3.Message);
+
+ customProviders.Remove(@"");
+
+ // Verify that setting a provider in the list with name that starts with 'MSSQL_' throws an exception.
+ customProviders.Add(new KeyValuePair(@"MSSQL_MyStore", new SqlColumnEncryptionCertificateStoreProvider()));
+ string expectedMessage4 = "Invalid key store provider name 'MSSQL_MyStore'. 'MSSQL_' prefix is reserved for system key store providers.";
+ ArgumentException e4 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
+ Assert.Contains(expectedMessage4, e4.Message);
+
+ customProviders.Remove(@"MSSQL_MyStore");
+
+ // Verify that setting a provider in the list with name that starts with 'MSSQL_' but different case throws an exception.
+ customProviders.Add(new KeyValuePair(@"MsSqL_MyStore", new SqlColumnEncryptionCertificateStoreProvider()));
+ string expectedMessage5 = "Invalid key store provider name 'MsSqL_MyStore'. 'MSSQL_' prefix is reserved for system key store providers.";
+ ArgumentException e5 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
+ Assert.Contains(expectedMessage5, e5.Message);
+
+ customProviders.Remove(@"MsSqL_MyStore");
+
+ // Clear any providers set by other tests.
+ Utility.ClearSqlConnectionGlobalProviders();
}
-
- string expectedMessage1 = "Column encryption key store provider dictionary cannot be null. Expecting a non-null value.";
- // Verify that we are able to set it to null.
- ArgumentException e1 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(null));
- Assert.Contains(expectedMessage1, e1.Message);
-
- // A dictionary holding custom providers.
- IDictionary customProviders = new Dictionary();
- customProviders.Add(new KeyValuePair(@"DummyProvider", new DummyKeyStoreProvider()));
-
- // Verify that setting a provider in the list with null value throws an exception.
- customProviders.Add(new KeyValuePair(@"CustomProvider", null));
- string expectedMessage2 = "Null reference specified for key store provider 'CustomProvider'. Expecting a non-null value.";
- ArgumentNullException e2 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
- Assert.Contains(expectedMessage2, e2.Message);
- customProviders.Remove(@"CustomProvider");
-
- // Verify that setting a provider in the list with an empty provider name throws an exception.
- customProviders.Add(new KeyValuePair(@"", new DummyKeyStoreProvider()));
- string expectedMessage3 = "Invalid key store provider name specified. Key store provider names cannot be null or empty";
- ArgumentNullException e3 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
- Assert.Contains(expectedMessage3, e3.Message);
-
- customProviders.Remove(@"");
-
- // Verify that setting a provider in the list with name that starts with 'MSSQL_' throws an exception.
- customProviders.Add(new KeyValuePair(@"MSSQL_MyStore", new SqlColumnEncryptionCertificateStoreProvider()));
- string expectedMessage4 = "Invalid key store provider name 'MSSQL_MyStore'. 'MSSQL_' prefix is reserved for system key store providers.";
- ArgumentException e4 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
- Assert.Contains(expectedMessage4, e4.Message);
-
- customProviders.Remove(@"MSSQL_MyStore");
-
- // Verify that setting a provider in the list with name that starts with 'MSSQL_' but different case throws an exception.
- customProviders.Add(new KeyValuePair(@"MsSqL_MyStore", new SqlColumnEncryptionCertificateStoreProvider()));
- string expectedMessage5 = "Invalid key store provider name 'MsSqL_MyStore'. 'MSSQL_' prefix is reserved for system key store providers.";
- ArgumentException e5 = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
- Assert.Contains(expectedMessage5, e5.Message);
-
- customProviders.Remove(@"MsSqL_MyStore");
-
- // Clear any providers set by other tests.
- Utility.ClearSqlConnectionGlobalProviders();
}
[Theory]
@@ -502,7 +493,7 @@ public class CEKEncryptionReversalParameters : DataAttribute
{
public override IEnumerable