From 6c03bbdfd2d1837e590201618f08d439a38c9335 Mon Sep 17 00:00:00 2001 From: Johnny Pham <23270162+johnnypham@users.noreply.github.com> Date: Thu, 22 Apr 2021 15:54:38 -0700 Subject: [PATCH] add RegisterColumnEncryptionKeyStoreProvidersOnConnection --- .../SqlConnection.xml | 18 +++++++ .../netcore/ref/Microsoft.Data.SqlClient.cs | 4 +- .../Microsoft/Data/SqlClient/SqlConnection.cs | 35 ++++++++++++ .../netfx/ref/Microsoft.Data.SqlClient.cs | 2 + .../Microsoft/Data/SqlClient/SqlConnection.cs | 32 +++++++++++ .../ExceptionRegisterKeyStoreProvider.cs | 54 +++++++++++++++++++ .../ManualTests/AlwaysEncrypted/ApiShould.cs | 22 ++++++++ 7 files changed, 166 insertions(+), 1 deletion(-) diff --git a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml index 4cc4448f8b..7eee99980d 100644 --- a/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml +++ b/doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml @@ -1052,6 +1052,24 @@ GO This function was called more than once. + + Dictionary of custom column encryption key providers + Registers the encryption key store providers on the instance. If this function has been called, any providers registered using the static methods will be ignored. This function can be called more than once. This does shallow copying of the dictionary so that the app cannot alter the custom provider list once it has been set. + + A null dictionary was provided. + + -or- + + A string key in the dictionary was null or empty. + + -or- + + An EncryptionKeyStoreProvider value in the dictionary was null. + + + A string key in the dictionary started with "MSSQL_". This prefix is reserved for system providers. + + Gets or sets a value that specifies the diff --git a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs index b6956df25b..e630a03acb 100644 --- a/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netcore/ref/Microsoft.Data.SqlClient.cs @@ -689,7 +689,9 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden public static System.Collections.Generic.IDictionary> ColumnEncryptionTrustedMasterKeyPaths { get { throw null; } } /// public static void RegisterColumnEncryptionKeyStoreProviders(System.Collections.Generic.IDictionary customProviders) { } - /// + /// + public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collections.Generic.IDictionary customProviders) { } + /// [System.ComponentModel.BrowsableAttribute(false)] [System.ComponentModel.DesignerSerializationVisibilityAttribute(0)] public string AccessToken { get { throw null; } set { } } diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs index 33fd700673..7cdd3b56d5 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -91,6 +91,11 @@ private static readonly Dictionary /// private static IReadOnlyDictionary s_globalCustomColumnEncryptionKeyStoreProviders; + /// + /// Per-connection custom providers. It can be provided by the user and can be set more than once. + /// + private IReadOnlyDictionary _customColumnEncryptionKeyStoreProviders; + /// /// Dictionary object holding trusted key paths for various SQL Servers. /// Key to the dictionary is a SQL Server Name @@ -234,6 +239,13 @@ internal static bool TryGetColumnEncryptionKeyStoreProvider(string providerName, return true; } + // instance-level custom provider cache takes precedence over global cache + if (connection._customColumnEncryptionKeyStoreProviders != null && + connection._customColumnEncryptionKeyStoreProviders.Count > 0) + { + return connection._customColumnEncryptionKeyStoreProviders.TryGetValue(providerName, out columnKeyStoreProvider); + } + lock (s_globalCustomColumnEncryptionKeyProvidersLock) { // If custom provider is not set, then return false @@ -264,6 +276,11 @@ internal static List GetColumnEncryptionSystemKeyStoreProviders() /// Combined list of provider names internal static List GetColumnEncryptionCustomKeyStoreProviders(SqlConnection connection) { + if (connection._customColumnEncryptionKeyStoreProviders != null && + connection._customColumnEncryptionKeyStoreProviders.Count > 0) + { + return connection._customColumnEncryptionKeyStoreProviders.Keys.ToList(); + } if (s_globalCustomColumnEncryptionKeyStoreProviders != null) { return s_globalCustomColumnEncryptionKeyStoreProviders.Keys.ToList(); @@ -306,6 +323,24 @@ public static void RegisterColumnEncryptionKeyStoreProviders(IDictionary + public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary customProviders) + { + ValidateCustomProviders(customProviders); + + // Create a temporary dictionary and then add items from the provided dictionary. + // Dictionary constructor does shallow copying by simply copying the provider name and provider reference pairs + // in the provided customerProviders dictionary. + Dictionary customColumnEncryptionKeyStoreProviders = + new Dictionary(customProviders, StringComparer.OrdinalIgnoreCase); + + // Set the dictionary to the ReadOnly dictionary. + // This method can be called more than once. Re-registering a new collection will replace the + // old collection of providers. + _customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders; + } + private static void ValidateCustomProviders(IDictionary customProviders) { // Throw when the provided dictionary is null. diff --git a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs index 16227cc0c8..23c88ff5b5 100644 --- a/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs +++ b/src/Microsoft.Data.SqlClient/netfx/ref/Microsoft.Data.SqlClient.cs @@ -850,6 +850,8 @@ public void Open(SqlConnectionOverrides overrides) { } public override System.Threading.Tasks.Task OpenAsync(System.Threading.CancellationToken cancellationToken) { throw null; } /// public static void RegisterColumnEncryptionKeyStoreProviders(System.Collections.Generic.IDictionary customProviders) { } + /// + public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collections.Generic.IDictionary customProviders) { } /// public void ResetStatistics() { } /// diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs index a7b0c5cfd4..c0dd91e2d7 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs @@ -64,6 +64,9 @@ static private readonly Dictionary /// private static IReadOnlyDictionary s_globalCustomColumnEncryptionKeyStoreProviders; + /// Instance-level list of custom key store providers. It can be set more than once by the user. + private IReadOnlyDictionary _customColumnEncryptionKeyStoreProviders; + // Lock to control setting of s_globalCustomColumnEncryptionKeyStoreProviders private static readonly object s_globalCustomColumnEncryptionKeyProvidersLock = new object(); @@ -161,6 +164,23 @@ static public void RegisterColumnEncryptionKeyStoreProviders(IDictionary + public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary customProviders) + { + ValidateCustomProviders(customProviders); + + // Create a temporary dictionary and then add items from the provided dictionary. + // Dictionary constructor does shallow copying by simply copying the provider name and provider reference pairs + // in the provided customerProviders dictionary. + Dictionary customColumnEncryptionKeyStoreProviders = + new Dictionary(customProviders, StringComparer.OrdinalIgnoreCase); + + // Set the dictionary to the ReadOnly dictionary. + // This method can be called more than once. Re-registering a new collection will replace the + // old collection of providers. + _customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders; + } + private static void ValidateCustomProviders(IDictionary customProviders) { // Throw when the provided dictionary is null. @@ -213,6 +233,13 @@ static internal bool TryGetColumnEncryptionKeyStoreProvider(string providerName, return true; } + // instance-level custom provider cache takes precedence over global cache + if (connection._customColumnEncryptionKeyStoreProviders != null && + connection._customColumnEncryptionKeyStoreProviders.Count > 0) + { + return connection._customColumnEncryptionKeyStoreProviders.TryGetValue(providerName, out columnKeyStoreProvider); + } + lock (s_globalCustomColumnEncryptionKeyProvidersLock) { // If custom provider is not set, then return false @@ -243,6 +270,11 @@ internal static List GetColumnEncryptionSystemKeyStoreProviders() /// Combined list of provider names internal static List GetColumnEncryptionCustomKeyStoreProviders(SqlConnection connection) { + if (connection._customColumnEncryptionKeyStoreProviders != null && + connection._customColumnEncryptionKeyStoreProviders.Count > 0) + { + return connection._customColumnEncryptionKeyStoreProviders.Keys.ToList(); + } if (s_globalCustomColumnEncryptionKeyStoreProviders != null) { return s_globalCustomColumnEncryptionKeyStoreProviders.Keys.ToList(); diff --git a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionRegisterKeyStoreProvider.cs b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionRegisterKeyStoreProvider.cs index 65f875f9ce..d703d4f748 100644 --- a/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionRegisterKeyStoreProvider.cs +++ b/src/Microsoft.Data.SqlClient/tests/FunctionalTests/AlwaysEncryptedTests/ExceptionRegisterKeyStoreProvider.cs @@ -22,6 +22,9 @@ public void TestNullDictionary() ArgumentNullException e = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders)); Assert.Contains(expectedMessage, e.Message); + + e = Assert.Throws(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders)); + Assert.Contains(expectedMessage, e.Message); } [Fact] @@ -35,6 +38,9 @@ public void TestInvalidProviderName() ArgumentException e = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders)); Assert.Contains(expectedMessage, e.Message); + + e = Assert.Throws(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders)); + Assert.Contains(expectedMessage, e.Message); } [Fact] @@ -48,6 +54,9 @@ public void TestNullProviderValue() ArgumentNullException e = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders)); Assert.Contains(expectedMessage, e.Message); + + e = Assert.Throws(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders)); + Assert.Contains(expectedMessage, e.Message); } [Fact] @@ -60,6 +69,9 @@ public void TestEmptyProviderName() ArgumentNullException e = Assert.Throws(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders)); Assert.Contains(expectedMessage, e.Message); + + e = Assert.Throws(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders)); + Assert.Contains(expectedMessage, e.Message); } [Fact] @@ -81,5 +93,47 @@ public void TestCanSetGlobalProvidersOnlyOnce() Utility.ClearSqlConnectionGlobalProviders(); } + + [Fact] + public void TestCanSetInstanceProvidersMoreThanOnce() + { + const string dummyProviderName1 = "DummyProvider1"; + const string dummyProviderName2 = "DummyProvider2"; + const string dummyProviderName3 = "DummyProvider3"; + IDictionary singleKeyStoreProvider = + new Dictionary() + { + {dummyProviderName1, new DummyKeyStoreProvider() } + }; + + IDictionary multipleKeyStoreProviders = + new Dictionary() + { + { dummyProviderName2, new DummyKeyStoreProvider() }, + { dummyProviderName3, new DummyKeyStoreProvider() } + }; + + using (SqlConnection connection = new SqlConnection()) + { + connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(singleKeyStoreProvider); + IReadOnlyDictionary instanceCache = + GetInstanceCacheFromConnection(connection); + Assert.Single(instanceCache); + Assert.True(instanceCache.ContainsKey(dummyProviderName1)); + + connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(multipleKeyStoreProviders); + instanceCache = GetInstanceCacheFromConnection(connection); + Assert.Equal(2, instanceCache.Count); + Assert.True(instanceCache.ContainsKey(dummyProviderName2)); + Assert.True(instanceCache.ContainsKey(dummyProviderName3)); + } + + IReadOnlyDictionary GetInstanceCacheFromConnection(SqlConnection conn) + { + FieldInfo instanceCacheField = conn.GetType().GetField( + "_customColumnEncryptionKeyStoreProviders", BindingFlags.NonPublic | BindingFlags.Instance); + return instanceCacheField.GetValue(conn) as IReadOnlyDictionary; + } + } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs index b5cc271605..30e24f8960 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ApiShould.cs @@ -2160,6 +2160,28 @@ public void TestCustomKeyStoreProviderDuringAeQuery(string connectionString) () => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection)); Assert.Contains(failedToDecryptMessage, ex.Message); Assert.True(ex.InnerException is NotImplementedException); + + // not required provider in instance cache + // it should not fall back to the global cache so the right provider will not be found + connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(notRequiredProvider); + ex = Assert.Throws( + () => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection)); + Assert.Equal(providerNotFoundMessage, ex.Message); + + // required provider in instance cache + // if the instance cache is not empty, it is always checked for the provider. + // => if the provider is found, it must have been retrieved from the instance cache and not the global cache + connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(requiredProvider); + ex = Assert.Throws( + () => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection)); + Assert.Contains(failedToDecryptMessage, ex.Message); + Assert.True(ex.InnerException is NotImplementedException); + + // not required provider will replace the previous entry so required provider will not be found + connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(notRequiredProvider); + ex = Assert.Throws( + () => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection)); + Assert.Equal(providerNotFoundMessage, ex.Message); } void ExecuteQueryThatRequiresCustomKeyStoreProvider(SqlConnection connection)