diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ConversionTests.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ConversionTests.cs new file mode 100644 index 0000000000..26dcf4ef36 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/ConversionTests.cs @@ -0,0 +1,1429 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information.using System; + +using System; +using System.Collections.Generic; +using System.Data; +using System.Data.SqlTypes; +using System.Diagnostics; +using System.Linq; +using System.Reflection; +using System.Text; +using System.Security.Cryptography.X509Certificates; +using Xunit; +using Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted.Setup; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted +{ + [PlatformSpecific(TestPlatforms.Windows)] + public class ConversionTests : IDisposable + { + + private const string IdentityColumnName = "IdentityColumn"; + private const string FirstColumnName = "Column1"; + private const string FirstParamName = "@Param1"; + private const string ColumnEncryptionAlgorithmName = @"AEAD_AES_256_CBC_HMAC_SHA_256"; + private const decimal SmallMoneyMaxValue = 214748.3647M; + private const decimal SmallMoneyMinValue = -214748.3648M; + private const int MaxLength = 10000; + private const int NumberOfRows = 100; + private readonly X509Certificate2 certificate; + private ColumnMasterKey columnMasterKey; + private ColumnEncryptionKey columnEncryptionKey; + private SqlColumnEncryptionCertificateStoreProvider certStoreProvider = new SqlColumnEncryptionCertificateStoreProvider(); + protected List databaseObjects = new List(); + + private class ColumnMetaData + { + public ColumnMetaData(SqlDbType columnType, int columnSize, int precision, int scale, bool useMax) + { + ColumnType = columnType; + ColumnSize = columnSize; + Precision = precision; + Scale = scale; + UseMax = useMax; + } + + public SqlDbType ColumnType { get; set; } + public int ColumnSize { get; set; } + public int Precision { get; set; } + public int Scale { get; set; } + public bool UseMax { get; set; } + } + + public ConversionTests() + { + certificate = CertificateUtility.CreateCertificate(); + columnMasterKey = new CspColumnMasterKey(DatabaseHelper.GenerateUniqueName("CMK"), certificate.Thumbprint); + databaseObjects.Add(columnMasterKey); + + columnEncryptionKey = new ColumnEncryptionKey(DatabaseHelper.GenerateUniqueName("CEK"), + columnMasterKey, + certStoreProvider); + databaseObjects.Add(columnEncryptionKey); + + using (SqlConnection sqlConnection = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConnection.Open(); + databaseObjects.ForEach(o => o.Create(sqlConnection)); + } + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [InlineData(SqlDbType.SmallMoney, SqlDbType.Money)] + [InlineData(SqlDbType.Bit, SqlDbType.TinyInt)] + [InlineData(SqlDbType.Bit, SqlDbType.SmallInt)] + [InlineData(SqlDbType.Bit, SqlDbType.Int)] + [InlineData(SqlDbType.Bit, SqlDbType.BigInt)] + [InlineData(SqlDbType.TinyInt, SqlDbType.SmallInt)] + [InlineData(SqlDbType.TinyInt, SqlDbType.Int)] + [InlineData(SqlDbType.TinyInt, SqlDbType.BigInt)] + [InlineData(SqlDbType.SmallInt, SqlDbType.Int)] + [InlineData(SqlDbType.SmallInt, SqlDbType.BigInt)] + [InlineData(SqlDbType.Int, SqlDbType.BigInt)] + [InlineData(SqlDbType.Binary, SqlDbType.Binary)] + [InlineData(SqlDbType.Binary, SqlDbType.VarBinary)] + [InlineData(SqlDbType.VarBinary, SqlDbType.Binary)] + [InlineData(SqlDbType.VarBinary, SqlDbType.VarBinary)] + [InlineData(SqlDbType.Char, SqlDbType.Char)] + [InlineData(SqlDbType.Char, SqlDbType.VarChar)] // padding whitespace issue, trimEnd for now + [InlineData(SqlDbType.VarChar, SqlDbType.Char)] + [InlineData(SqlDbType.VarChar, SqlDbType.VarChar)] + [InlineData(SqlDbType.NChar, SqlDbType.NChar)] + [InlineData(SqlDbType.NChar, SqlDbType.NVarChar)] + [InlineData(SqlDbType.NVarChar, SqlDbType.NChar)] + [InlineData(SqlDbType.NVarChar, SqlDbType.NVarChar)] + [InlineData(SqlDbType.Time, SqlDbType.Time)] + [InlineData(SqlDbType.DateTime2, SqlDbType.DateTime2)] + [InlineData(SqlDbType.DateTimeOffset, SqlDbType.DateTimeOffset)] + [InlineData(SqlDbType.Float, SqlDbType.Float)] + [InlineData(SqlDbType.Real, SqlDbType.Real)] + public void ConversionSmallerToLargerInsertAndSelect(SqlDbType smallDbType, SqlDbType largeDbType) + { + ColumnMetaData largeColumnInfo = new ColumnMetaData(largeDbType, 0, 1, 1, false); + ColumnMetaData smallColumnInfo = new ColumnMetaData(smallDbType, 0, 1, 1, false); + + // Adjust the size, precision and scale for data types that have one. + AdjustSizePrecisionAndScale(ref largeColumnInfo, ref smallColumnInfo); + + // Create the encrypted and unencrypted table with the proper column types. + string encryptedTableName = DatabaseHelper.GenerateUniqueName("encrypted"); + string unencryptedTableName = DatabaseHelper.GenerateUniqueName("unencrypted"); + + // Create the encrypted and unencrypted table with the proper column types. + CreateTable(largeColumnInfo, encryptedTableName, isEncrypted: true); + CreateTable(largeColumnInfo, unencryptedTableName, isEncrypted: false); + + // Insert data using the smaller type to the tables with the large type. + object[] rawValues = PopulateTablesAndReturnRandomValue(encryptedTableName, unencryptedTableName, smallColumnInfo); + + // Keep the values from unencryptedTable other than the rawValues to perform a select later for DateTime2 and DateTimeOffset. + object[] valuesToSelect = RetriveDataFromDatabase(unencryptedTableName); + + // Now read back everything and make sure the values and types are identical. + CompareTables(encryptedTableName, unencryptedTableName); + + // Now send a query with a predicate using the larger type and confirm that the row that was inserted with the smaller type can still be found. + using (SqlConnection sqlConnectionEncrypted = new SqlConnection(DataTestUtility.TcpConnStr)) + using (SqlConnection sqlConnectionUnencrypted = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConnectionEncrypted.Open(); + sqlConnectionUnencrypted.Open(); + + try + { + // Select each value we just inserted with a predicate and verify that encrypted and unencrypted return the same result. + for (int i = 0; i < NumberOfRows; i++) + { + object value; + + // Use the retrieved values for DateTime2 and DateTimeOffset due to fractional insertion adjustment + if (smallColumnInfo.ColumnType is SqlDbType.DateTime2 || smallColumnInfo.ColumnType is SqlDbType.DateTimeOffset) + { + value = valuesToSelect[i]; + } + else + { + value = rawValues[i]; + } + + using (SqlCommand cmdEncrypted = new SqlCommand(string.Format(@"SELECT {0} FROM [{1}] WHERE {0} = {2}", FirstColumnName, encryptedTableName, FirstParamName), sqlConnectionEncrypted, null, SqlCommandColumnEncryptionSetting.Enabled)) + using (SqlCommand cmdUnencrypted = new SqlCommand(string.Format(@"SELECT {0} FROM [{1}] WHERE {0} = {2}", FirstColumnName, unencryptedTableName, FirstParamName), sqlConnectionUnencrypted, null, SqlCommandColumnEncryptionSetting.Disabled)) + { + SqlParameter paramEncrypted = new SqlParameter(); + paramEncrypted.ParameterName = FirstParamName; + paramEncrypted.SqlDbType = largeDbType; + SetParamSizeScalePrecision(ref paramEncrypted, largeColumnInfo); + paramEncrypted.Value = value; + cmdEncrypted.Parameters.Add(paramEncrypted); + + SqlParameter paramUnencrypted = new SqlParameter(); + paramUnencrypted.ParameterName = FirstParamName; + paramUnencrypted.SqlDbType = largeDbType; + SetParamSizeScalePrecision(ref paramUnencrypted, largeColumnInfo); + paramUnencrypted.Value = value; + cmdUnencrypted.Parameters.Add(paramUnencrypted); + + using (SqlDataReader readerUnencrypted = cmdUnencrypted.ExecuteReader()) + using (SqlDataReader readerEncrypted = cmdEncrypted.ExecuteReader()) + { + // First check that we found some rows. + Assert.True(readerEncrypted.HasRows, @"We didn't find any rows."); + + // Now compare the result. + CompareResults(readerEncrypted, readerUnencrypted); + } + } + } + } + finally + { + // DropTables + DropTableIfExists(sqlConnectionEncrypted, encryptedTableName); + DropTableIfExists(sqlConnectionUnencrypted, unencryptedTableName); + } + } + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [InlineData(SqlDbType.SmallMoney, SqlDbType.Money)] + [InlineData(SqlDbType.Bit, SqlDbType.TinyInt)] + [InlineData(SqlDbType.Bit, SqlDbType.SmallInt)] + [InlineData(SqlDbType.Bit, SqlDbType.Int)] + [InlineData(SqlDbType.Bit, SqlDbType.BigInt)] + [InlineData(SqlDbType.TinyInt, SqlDbType.SmallInt)] + [InlineData(SqlDbType.TinyInt, SqlDbType.Int)] + [InlineData(SqlDbType.TinyInt, SqlDbType.BigInt)] + [InlineData(SqlDbType.SmallInt, SqlDbType.Int)] + [InlineData(SqlDbType.SmallInt, SqlDbType.BigInt)] + [InlineData(SqlDbType.Int, SqlDbType.BigInt)] + [InlineData(SqlDbType.Binary, SqlDbType.Binary)] + [InlineData(SqlDbType.Binary, SqlDbType.VarBinary)] + [InlineData(SqlDbType.VarBinary, SqlDbType.Binary)] + [InlineData(SqlDbType.VarBinary, SqlDbType.VarBinary)] + [InlineData(SqlDbType.Char, SqlDbType.Char)] // padding whitespace issue + [InlineData(SqlDbType.Char, SqlDbType.VarChar)] // padding whitespace issue + [InlineData(SqlDbType.VarChar, SqlDbType.Char)] + [InlineData(SqlDbType.VarChar, SqlDbType.VarChar)] + [InlineData(SqlDbType.NChar, SqlDbType.NChar)] + [InlineData(SqlDbType.NChar, SqlDbType.NVarChar)] + [InlineData(SqlDbType.NVarChar, SqlDbType.NChar)] + [InlineData(SqlDbType.NVarChar, SqlDbType.NVarChar)] + [InlineData(SqlDbType.Time, SqlDbType.Time)] + [InlineData(SqlDbType.DateTime2, SqlDbType.DateTime2)] + [InlineData(SqlDbType.DateTimeOffset, SqlDbType.DateTimeOffset)] + [InlineData(SqlDbType.Float, SqlDbType.Float)] + [InlineData(SqlDbType.Real, SqlDbType.Real)] + public void ConversionSmallerToLargerInsertAndSelectBulk(SqlDbType smallDbType, SqlDbType largeDbType) + { + ColumnMetaData largeColumnInfo = new ColumnMetaData(largeDbType, 0, 1, 1, false); + ColumnMetaData smallColumnInfo = new ColumnMetaData(smallDbType, 0, 1, 1, false); + + // Adjust the size, precision and scale for data types that have one. + AdjustSizePrecisionAndScale(ref largeColumnInfo, ref smallColumnInfo); + + string originTableName = DatabaseHelper.GenerateUniqueName("small_type_pt"); + string targetTableName = DatabaseHelper.GenerateUniqueName("large_type_enc"); + string witnessTableName = DatabaseHelper.GenerateUniqueName("large_type_pt"); + + // Create the encrypted and unencrypted table with the proper column types. + CreateTable(smallColumnInfo, originTableName, isEncrypted: false); + CreateTable(largeColumnInfo, targetTableName, isEncrypted: true); + CreateTable(largeColumnInfo, witnessTableName, isEncrypted: false); + + // Insert data using the smaller type to the tables with the large type. + // Also keep the values on the side to perform a select later. + object[] rawValues = PopulateTablesAndReturnRandomValuePlaintextOnly(originTableName, smallColumnInfo); + + // Keep the values from originTable other than the rawValues to perform a select later for DateTime2 and DateTimeOffset. + object[] valuesToSelect = RetriveDataFromDatabase(originTableName); + + // populate the witness table & the target table using bulk insert + portDataToTablePairViaBulkCopy(originTableName, SqlConnectionColumnEncryptionSetting.Disabled, targetTableName, SqlConnectionColumnEncryptionSetting.Enabled); + portDataToTablePairViaBulkCopy(originTableName, SqlConnectionColumnEncryptionSetting.Disabled, witnessTableName, SqlConnectionColumnEncryptionSetting.Disabled); + + // Now read back everything and make sure the values and types are identical. + CompareTables(targetTableName, witnessTableName); + + // Now send a query with a predicate using the larger type and confirm that the row that was inserted with the smaller type can still be found. + using (SqlConnection sqlConnectionEncrypted = new SqlConnection(DataTestUtility.TcpConnStr)) + using (SqlConnection sqlConnectionUnencrypted = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConnectionEncrypted.Open(); + sqlConnectionUnencrypted.Open(); + + try + { + // Select each value we just inserted with a predicate and verify that encrypted and unencrypted return the same result. + for (int i = 0; i < NumberOfRows; i++) + { + object value; + + // Use the retrieved values for DateTime2 and DateTimeOffset due to fractional insertion adjustment + if (smallColumnInfo.ColumnType is SqlDbType.DateTime2 || + smallColumnInfo.ColumnType is SqlDbType.DateTimeOffset || + smallColumnInfo.ColumnType is SqlDbType.Char || + smallColumnInfo.ColumnType is SqlDbType.NChar) + { + value = valuesToSelect[i]; + } + else + { + value = rawValues[i]; + } + + using (SqlCommand cmdEncrypted = new SqlCommand(string.Format(@"SELECT {0} FROM [{1}] WHERE {0} = {2}", FirstColumnName, targetTableName, FirstParamName), sqlConnectionEncrypted, null, SqlCommandColumnEncryptionSetting.Enabled)) + using (SqlCommand cmdUnencrypted = new SqlCommand(string.Format(@"SELECT {0} FROM [{1}] WHERE {0} = {2}", FirstColumnName, witnessTableName, FirstParamName), sqlConnectionUnencrypted, null, SqlCommandColumnEncryptionSetting.Disabled)) + { + SqlParameter paramEncrypted = new SqlParameter(); + paramEncrypted.ParameterName = FirstParamName; + paramEncrypted.SqlDbType = largeDbType; + SetParamSizeScalePrecision(ref paramEncrypted, largeColumnInfo); + paramEncrypted.Value = value; + cmdEncrypted.Parameters.Add(paramEncrypted); + + SqlParameter paramUnencrypted = new SqlParameter(); + paramUnencrypted.ParameterName = FirstParamName; + paramUnencrypted.SqlDbType = largeDbType; + SetParamSizeScalePrecision(ref paramUnencrypted, largeColumnInfo); + paramUnencrypted.Value = value; + cmdUnencrypted.Parameters.Add(paramUnencrypted); + + using (SqlDataReader readerUnencrypted = cmdUnencrypted.ExecuteReader()) + using (SqlDataReader readerEncrypted = cmdEncrypted.ExecuteReader()) + { + // First check that we found some rows. + Assert.True(readerEncrypted.HasRows, @"We didn't find any rows."); + + // Now compare the result. + CompareResults(readerEncrypted, readerUnencrypted); + } + } + } + } + finally + { + // DropTables + DropTableIfExists(sqlConnectionEncrypted, targetTableName); + DropTableIfExists(sqlConnectionUnencrypted, witnessTableName); + DropTableIfExists(sqlConnectionUnencrypted, originTableName); + } + } + } + + [ConditionalTheory(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup))] + [InlineData(SqlDbType.BigInt)] + [InlineData(SqlDbType.Binary)] + [InlineData(SqlDbType.Bit)] + [InlineData(SqlDbType.Char)] + [InlineData(SqlDbType.Date)] + [InlineData(SqlDbType.DateTime)] + [InlineData(SqlDbType.DateTime2)] + [InlineData(SqlDbType.DateTimeOffset)] + [InlineData(SqlDbType.Decimal)] + [InlineData(SqlDbType.Float)] + [InlineData(SqlDbType.Int)] + [InlineData(SqlDbType.Money)] + [InlineData(SqlDbType.NChar)] + [InlineData(SqlDbType.NVarChar)] + [InlineData(SqlDbType.Real)] + [InlineData(SqlDbType.SmallDateTime)] + [InlineData(SqlDbType.SmallInt)] + [InlineData(SqlDbType.SmallMoney)] + [InlineData(SqlDbType.Time)] + [InlineData(SqlDbType.TinyInt)] + [InlineData(SqlDbType.UniqueIdentifier)] + [InlineData(SqlDbType.VarBinary)] + [InlineData(SqlDbType.VarChar)] + public void TestOutOfRangeValues(SqlDbType currentDbType) + { + ColumnMetaData currentColumnInfo = new ColumnMetaData(currentDbType, 0, 1, 1, false); + ColumnMetaData dummyColumnInfo = null; + + // Adjust size, precision and scale if the type has one. + AdjustSizePrecisionAndScale(ref currentColumnInfo, ref dummyColumnInfo); + + // Create the encrypted and unencrypted table with the proper column types. + string encryptedTableName = DatabaseHelper.GenerateUniqueName("encrypted"); + string unencryptedTableName = DatabaseHelper.GenerateUniqueName("unencrypted"); + + // Create the encrypted and unencrypted table with the proper column types. + CreateTable(currentColumnInfo, encryptedTableName, isEncrypted: true); + CreateTable(currentColumnInfo, unencryptedTableName, isEncrypted: false); + + // Generate a list of out of range values, indicating which should fail and which shouldn't. + List valueList = GenerateOutOfRangeValuesForType(currentDbType, currentColumnInfo.ColumnSize, currentColumnInfo.Precision, currentColumnInfo.Scale); + Assert.True(valueList.Count != 0, "Test bug, the list is empty!"); + + using (SqlConnection sqlConnectionEncrypted = new SqlConnection(DataTestUtility.TcpConnStr)) + using (SqlConnection sqlConnectionUnencrypted = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConnectionEncrypted.Open(); + sqlConnectionUnencrypted.Open(); + + try + { + foreach (ValueErrorTuple tuple in valueList) + { + using (SqlCommand sqlCmd = new SqlCommand(String.Format("INSERT INTO [{0}] VALUES ({1})", encryptedTableName, FirstParamName), sqlConnectionEncrypted, null, SqlCommandColumnEncryptionSetting.Enabled)) + { + SqlParameter param = new SqlParameter(); + param.ParameterName = FirstParamName; + param.SqlDbType = currentColumnInfo.ColumnType; + SetParamSizeScalePrecision(ref param, currentColumnInfo); + param.Value = tuple.Value; + sqlCmd.Parameters.Add(param); + + ExecuteAndCheckForError(sqlCmd, tuple.ExpectsError); + } + + // Add same value to the unencrypted table + using (SqlCommand sqlCmd = new SqlCommand(String.Format("INSERT INTO [{0}] VALUES ({1})", unencryptedTableName, FirstParamName), sqlConnectionUnencrypted, null, SqlCommandColumnEncryptionSetting.Disabled)) + { + SqlParameter param = new SqlParameter(); + param.ParameterName = FirstParamName; + param.SqlDbType = currentColumnInfo.ColumnType; + SetParamSizeScalePrecision(ref param, currentColumnInfo); + param.Value = tuple.Value; + sqlCmd.Parameters.Add(param); + + ExecuteAndCheckForError(sqlCmd, tuple.ExpectsError); + } + + } + + CompareTables(encryptedTableName, unencryptedTableName); + } + finally + { + DropTableIfExists(sqlConnectionEncrypted, encryptedTableName); + DropTableIfExists(sqlConnectionUnencrypted, unencryptedTableName); + } + } + } + + + /// + /// Internal class to store a tupple of the value to insert and whether an exception is expected. + /// + private class ValueErrorTuple + { + public ValueErrorTuple(object value, bool expectsError) + { + Value = value; + ExpectsError = expectsError; + } + + public object Value { get; set; } + public bool ExpectsError { get; set; } + } + + /// + /// Generate out of bound values for each data type. + /// + /// + /// + /// + /// + /// + List GenerateOutOfRangeValuesForType(SqlDbType type, int length, int precision, int scale) + { + List list = new List(); + + switch (type) + { + case SqlDbType.Bit: + // Sql actually allows to insert out of bound values for bit and it converts them to a bit value. + list.Add(new ValueErrorTuple(2, false)); + list.Add(new ValueErrorTuple(-1, false)); + break; + case SqlDbType.BigInt: + list.Add(new ValueErrorTuple("9223372036854775808", true)); + list.Add(new ValueErrorTuple("-9223372036854775809", true)); + break; + case SqlDbType.Binary: + case SqlDbType.VarBinary: + { + byte[] upperValueArray = new byte[length + 1]; + Random random = new Random(); + random.NextBytes(upperValueArray); + list.Add(new ValueErrorTuple(upperValueArray, false)); + break; + } + case SqlDbType.Char: + case SqlDbType.NChar: + case SqlDbType.VarChar: + case SqlDbType.NVarChar: + { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < length + 1; i++) + { + sb.Append('a'); + } + list.Add(new ValueErrorTuple(sb.ToString(), false)); + break; + } + case SqlDbType.DateTime: + // This value is out of range and should fail. + list.Add(new ValueErrorTuple(new DateTime(1752, 12, 31, 23, 59, 59, 997), true)); + + // This value has greater scale and it should get truncated, but not fail. + list.Add(new ValueErrorTuple(new DateTime(2014, 1, 1, 23, 59, 59, 998), false)); + break; + case SqlDbType.Int: + list.Add(new ValueErrorTuple((Int64)Int32.MaxValue + 1, true)); + list.Add(new ValueErrorTuple((Int64)Int32.MinValue - 1, true)); + break; + case SqlDbType.Money: + list.Add(new ValueErrorTuple(SqlMoney.MaxValue.Value + (decimal)0.0001, true)); + list.Add(new ValueErrorTuple(SqlMoney.MinValue.Value - (decimal)0.0001, true)); + + // This value has greater scale and it should get truncated, but not fail. + list.Add(new ValueErrorTuple(1.00001, false)); + break; + case SqlDbType.UniqueIdentifier: + list.Add(new ValueErrorTuple(new Guid().ToString() + "1", true)); + list.Add(new ValueErrorTuple(new Guid().ToString().Substring(0, new Guid().ToString().Length - 1), true)); + break; + case SqlDbType.SmallDateTime: + list.Add(new ValueErrorTuple(new DateTime(2079, 6, 7, 0, 0, 0), true)); + + // This value is out of range and should fail. + list.Add(new ValueErrorTuple(new DateTime(1899, 12, 31, 23, 59, 29), true)); + + // This is rounded and inserted properly for both encrypted and unencrypted. + list.Add(new ValueErrorTuple(new DateTime(1899, 12, 31, 23, 59, 59), false)); + + // This value has greater scale and it should get truncated, but not fail. + list.Add(new ValueErrorTuple(new DateTime(2014, 1, 1, 23, 59, 59, 1), false)); + break; + case SqlDbType.SmallInt: + list.Add(new ValueErrorTuple((Int32)Int16.MaxValue + 1, true)); + list.Add(new ValueErrorTuple((Int32)Int16.MinValue - 1, true)); + break; + case SqlDbType.SmallMoney: + list.Add(new ValueErrorTuple((decimal)214748.3648, true)); + list.Add(new ValueErrorTuple((decimal)-214748.3649, true)); + + // This value has greater scale and it should get truncated, but not fail. + list.Add(new ValueErrorTuple((decimal)1.00001, false)); + break; + case SqlDbType.TinyInt: + list.Add(new ValueErrorTuple((Int16)byte.MaxValue + 1, true)); + list.Add(new ValueErrorTuple(-1, true)); + break; + case SqlDbType.Date: + // These values are out of range and should fail. + list.Add(new ValueErrorTuple("10000/1/1", true)); + list.Add(new ValueErrorTuple("0/12/31", true)); + break; + case SqlDbType.Time: + case SqlDbType.DateTime2: + case SqlDbType.DateTimeOffset: + // All values with higher precision will get truncated but not fail. + String timeStringUpper = "23:59:59."; + String timeStringLower = "00:00:00."; + + for (int i = 0; i < scale; i++) + { + timeStringUpper = timeStringUpper + "9"; + timeStringLower = timeStringLower + "0"; + } + + timeStringUpper = timeStringUpper + "1"; + timeStringLower = timeStringLower + "1"; + + if (type == SqlDbType.Time) + { + TimeSpan temp = new TimeSpan(); + TimeSpan.TryParse(timeStringUpper, out temp); + list.Add(new ValueErrorTuple(temp, false)); + TimeSpan.TryParse(timeStringLower, out temp); + list.Add(new ValueErrorTuple(temp, false)); + } + else if (type == SqlDbType.DateTime2) + { + // These values are out of range and should fail. + list.Add(new ValueErrorTuple("10000/1/1 00:00:00", true)); + list.Add(new ValueErrorTuple("0/12/31 23::59:59:999", true)); + + timeStringUpper = "2014/1/1 " + timeStringUpper; + timeStringLower = "2014/1/1 " + timeStringLower; + DateTime temp = new DateTime(); + DateTime.TryParse(timeStringUpper, out temp); + list.Add(new ValueErrorTuple(temp, false)); + DateTime.TryParse(timeStringLower, out temp); + list.Add(new ValueErrorTuple(temp, false)); + } + else if (type == SqlDbType.DateTimeOffset) + { + // These values are out of range and should fail. + list.Add(new ValueErrorTuple("10000/1/1 00:00:00 -14:00", true)); + list.Add(new ValueErrorTuple("0/12/31 23::59:59:999 +14:00", true)); + + timeStringUpper = "2014/1/1 " + timeStringUpper; + timeStringLower = "2014/1/1 " + timeStringLower; + DateTime temp = new DateTime(); + DateTime.TryParse(timeStringUpper, out temp); + list.Add(new ValueErrorTuple(temp, false)); + DateTime.TryParse(timeStringLower, out temp); + list.Add(new ValueErrorTuple(temp, false)); + + // These values are out of range and should fail. + list.Add(new ValueErrorTuple("2014/1/1 10:00 +14:01", true)); + list.Add(new ValueErrorTuple("2014/1/1 10:00 -14:01", true)); + } + break; + case SqlDbType.Decimal: + decimal highPart = 0; + decimal lowPart = 1; + + for (int i = 0; i < precision - scale; i++) + { + if (i == 0) + { + highPart = 1; + } + else + { + highPart *= 10; + } + } + + for (int i = 0; i < scale; i++) + { + lowPart /= 10; + } + + // Construct a value with higher precision than allowed. + list.Add(new ValueErrorTuple(highPart == 0 ? 1 : highPart * 10 + lowPart, true)); + + // This value has greater scale and it should get truncated, but not fail. + // If scale is 0 then this actually fails because of how .NET internally calculates the state. + list.Add(new ValueErrorTuple(highPart + lowPart / 10, scale == 0 ? true : false)); + break; + case SqlDbType.Float: + list.Add(new ValueErrorTuple("1.79770e+308", true)); + list.Add(new ValueErrorTuple("-1.79770e+308", true)); + list.Add(new ValueErrorTuple(Double.PositiveInfinity, true)); + list.Add(new ValueErrorTuple(Double.NegativeInfinity, true)); + list.Add(new ValueErrorTuple(Double.NaN, true)); + list.Add(new ValueErrorTuple(Double.Epsilon, false)); + break; + case SqlDbType.Real: + list.Add(new ValueErrorTuple((double)3.40283e+038, true)); + list.Add(new ValueErrorTuple((double)-3.40283e+038, true)); + list.Add(new ValueErrorTuple(Single.PositiveInfinity, true)); + list.Add(new ValueErrorTuple(Single.NegativeInfinity, true)); + list.Add(new ValueErrorTuple(Single.NaN, true)); + list.Add(new ValueErrorTuple(Single.Epsilon, false)); + break; + default: + Assert.True(false, "We should never get here"); + break; + } + + return list; + } + + /// + /// Check if this exception is expected. + /// + /// + /// + private bool IsExpectedException(Exception e) + { + return e is OverflowException || + e is InvalidCastException || + e is SqlTypeException || + e is ArgumentException || + e is FormatException || + e is SqlException; + } + + /// + /// Try to execute the command and check if there was an error if one was expected. + /// + /// + /// + private void ExecuteAndCheckForError(SqlCommand sqlCmd, bool expectError) + { + if (!expectError) + { + sqlCmd.ExecuteNonQuery(); + } + else + { + try + { + sqlCmd.ExecuteNonQuery(); + Assert.True(false, "We should have gotten an error but passed instead."); + } + catch (Exception e) + { + Type exceptionType = e.GetType(); + if (!IsExpectedException(e)) + { + throw; + } + } + } + } + + /// + /// Adjust the size, scale and precision for the data types that have one. + /// + /// + /// + private void AdjustSizePrecisionAndScale(ref ColumnMetaData largeColumnMeta, ref ColumnMetaData smallColumnMeta) + { + Random random = new Random(); + + if (TypeHasSize(largeColumnMeta.ColumnType)) + { + // 20% of the time use (max) as the length. + largeColumnMeta.UseMax = (largeColumnMeta.ColumnType is SqlDbType.VarChar || + largeColumnMeta.ColumnType is SqlDbType.NVarChar || + largeColumnMeta.ColumnType is SqlDbType.VarBinary) && + random.Next(0, 100) < 20; + + int unicodeMaxLength = 3500; + int maxLength = 7500; + + if (largeColumnMeta.UseMax) + { + largeColumnMeta.ColumnSize = -1; + + if (smallColumnMeta != null) + { + if (largeColumnMeta.ColumnType is SqlDbType.NChar || largeColumnMeta.ColumnType is SqlDbType.NVarChar) + { + smallColumnMeta.ColumnSize = random.Next(1, unicodeMaxLength); + } + else + { + smallColumnMeta.ColumnSize = random.Next(1, maxLength); + } + } + } + else + { + if (largeColumnMeta.ColumnType is SqlDbType.NChar || largeColumnMeta.ColumnType is SqlDbType.NVarChar) + { + largeColumnMeta.ColumnSize = random.Next(2, unicodeMaxLength); + } + else + { + largeColumnMeta.ColumnSize = random.Next(2, maxLength); + } + + if (smallColumnMeta != null) + { + smallColumnMeta.ColumnSize = random.Next(1, largeColumnMeta.ColumnSize); + } + } + } + else if (TypeHasScale(largeColumnMeta.ColumnType)) + { + int precision = 0; + int scale = random.Next(1, 8); + int minScale = 1; + + if (largeColumnMeta.ColumnType is SqlDbType.Decimal) + { + precision = random.Next(1, 28); + scale = random.Next(0, precision + 1); + minScale = 0; + } + + largeColumnMeta.Precision = precision; + largeColumnMeta.Scale = scale; + + if (smallColumnMeta != null) + { + smallColumnMeta.Precision = 0; + + // For Time / DateTime2 / DateTimeOffset types, actual scale is set to 7 when parameter.scale is zero. + // Active Issue in SQLParameter.cs when user wants to specify zero as the actual scale. + smallColumnMeta.Scale = random.Next(minScale, largeColumnMeta.Scale); + } + } + else if (TypeHasPrecision(largeColumnMeta.ColumnType)) + { + largeColumnMeta.Precision = random.Next(2, 54); + largeColumnMeta.Scale = 0; + + if (smallColumnMeta != null) + { + smallColumnMeta.Precision = random.Next(1, largeColumnMeta.Precision); + smallColumnMeta.Scale = 0; + } + } + } + + /// + /// Check if this data type has size. + /// + /// + /// + private bool TypeHasSize(SqlDbType type) + { + return type is SqlDbType.Binary || + type is SqlDbType.VarBinary || + type is SqlDbType.Char || + type is SqlDbType.VarChar || + type is SqlDbType.NChar || + type is SqlDbType.NVarChar; + } + + /// + /// Check if this data type has scale. + /// + /// + /// + private bool TypeHasScale(SqlDbType type) + { + return type is SqlDbType.Time || + type is SqlDbType.DateTime2 || + type is SqlDbType.DateTimeOffset || + type is SqlDbType.Decimal; + } + + /// + /// Check if this data type has precision. + /// + /// + /// + private bool TypeHasPrecision(SqlDbType type) + { + return type is SqlDbType.Decimal; + } + + /// + /// Populate the tables with data of the provided data type. + /// + /// + /// + /// + /// + private object[] PopulateTablesAndReturnRandomValue(string encryptedTableName, string unencryptedTableName, ColumnMetaData columnInfo) + { + object[] valueArray = new object[NumberOfRows]; + + using (SqlConnection sqlConnection = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConnection.Open(); + + for (int i = 0; i < NumberOfRows; i++) + { + valueArray[i] = GenerateRandomValue(columnInfo); + + // Add value to the encrypted table + using (SqlCommand sqlCmd = new SqlCommand(String.Format("INSERT INTO [{0}] VALUES ({1})", encryptedTableName, FirstParamName), sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled)) + { + SqlParameter param = new SqlParameter(); + param.ParameterName = FirstParamName; + param.SqlDbType = columnInfo.ColumnType; + SetParamSizeScalePrecision(ref param, columnInfo); + param.Value = valueArray[i]; + sqlCmd.Parameters.Add(param); + + sqlCmd.ExecuteNonQuery(); + } + + // Add same value to the unencrypted table + using (SqlCommand sqlCmd = new SqlCommand(String.Format("INSERT INTO [{0}] VALUES ({1})", unencryptedTableName, FirstParamName), sqlConnection, null, SqlCommandColumnEncryptionSetting.Enabled)) + { + SqlParameter param = new SqlParameter(); + param.ParameterName = FirstParamName; + param.SqlDbType = columnInfo.ColumnType; + SetParamSizeScalePrecision(ref param, columnInfo); + param.Value = valueArray[i]; + sqlCmd.Parameters.Add(param); + + sqlCmd.ExecuteNonQuery(); + } + } + } + + return valueArray; + } + + /// + /// Populate the tables with data of the provided data type. + /// + /// + /// + /// + private object[] PopulateTablesAndReturnRandomValuePlaintextOnly(string unencryptedTableName, ColumnMetaData columnInfo) + { + object[] valueArray = new object[NumberOfRows]; + + using (SqlConnection sqlConnection = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConnection.Open(); + + for (int i = 0; i < NumberOfRows; i++) + { + valueArray[i] = GenerateRandomValue(columnInfo); + + // Add same value to the unencrypted table + using (SqlCommand sqlCmd = new SqlCommand(String.Format("INSERT INTO [{0}] VALUES ({1})", unencryptedTableName, FirstParamName), sqlConnection, null, SqlCommandColumnEncryptionSetting.Disabled)) + { + SqlParameter param = new SqlParameter(); + param.ParameterName = FirstParamName; + param.SqlDbType = columnInfo.ColumnType; + SetParamSizeScalePrecision(ref param, columnInfo); + param.Value = valueArray[i]; + sqlCmd.Parameters.Add(param); + + sqlCmd.ExecuteNonQuery(); + } + } + } + + return valueArray; + } + + /// + /// Inserts identical data into two tables (for comparison purposes) + /// + /// + /// + /// + /// + private void portDataToTablePairViaBulkCopy(string sourceName, SqlConnectionColumnEncryptionSetting sourceConnectionFlag, string targetName, SqlConnectionColumnEncryptionSetting targetConnectionFlag) + { + SqlConnectionStringBuilder strbld = new SqlConnectionStringBuilder(DataTestUtility.TcpConnStr); + strbld.ColumnEncryptionSetting = sourceConnectionFlag; + + using (SqlConnection sourceConnection = new SqlConnection(strbld.ToString())) + { + sourceConnection.Open(); + + SqlCommand sourceCmd = sourceConnection.CreateCommand(); + sourceCmd.CommandText = String.Format(@"SELECT * FROM [{0}]", sourceName); + + SqlDataReader reader = sourceCmd.ExecuteReader(); + + strbld.ColumnEncryptionSetting = targetConnectionFlag; + using (SqlConnection targetConnection = new SqlConnection(strbld.ToString())) + { + targetConnection.Open(); + + using (SqlBulkCopy bulkCopy = new SqlBulkCopy(targetConnection)) + { + bulkCopy.DestinationTableName = targetName; + + try + { + bulkCopy.WriteToServer(reader); + } + finally + { + reader.Close(); + } + } + } + } + } + + /// + /// Retrive data from unecrypted table for comparison + /// + /// + /// + private object[] RetriveDataFromDatabase(string unencryptedTableName) + { + object[] valueArray = new object[NumberOfRows]; + int index = 0; + + using (SqlConnection sqlConnection = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConnection.Open(); + + using (SqlCommand cmdUnencrypted = new SqlCommand(String.Format("SELECT {0} FROM [{1}] ORDER BY {2}", FirstColumnName, unencryptedTableName, IdentityColumnName), sqlConnection, null, SqlCommandColumnEncryptionSetting.Disabled)) + { + using (SqlDataReader readerUnencrypted = cmdUnencrypted.ExecuteReader()) + { + Assert.True(readerUnencrypted.HasRows, "We didn't find any rows in unEncryptedTable."); + + while (readerUnencrypted.Read()) + { + valueArray[index] = readerUnencrypted.GetValue(0); + index++; + } + + Assert.True(NumberOfRows == index, String.Format("The number of rows retrieved is {0}", index)); + } + } + } + + return valueArray; + } + + /// + /// Compare the two tables to check that they have identical rows. + /// + /// + /// + private void CompareTables(string encryptedTableName, string unencryptedTableName) + { + using (SqlConnection sqlConnectionEncrypted = new SqlConnection(DataTestUtility.TcpConnStr)) + using (SqlConnection sqlConnectionUnencrypted = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConnectionEncrypted.Open(); + sqlConnectionUnencrypted.Open(); + + // Check that the tables contain identical data for the small types. + using (SqlCommand cmdEncrypted = new SqlCommand(String.Format("SELECT * FROM [{0}] ORDER BY {1}", encryptedTableName, IdentityColumnName), sqlConnectionEncrypted, null, SqlCommandColumnEncryptionSetting.Enabled)) + using (SqlCommand cmdUnencrypted = new SqlCommand(String.Format("SELECT * FROM [{0}] ORDER BY {1}", unencryptedTableName, IdentityColumnName), sqlConnectionUnencrypted, null, SqlCommandColumnEncryptionSetting.Disabled)) + { + using (SqlDataReader readerUnencrypted = cmdUnencrypted.ExecuteReader()) + using (SqlDataReader readerEncrypted = cmdEncrypted.ExecuteReader()) + { + CompareResults(readerEncrypted, readerUnencrypted); + } + } + } + } + + /// + /// Read data using two sqlDataReaders and compare the results. + /// + /// + /// + private void CompareResults(SqlDataReader sqlDataReaderEncrypted, SqlDataReader sqlDataReaderUnencrypted) + { + int rowId = 0; + + while (sqlDataReaderEncrypted.Read()) + { + rowId++; + + Assert.True(sqlDataReaderUnencrypted.HasRows, "Unencrypted reader has less rows than the encrypted."); + + sqlDataReaderUnencrypted.Read(); + + for (int i = 0; i < sqlDataReaderEncrypted.FieldCount; i++) + { + Assert.True(sqlDataReaderEncrypted.GetDataTypeName(i).Equals(sqlDataReaderUnencrypted.GetDataTypeName(i)), string.Format("The types for column '{0}' are not identical.", sqlDataReaderEncrypted.GetName(i))); + Assert.True(sqlDataReaderEncrypted.GetValue(i).GetType().Equals(sqlDataReaderUnencrypted.GetValue(i).GetType()), string.Format("The types of the value read for row '{0}' column '{1}' are not identical", rowId, sqlDataReaderEncrypted.GetName(i))); + + object encryptedValue = sqlDataReaderEncrypted.GetValue(i); + object unencryptedValue = sqlDataReaderUnencrypted.GetValue(i); + if (sqlDataReaderEncrypted.GetDataTypeName(i) == "binary" || sqlDataReaderEncrypted.GetDataTypeName(i) == "varbinary") + { + Assert.True(((byte[])encryptedValue).SequenceEqual((byte[])unencryptedValue), string.Format("The values read for row '{0}' column '{1}' are not identical", rowId, sqlDataReaderEncrypted.GetName(i))); + } + else if (sqlDataReaderEncrypted.GetDataTypeName(i) == "char" || sqlDataReaderEncrypted.GetDataTypeName(i) == "varchar" || + sqlDataReaderEncrypted.GetDataTypeName(i) == "nchar" || sqlDataReaderEncrypted.GetDataTypeName(i) == "nvarchar" ) + { + Assert.True(((string)encryptedValue).TrimEnd().Equals(((string)unencryptedValue).TrimEnd()), string.Format("The values read for row '{0}' column '{1}' are not identical", rowId, sqlDataReaderEncrypted.GetName(i))); + } + else + { + Assert.True(encryptedValue.Equals(unencryptedValue), string.Format("The values read for row '{0}' column '{1}' are not identical", rowId, sqlDataReaderEncrypted.GetName(i))); + } + } + } + + Assert.False(sqlDataReaderUnencrypted.Read(), "Unencrypted reader has more rows than the encrypted."); + } + + + /// + /// Generate random value for insertion according to database column + /// + /// + /// + private object GenerateRandomValue(ColumnMetaData columnInfo) + { + object returnValue; + int year; + int month; + int day; + int hour; + int minute; + int second; + int millisecond; + int count; + long ticks; + + Random rand = new Random(); + bool isNegative = Convert.ToBoolean(rand.Next(0, 2)); + StringBuilder strBuilder = new StringBuilder(); + TimeSpan tempTime; + + switch (columnInfo.ColumnType) + { + case SqlDbType.BigInt: + returnValue = isNegative ? Convert.ToInt64(rand.NextDouble() * Int64.MinValue) : Convert.ToInt64(rand.NextDouble() * Int64.MaxValue); + break; + + case SqlDbType.Bit: + returnValue = Convert.ToBoolean(rand.Next(0, 2)); + break; + + case SqlDbType.Int: + returnValue = rand.Next(); + break; + + case SqlDbType.Date: + year = rand.Next(1, 9999); + month = rand.Next(1, 13); + day = rand.Next(1, 29); + + returnValue = new System.DateTime(year, month, day); + break; + + case SqlDbType.DateTime: + year = rand.Next(1753, 9999); + month = rand.Next(1, 13); + day = rand.Next(1, 28); + hour = rand.Next(0, 24); + minute = rand.Next(0, 60); + second = rand.Next(0, 60); + millisecond = rand.Next(0, 998); + + returnValue = new DateTime(year, month, day, hour, minute, second, millisecond); + break; + + case SqlDbType.Money: + returnValue = isNegative ? Convert.ToDecimal((SqlMoney)rand.NextDouble() * SqlMoney.MinValue) : Convert.ToDecimal((SqlMoney)rand.NextDouble() * SqlMoney.MaxValue); + break; + + case SqlDbType.Real: + returnValue = isNegative ? Convert.ToSingle(rand.NextDouble() * Single.MinValue) : Convert.ToSingle(rand.NextDouble() * Single.MaxValue); + break; + + case SqlDbType.SmallDateTime: + year = rand.Next(1900, 2079); + month = rand.Next(1, 13); + day = rand.Next(1, 28); + hour = rand.Next(0, 24); + minute = rand.Next(0, 60); + second = rand.Next(0, 60); + + returnValue = new DateTime(year, month, day, hour, minute, second); + break; + + case SqlDbType.SmallInt: + returnValue = isNegative ? Convert.ToInt16(rand.NextDouble() * Int16.MinValue) : Convert.ToInt16(rand.NextDouble() * Int16.MaxValue); + break; + + case SqlDbType.SmallMoney: + returnValue = isNegative ? Convert.ToDecimal((decimal)rand.NextDouble() * SmallMoneyMinValue) : Convert.ToDecimal((decimal)rand.NextDouble() * SmallMoneyMaxValue); + break; + + case SqlDbType.TinyInt: + returnValue = Convert.ToByte(rand.Next(Byte.MinValue, Byte.MaxValue + 1)); + break; + + case SqlDbType.Binary: + returnValue = DatabaseHelper.GenerateRandomBytes(columnInfo.ColumnSize); + break; + + case SqlDbType.Char: + returnValue = Encoding.UTF8.GetString(DatabaseHelper.GenerateRandomBytes(columnInfo.ColumnSize)).TrimEnd(); + break; + + case SqlDbType.NChar: + returnValue = Encoding.Unicode.GetString(DatabaseHelper.GenerateRandomBytes(2 * columnInfo.ColumnSize)).TrimEnd(); + break; + + case SqlDbType.DateTime2: + case SqlDbType.DateTimeOffset: + year = rand.Next(1, 9999); + month = rand.Next(1, 13); + day = rand.Next(1, 28); + hour = rand.Next(0, 24); + minute = rand.Next(0, 60); + second = rand.Next(0, 60); + + strBuilder.Clear(); + count = columnInfo.Scale > 3 ? 3 : columnInfo.Scale; + + while (count > 0) + { + strBuilder.Append("9"); + count--; + } + + millisecond = (0 == strBuilder.Length) ? 0 : rand.Next(0, Int32.Parse(strBuilder.ToString())); + + if (SqlDbType.DateTime2 == columnInfo.ColumnType) + { + returnValue = new DateTime(year, month, day, hour, minute, second, millisecond); + } + else + { + returnValue = new DateTimeOffset(year, month, day, hour, minute, second, millisecond, new TimeSpan(rand.Next(-14, 15), 0, 0)); + } + break; + + case SqlDbType.Time: + ticks = Convert.ToInt64(rand.NextDouble() * (TimeSpan.TicksPerDay - 1)); + strBuilder.Clear(); + + count = columnInfo.Scale; + + if (0 == count) + { + strBuilder.Append(@"hh\:mm\:ss"); + } + else + { + strBuilder.Append(@"hh\:mm\:ss\."); + } + + while (count > 0) + { + strBuilder.Append("f"); + count--; + } + + tempTime = new TimeSpan(ticks); + returnValue = TimeSpan.Parse(tempTime.ToString(strBuilder.ToString())); + break; + + case SqlDbType.Decimal: + returnValue = isNegative ? Convert.ToDecimal((decimal)rand.NextDouble() * Decimal.MinValue) : Convert.ToDecimal((decimal)rand.NextDouble() * Decimal.MaxValue); + break; + + case SqlDbType.Float: + returnValue = isNegative ? rand.NextDouble() * Double.MinValue : rand.NextDouble() * Double.MaxValue; + break; + + case SqlDbType.VarChar: + if (columnInfo.UseMax) + { + returnValue = Encoding.UTF8.GetString(DatabaseHelper.GenerateRandomBytes(MaxLength)).TrimEnd(); + } + else + { + returnValue = Encoding.UTF8.GetString(DatabaseHelper.GenerateRandomBytes(columnInfo.ColumnSize)).TrimEnd(); + } + break; + + case SqlDbType.VarBinary: + if (columnInfo.UseMax) + { + returnValue = DatabaseHelper.GenerateRandomBytes(MaxLength); + } + else + { + returnValue = DatabaseHelper.GenerateRandomBytes(columnInfo.ColumnSize); + } + break; + + case SqlDbType.NVarChar: + if (columnInfo.UseMax) + { + returnValue = Encoding.Unicode.GetString(DatabaseHelper.GenerateRandomBytes(2 * MaxLength)).TrimEnd(); + } + else + { + returnValue = Encoding.Unicode.GetString(DatabaseHelper.GenerateRandomBytes(2 * columnInfo.ColumnSize)).TrimEnd(); + } + break; + + default: + returnValue = Encoding.Unicode.GetString(DatabaseHelper.GenerateRandomBytes(100)).TrimEnd(); + break; + } + + return returnValue; + } + + /// + /// Creates a table with the specified column type. + /// + /// + /// + /// + private void CreateTable(ColumnMetaData columnMeta, string tableName, bool isEncrypted) + { + string columnType = columnMeta.ColumnType.ToString().ToLower(); + string columnInfo = ""; + StringBuilder builder = new StringBuilder(); + + switch (columnMeta.ColumnType) + { + case SqlDbType.BigInt: + case SqlDbType.Bit: + case SqlDbType.Int: + case SqlDbType.Date: + case SqlDbType.DateTime: + case SqlDbType.Money: + case SqlDbType.Real: + case SqlDbType.Float: + case SqlDbType.SmallDateTime: + case SqlDbType.SmallInt: + case SqlDbType.SmallMoney: + case SqlDbType.TinyInt: + case SqlDbType.UniqueIdentifier: + columnInfo = columnType; + break; + + case SqlDbType.Binary: + columnInfo = $@"{columnMeta.ColumnType}({columnMeta.ColumnSize})"; + break; + + case SqlDbType.Char: + case SqlDbType.NChar: + columnInfo = $@"{columnMeta.ColumnType}({columnMeta.ColumnSize}) COLLATE Latin1_General_BIN2"; + break; + + case SqlDbType.DateTime2: + case SqlDbType.DateTimeOffset: + if (columnMeta.Scale >= 0 && columnMeta.Scale <= 7) + { + columnInfo = $@"{columnType}({columnMeta.Scale})"; + } + else + { + columnInfo = $@"{columnType}"; + } + break; + + case SqlDbType.Time: + if (columnMeta.Scale >= 0 && columnMeta.Scale <= 7) + { + columnInfo = $@"{columnType}({columnMeta.Scale})"; + } + break; + + case SqlDbType.Decimal: + builder.Clear(); + builder.Append(columnType); + + // If we have a valid precision + if (columnMeta.Precision >= 1 && columnMeta.Precision <= 38) + { + builder.AppendFormat("({0}", columnMeta.Precision); + + // If we have a valid scale + if (columnMeta.Scale >= 0 && columnMeta.Scale <= columnMeta.Precision) + { + builder.AppendFormat(",{0}", columnMeta.Scale); + } + + builder.Append(")"); + } + + columnInfo = builder.ToString(); + break; + + case SqlDbType.VarBinary: + if (columnMeta.UseMax) + { + columnInfo = $@"{columnType}(max)"; + } + else + { + columnInfo = $@"{columnType}({columnMeta.ColumnSize})"; + } + break; + + case SqlDbType.VarChar: + case SqlDbType.NVarChar: + if (columnMeta.UseMax) + { + columnInfo = $@"{columnType}(max) COLLATE Latin1_General_BIN2"; + } + else + { + columnInfo = $@"{columnType}({columnMeta.ColumnSize}) COLLATE Latin1_General_BIN2"; + } + break; + + default: + columnInfo = "nvarchar(50) COLLATE Latin1_General_BIN2"; + break; + } + + string sql; + + if (isEncrypted) + { + sql = $@"CREATE TABLE [dbo].[{tableName}] + ( + [{IdentityColumnName}] int IDENTITY(1,1), + [{FirstColumnName}] {columnInfo} ENCRYPTED WITH (COLUMN_ENCRYPTION_KEY = [{columnEncryptionKey.Name}], ENCRYPTION_TYPE = DETERMINISTIC, ALGORITHM = '{ColumnEncryptionAlgorithmName}'), + )"; + } + else + { + sql = $@"CREATE TABLE [dbo].[{tableName}] + ( + [{IdentityColumnName}] int IDENTITY(1,1), + [{FirstColumnName}] {columnInfo} + )"; + } + + using (SqlConnection sqlConn = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConn.Open(); + + using (SqlCommand command = sqlConn.CreateCommand()) + { + command.CommandText = sql; + command.ExecuteNonQuery(); + } + } + } + + /// + /// Drop the table if it exists. + /// + private void DropTableIfExists(SqlConnection sqlConnection, string tableName) + { + string cmdText = $@"IF EXISTS (select * from sys.objects where name = '{tableName}') BEGIN DROP TABLE [{tableName}] END"; + using (SqlCommand command = sqlConnection.CreateCommand()) + { + command.CommandText = cmdText; + command.ExecuteNonQuery(); + } + } + + /// + /// Set the parameter size, precision and scale. + /// + /// + /// + private void SetParamSizeScalePrecision(ref SqlParameter param, ColumnMetaData columnMeta) + { + if (TypeHasSize(columnMeta.ColumnType)) + { + param.Size = columnMeta.ColumnSize; + } + + if (TypeHasScale(columnMeta.ColumnType)) + { + param.Scale = (byte)columnMeta.Scale; + } + + if (TypeHasPrecision(columnMeta.ColumnType)) + { + param.Precision = (byte)columnMeta.Precision; + } + } + + + public void Dispose() + { + databaseObjects.Reverse(); + using (SqlConnection sqlConnection = new SqlConnection(DataTestUtility.TcpConnStr)) + { + sqlConnection.Open(); + databaseObjects.ForEach(o => o.Drop(sqlConnection)); + } + } + + + + } + +} \ No newline at end of file diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/CspProviderExt.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/CspProviderExt.cs index 981cddc843..a2c27d41b7 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/CspProviderExt.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/CspProviderExt.cs @@ -115,7 +115,7 @@ public void TestRoundTripWithCSPAndCertStoreProvider() SqlColumnEncryptionCertificateStoreProvider certProvider = new SqlColumnEncryptionCertificateStoreProvider(); SqlColumnEncryptionCspProvider cspProvider = new SqlColumnEncryptionCspProvider(); - byte[] columnEncryptionKey = CertificateUtilityWin.GenerateRandomBytes(32); + byte[] columnEncryptionKey = DatabaseHelper.GenerateRandomBytes(32); byte[] encryptedColumnEncryptionKeyUsingCert = certProvider.EncryptColumnEncryptionKey(certificatePath, @"RSA_OAEP", columnEncryptionKey); byte[] columnEncryptionKeyReturnedCert2CSP = cspProvider.DecryptColumnEncryptionKey(cspPath, @"RSA_OAEP", encryptedColumnEncryptionKeyUsingCert); diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/DatabaseHelper.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/DatabaseHelper.cs index 0c2cc76fb9..0c6c4fd94e 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/DatabaseHelper.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/DatabaseHelper.cs @@ -2,14 +2,13 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; +using System.Security.Cryptography; + namespace Microsoft.Data.SqlClient.ManualTesting.Tests.AlwaysEncrypted { class DatabaseHelper { - private DatabaseHelper() - { - } - /// /// Insert Customer record into table /// @@ -28,5 +27,22 @@ internal static void InsertCustomerData(SqlConnection sqlConnection, string tabl sqlCommand.ExecuteNonQuery(); } } + + /// + /// Generates cryptographically random bytes + /// + /// No of cryptographically random bytes to be generated + /// A byte array containing cryptographically generated random bytes + internal static byte[] GenerateRandomBytes(int length) + { + // Generate random bytes cryptographically. + byte[] randomBytes = new byte[length]; + RNGCryptoServiceProvider rngCsp = new RNGCryptoServiceProvider(); + rngCsp.GetBytes(randomBytes); + + return randomBytes; + } + + internal static string GenerateUniqueName(string baseName) => string.Concat("AE_", baseName, "_", Guid.NewGuid().ToString().Replace('-', '_')); } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/CertificateUtilityWin.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/CertificateUtilityWin.cs index 6d3281fe71..1212fcf5c4 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/CertificateUtilityWin.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/AlwaysEncrypted/TestFixtures/Setup/CertificateUtilityWin.cs @@ -186,21 +186,6 @@ internal static string GetCspPathFromCertificate(X509Certificate2 certificate) return string.Concat(rsaProvider.CspKeyContainerInfo.ProviderName, @"/", rsaProvider.CspKeyContainerInfo.KeyContainerName); } - /// - /// Generates cryptographically random bytes - /// - /// No of cryptographically random bytes to be generated - /// A byte array containing cryptographically generated random bytes - internal static byte[] GenerateRandomBytes(int length) - { - // Generate random bytes cryptographically. - byte[] randomBytes = new byte[length]; - RNGCryptoServiceProvider rngCsp = new RNGCryptoServiceProvider(); - rngCsp.GetBytes(randomBytes); - - return randomBytes; - } - /// /// Removes a certificate from the store. Cleanup purposes. /// diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 451a838768..2e1f03fef6 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -20,6 +20,7 @@ +