From 86c4a9b855e70395d1ae895eecfa6fb868c83624 Mon Sep 17 00:00:00 2001 From: Carl Meyertons Date: Mon, 26 Apr 2021 10:38:04 -0500 Subject: [PATCH 1/6] SqlBulkCopy - Leverage Generics to Eliminate Boxing --- .../src/Microsoft.Data.SqlClient.csproj | 1 + .../Data/SqlClient/GenericCastExtensions.cs | 17 + .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 371 ++++++++++----- .../Microsoft/Data/SqlClient/SqlParameter.cs | 42 +- .../src/Microsoft/Data/SqlClient/TdsParser.cs | 258 ++++++----- ....Data.SqlClient.ManualTesting.Tests.csproj | 2 + .../DataConversionErrorMessageTest.cs | 2 +- .../SqlBulkCopyTest/NoBoxingValuesTypes.cs | 431 ++++++++++++++++++ tools/props/Versions.props | 1 + 9 files changed, 882 insertions(+), 243 deletions(-) create mode 100644 src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericCastExtensions.cs create mode 100644 src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj index 4546d7d1ab..e8035d5990 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj @@ -806,6 +806,7 @@ + True True diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericCastExtensions.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericCastExtensions.cs new file mode 100644 index 0000000000..dd2cce4869 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/GenericCastExtensions.cs @@ -0,0 +1,17 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +namespace Microsoft.Data.SqlClient +{ + /// + /// Serves to convert generic to out type by casting to object first. Relies on JIT to optimize out unneccessary casts and prevent double boxing. + /// + internal static class GenericCastExtensions + { + public static TOut GenericCast(this TIn value) + { + return (TOut)(object)value; + } + } +} diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index b0dce3a6f9..8a5dae6785 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -120,8 +120,8 @@ private enum ValueSourceType DbDataReader } - // Enum for specifying SqlDataReader.Get method used - private enum ValueMethod : byte + // Enum for specifying SqlDataReader.Get / IDataReader method used + private enum ValueMethod { GetValue, SqlTypeSqlDecimal, @@ -129,7 +129,19 @@ private enum ValueMethod : byte SqlTypeSqlSingle, DataFeedStream, DataFeedText, - DataFeedXml + DataFeedXml, + GetInt32, + GetString, + GetDouble, + GetDecimal, + GetInt16, + GetInt64, + GetChar, + GetByte, + GetBoolean, + GetDateTime, + GetGuid, + GetFloat } // Used to hold column metadata for SqlDataReader case @@ -873,11 +885,19 @@ private void Dispose(bool disposing) } // Unified method to read a value from the current row - private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out bool isDataFeed, out bool isNull) + // Reads a cell and then writes it. + // Read may block at this moment since there is no getValueAsync or DownStream async at this moment. + // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance. + // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end. + private Task ReadWriteColumnValueAsync(int destRowIndex) { _SqlMetaData metadata = _sortedColumnMappings[destRowIndex]._metadata; int sourceOrdinal = _sortedColumnMappings[destRowIndex]._sourceColumnOrdinal; + bool isSqlType = false; + bool isDataFeed = false; + bool isNull = false; + switch (_rowSourceType) { case ValueSourceType.IDataReader: @@ -887,34 +907,40 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (_dbDataReaderRowSource.IsDBNull(sourceOrdinal)) { - isSqlType = false; - isDataFeed = false; isNull = true; - return DBNull.Value; + return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull); } else { - isSqlType = false; isDataFeed = true; - isNull = false; + + object feedColumnValue; + switch (_currentRowMetadata[destRowIndex].Method) { case ValueMethod.DataFeedStream: - return new StreamDataFeed(_dbDataReaderRowSource.GetStream(sourceOrdinal)); + feedColumnValue = new StreamDataFeed(_dbDataReaderRowSource.GetStream(sourceOrdinal)); + + break; case ValueMethod.DataFeedText: - return new TextDataFeed(_dbDataReaderRowSource.GetTextReader(sourceOrdinal)); + feedColumnValue = new TextDataFeed(_dbDataReaderRowSource.GetTextReader(sourceOrdinal)); + break; case ValueMethod.DataFeedXml: // Only SqlDataReader supports an XmlReader // There is no GetXmlReader on DbDataReader, however if GetValue returns XmlReader we will read it as stream if it is assigned to XML field Debug.Assert(_sqlDataReaderRowSource != null, "Should not be reading row as an XmlReader if bulk copy source is not a SqlDataReader"); - return new XmlDataFeed(_sqlDataReaderRowSource.GetXmlReader(sourceOrdinal)); + feedColumnValue = new XmlDataFeed(_sqlDataReaderRowSource.GetXmlReader(sourceOrdinal)); + break; default: Debug.Fail($"Current column is marked as being a DataFeed, but no DataFeed compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); isDataFeed = false; - object columnValue = _dbDataReaderRowSource.GetValue(sourceOrdinal); - ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); - return columnValue; + feedColumnValue = _dbDataReaderRowSource.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(feedColumnValue, out isNull, out isSqlType); + break; } + + //specifically choosing to use the object overload here to simplify TdsParser logic for the XmlReader scenario + return WriteValueAsync(feedColumnValue, destRowIndex, isSqlType, isDataFeed, isNull); } } // SqlDataReader-specific logic @@ -922,36 +948,30 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (_currentRowMetadata[destRowIndex].IsSqlType) { - INullable value; isSqlType = true; - isDataFeed = false; switch (_currentRowMetadata[destRowIndex].Method) { case ValueMethod.SqlTypeSqlDecimal: - value = _sqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal); - break; + var value = _sqlDataReaderRowSource.GetSqlDecimal(sourceOrdinal); + return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, value.IsNull); case ValueMethod.SqlTypeSqlDouble: // use cast to handle IsNull correctly because no public constructor allows it - value = (SqlDecimal)_sqlDataReaderRowSource.GetSqlDouble(sourceOrdinal); - break; + var dblValue = (SqlDecimal)_sqlDataReaderRowSource.GetSqlDouble(sourceOrdinal); + return WriteValueAsync(dblValue, destRowIndex, isSqlType, isDataFeed, dblValue.IsNull); case ValueMethod.SqlTypeSqlSingle: - // use cast to handle IsNull correctly because no public constructor allows it - value = (SqlDecimal)_sqlDataReaderRowSource.GetSqlSingle(sourceOrdinal); - break; + // use cast to handle value.IsNull correctly because no public constructor allows it + var singleValue = (SqlDecimal)_sqlDataReaderRowSource.GetSqlSingle(sourceOrdinal); + return WriteValueAsync(singleValue, destRowIndex, isSqlType, isDataFeed, singleValue.IsNull); default: Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); - value = (INullable)_sqlDataReaderRowSource.GetSqlValue(sourceOrdinal); - break; + var sqlValue = (INullable)_sqlDataReaderRowSource.GetSqlValue(sourceOrdinal); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, sqlValue.IsNull); } - isNull = value.IsNull; - return value; + } else { - isSqlType = false; - isDataFeed = false; - object value = _sqlDataReaderRowSource.GetValue(sourceOrdinal); isNull = ((value == null) || (value == DBNull.Value)); if ((!isNull) && (metadata.type == SqlDbType.Udt)) @@ -965,30 +985,58 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b Debug.Assert(!(value is INullable) || !((INullable)value).IsNull, "IsDBNull returned false, but GetValue returned a null INullable"); } #endif - return value; + return WriteValueAsync(value, destRowIndex, isSqlType, isDataFeed, isNull); } } else { - isDataFeed = false; - IDataReader rowSourceAsIDataReader = (IDataReader)_rowSource; - - // Only use IsDbNull when streaming is enabled and only for non-SqlDataReader - if ((_enableStreaming) && (_sqlDataReaderRowSource == null) && (rowSourceAsIDataReader.IsDBNull(sourceOrdinal))) + // previously, IsDbNull was only invoked in a non-streaming scenario with a non-SqlDataReader. + // based on the else if above, the non-SqlDataReader check was superfluous + // the new logic to not rely only on IDataReader.GetValue needs DbNull + // this could potentially be a breaking change to custom IDataReader implementations that incorrectly return IsDbNull + if (rowSourceAsIDataReader.IsDBNull(sourceOrdinal)) { - isSqlType = false; isNull = true; - return DBNull.Value; + return WriteValueAsync(DBNull.Value, destRowIndex, isSqlType, isDataFeed, isNull); } else { - object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); - ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); - return columnValue; + switch (_currentRowMetadata[destRowIndex].Method) + { + case ValueMethod.GetInt32: + return WriteValueAsync(rowSourceAsIDataReader.GetInt32(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, false); + case ValueMethod.GetString: + var strValue = rowSourceAsIDataReader.GetString(sourceOrdinal); + isNull = strValue == null; + return WriteValueAsync(strValue, destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDouble: + return WriteValueAsync(rowSourceAsIDataReader.GetDouble(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDecimal: + return WriteValueAsync(rowSourceAsIDataReader.GetDecimal(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetInt16: + return WriteValueAsync(rowSourceAsIDataReader.GetInt16(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetInt64: + return WriteValueAsync(rowSourceAsIDataReader.GetInt64(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetChar: + return WriteValueAsync(rowSourceAsIDataReader.GetChar(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetByte: + return WriteValueAsync(rowSourceAsIDataReader.GetByte(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetBoolean: + return WriteValueAsync(rowSourceAsIDataReader.GetBoolean(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetDateTime: + return WriteValueAsync(rowSourceAsIDataReader.GetDateTime(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetGuid: + return WriteValueAsync(rowSourceAsIDataReader.GetGuid(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + case ValueMethod.GetFloat: + return WriteValueAsync(rowSourceAsIDataReader.GetFloat(sourceOrdinal), destRowIndex, isSqlType, isDataFeed, isNull); + default: + object columnValue = rowSourceAsIDataReader.GetValue(sourceOrdinal); + ADP.IsNullOrSqlType(columnValue, out isNull, out isSqlType); + return WriteValueAsync(columnValue, destRowIndex, isSqlType, isDataFeed, isNull); + } } } - case ValueSourceType.DataTable: case ValueSourceType.RowArray: { @@ -996,6 +1044,7 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b Debug.Assert(sourceOrdinal < _currentRowLength, "inconsistency of length of rows from rowsource!"); isDataFeed = false; + // unfortunately this has to be boxed due to DataRow's API. object currentRowValue = _currentRow[sourceOrdinal]; ADP.IsNullOrSqlType(currentRowValue, out isNull, out isSqlType); @@ -1008,7 +1057,8 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b { if (isSqlType) { - return new SqlDecimal(((SqlSingle)currentRowValue).Value); + var sqlDec = new SqlDecimal(((SqlSingle)currentRowValue).Value); + return WriteValueAsync(sqlDec, destRowIndex, isSqlType, isDataFeed, isNull); } else { @@ -1016,16 +1066,20 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b if (!float.IsNaN(f)) { isSqlType = true; - return new SqlDecimal(f); + return WriteValueAsync(new SqlDecimal(f), destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } - break; } } case ValueMethod.SqlTypeSqlDouble: { if (isSqlType) { - return new SqlDecimal(((SqlDouble)currentRowValue).Value); + var sqlValue = new SqlDecimal(((SqlDouble)currentRowValue).Value); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } else { @@ -1033,33 +1087,40 @@ private object GetValueFromSourceRow(int destRowIndex, out bool isSqlType, out b if (!double.IsNaN(d)) { isSqlType = true; - return new SqlDecimal(d); + return WriteValueAsync(new SqlDecimal(d), destRowIndex, isSqlType, isDataFeed, isNull); + } + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } - break; } } case ValueMethod.SqlTypeSqlDecimal: { if (isSqlType) { - return (SqlDecimal)currentRowValue; + var sqlValue = (SqlDecimal)currentRowValue; + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } else { isSqlType = true; - return new SqlDecimal((decimal)currentRowValue); + var sqlValue = new SqlDecimal((decimal)currentRowValue); + return WriteValueAsync(sqlValue, destRowIndex, isSqlType, isDataFeed, isNull); } } default: { Debug.Fail($"Current column is marked as being a SqlType, but no SqlType compatible method was provided. Method: {_currentRowMetadata[destRowIndex].Method}"); - break; + // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN) + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); } } } - - // If we are here then either the value is null, there was no special storage type for this column or the special storage type wasn't handled (e.g. if the currentRowValue is NaN) - return currentRowValue; + else + { + return WriteValueAsync(currentRowValue, destRowIndex, isSqlType, isDataFeed, isNull); + } } default: { @@ -1260,6 +1321,66 @@ private SourceColumnMetadata GetColumnMetadata(int ordinal) method = ValueMethod.GetValue; } } + else if (_rowSourceType == ValueSourceType.IDataReader) + { + isSqlType = false; + isDataFeed = false; + + Type t = ((IDataReader)_rowSource).GetFieldType(ordinal); + + if (t == typeof(bool)) + { + method = ValueMethod.GetBoolean; + } + else if (t == typeof(byte)) + { + method = ValueMethod.GetByte; + } + else if (t == typeof(char)) + { + method = ValueMethod.GetChar; + } + else if (t == typeof(DateTime)) + { + method = ValueMethod.GetDateTime; + } + else if (t == typeof(decimal)) + { + method = ValueMethod.GetDecimal; + } + else if (t == typeof(double)) + { + method = ValueMethod.GetDouble; + } + else if (t == typeof(float)) + { + method = ValueMethod.GetFloat; + } + else if (t == typeof(Guid)) + { + method = ValueMethod.GetGuid; + } + else if (t == typeof(short)) + { + method = ValueMethod.GetInt16; + } + else if (t == typeof(int)) + { + method = ValueMethod.GetInt32; + } + else if (t == typeof(long)) + { + method = ValueMethod.GetInt64; + } + else if (t == typeof(string)) + { + method = ValueMethod.GetString; + } + else + { + method = ValueMethod.GetValue; + } + } else { isSqlType = false; @@ -1394,8 +1515,10 @@ private string UnquotedName(string name) return name; } - private object ValidateBulkCopyVariant(object value) + private bool ValidateBulkCopyVariantIfNeeded(T value, out object variantValue) { + variantValue = null; + // From the spec: // "The only acceptable types are ..." // GUID, BIGVARBINARY, BIGBINARY, BIGVARCHAR, BIGCHAR, NVARCHAR, NCHAR, BIT, INT1, INT2, INT4, INT8, @@ -1423,20 +1546,21 @@ private object ValidateBulkCopyVariant(object value) case TdsEnums.SQLDATETIMEOFFSET: if (value is INullable) { // Current limitation in the SqlBulkCopy Variant code limits BulkCopy to CLR/COM Types. - return MetaType.GetComValueFromSqlVariant(value); + variantValue = MetaType.GetComValueFromSqlVariant(value); + return true; } else { - return value; + return false; } default: throw SQL.BulkLoadInvalidVariantValue(); } } - private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, ref bool isSqlType, out bool coercedToDataFeed) + private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, bool isNull, bool isSqlType) { - coercedToDataFeed = false; + bool coercedToDataFeed = false; if (isNull) { @@ -1444,11 +1568,13 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re { throw SQL.BulkLoadBulkLoadNotAllowDBNull(metadata.column); } - return value; + + return DoWriteValueAsync(value, col, isSqlType, coercedToDataFeed, isNull, metadata); } MetaType type = metadata.metaType; bool typeChanged = false; + object objValue = null; // If the column is encrypted then we are going to transparently encrypt this column // (based on connection string setting)- Use the metaType for the underlying @@ -1473,24 +1599,26 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re { case TdsEnums.SQLNUMERICN: case TdsEnums.SQLDECIMALN: - mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); - - // Convert Source Decimal Precision and Scale to Destination Precision and Scale - // Sql decimal data could get corrupted on insert if the scale of - // the source and destination weren't the same. The BCP protocol, specifies the - // scale of the incoming data in the insert statement, we just tell the server we - // are inserting the same scale back. SqlDecimal sqlValue; - if ((isSqlType) && (!typeChanged)) + if (typeof(T) == typeof(decimal)) + { + sqlValue = new SqlDecimal(value.GenericCast()); + } + else if (typeof(T) == typeof(SqlDecimal)) { - sqlValue = (SqlDecimal)value; + sqlValue = value.GenericCast(); } else { - sqlValue = new SqlDecimal((decimal)value); + mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); + sqlValue = new SqlDecimal((decimal)SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false)); } + // Convert Source Decimal Precision and Scale to Destination Precision and Scale + // Sql decimal data could get corrupted on insert if the scale of + // the source and destination weren't the same. The BCP protocol, specifies the + // scale of the incoming data in the insert statement, we just tell the server we + // are inserting the same scale back. if (sqlValue.Scale != scale) { sqlValue = TdsParser.AdjustSqlDecimalScale(sqlValue, scale); @@ -1504,15 +1632,17 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re } catch (SqlTruncateException) { + mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); throw SQL.BulkLoadCannotConvertValue(value.GetType(), mt, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), ADP.ParameterValueOutOfRange(sqlValue)); } } // Perf: It is more efficient to write a SqlDecimal than a decimal since we need to break it into its 'bits' when writing - value = sqlValue; isSqlType = true; typeChanged = false; // Setting this to false as SqlParameter.CoerceValue will only set it to true when converting to a CLR type - break; + + // returning here to avoid unnecessary decValue initialization for all types + return WriteConvertedValue(sqlValue, col, isSqlType, isNull, coercedToDataFeed, metadata); case TdsEnums.SQLINTN: case TdsEnums.SQLFLTN: @@ -1536,16 +1666,22 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re case TdsEnums.SQLDATETIME2: case TdsEnums.SQLDATETIMEOFFSET: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed); break; case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - value = SqlParameter.CoerceValue(value, mt, out coercedToDataFeed, out typeChanged, false); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed, false); if (!coercedToDataFeed) { // We do not need to test for TextDataFeed as it is only assigned to (N)VARCHAR(MAX) - string str = ((isSqlType) && (!typeChanged)) ? ((SqlString)value).Value : ((string)value); + string str = typeChanged + ? (string)objValue + : isSqlType + ? value.GenericCast().Value + : value.GenericCast() + ; + int maxStringLength = length / 2; if (str.Length > maxStringLength) { @@ -1564,8 +1700,7 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re } break; case TdsEnums.SQLVARIANT: - value = ValidateBulkCopyVariant(value); - typeChanged = true; + typeChanged = ValidateBulkCopyVariantIfNeeded(value, out objValue); break; case TdsEnums.SQLUDT: // UDTs are sent as varbinary so we need to get the raw bytes @@ -1576,16 +1711,16 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re // in byte[] form. if (!(value is byte[])) { - value = _connection.GetBytes(value); + objValue = _connection.GetBytes(value); typeChanged = true; } break; case TdsEnums.SQLXMLTYPE: // Could be either string, SqlCachedBuffer, XmlReader or XmlDataFeed Debug.Assert((value is XmlReader) || (value is SqlCachedBuffer) || (value is string) || (value is SqlString) || (value is XmlDataFeed), "Invalid value type of Xml datatype"); - if (value is XmlReader) + if (value is XmlReader xmlReader) { - value = new XmlDataFeed((XmlReader)value); + objValue = new XmlDataFeed(xmlReader); typeChanged = true; coercedToDataFeed = true; } @@ -1595,14 +1730,6 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re Debug.Fail("Unknown TdsType!" + type.NullableType.ToString("x2", (IFormatProvider)null)); throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), null); } - - if (typeChanged) - { - // All type changes change to CLR types - isSqlType = false; - } - - return value; } catch (Exception e) { @@ -1612,6 +1739,17 @@ private object ConvertValue(object value, _SqlMetaData metadata, bool isNull, re } throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), e); } + + if (typeChanged) + { + // All type changes change to CLR types + isSqlType = false; + return WriteConvertedValue(objValue, col, isSqlType, isNull, coercedToDataFeed, metadata); + } + else + { + return WriteConvertedValue(value, col, isSqlType, isNull, coercedToDataFeed, metadata); + } } /// @@ -2135,33 +2273,40 @@ private bool FireRowsCopiedEvent(long rowsCopied) return eventArgs.Abort; } - // Reads a cell and then writes it. - // Read may block at this moment since there is no getValueAsync or DownStream async at this moment. - // When _isAsyncBulkCopy == true: Write will return Task (when async method runs asynchronously) or Null (when async call actually ran synchronously) for performance. - // When _isAsyncBulkCopy == false: Writes are purely sync. This method return null at the end. - private Task ReadWriteColumnValueAsync(int col) + private Task WriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull) { - bool isSqlType; - bool isDataFeed; - bool isNull; - object value = GetValueFromSourceRow(col, out isSqlType, out isDataFeed, out isNull); //this will return Task/null in future: as rTask - _SqlMetaData metadata = _sortedColumnMappings[col]._metadata; - if (!isDataFeed) + if (isDataFeed) + { + //nothing to convert, skip straight to write + return DoWriteValueAsync(value, col, isSqlType, isDataFeed, isNull, metadata); + } + else { - value = ConvertValue(value, metadata, isNull, ref isSqlType, out isDataFeed); + return ConvertWriteValueAsync(value, col, metadata, isNull, isSqlType); + } + } - // If column encryption is requested via connection string option, perform encryption here - if (!isNull && // if value is not NULL - metadata.isEncrypted) - { // If we are transparently encrypting - Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy()); - value = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDataFeed, isSqlType); - isSqlType = false; // Its not a sql type anymore - } + private Task WriteConvertedValue(T value, int col, bool isSqlType, bool isNull, bool isDatafeed, _SqlMetaData metadata) + { + // If column encryption is requested via connection string option, perform encryption here + if (!isNull && // if value is not NULL + metadata.isEncrypted) + { // If we are transparently encrypting + Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy()); + var bytesValue = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDatafeed, isSqlType); + isSqlType = false; // Its not a sql type anymore + + return DoWriteValueAsync(bytesValue, col, isSqlType, isDatafeed, isNull, metadata); } + else + { + return DoWriteValueAsync(value, col, isSqlType, isDatafeed, isNull, metadata); + } + } - //write part + private Task DoWriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull, _SqlMetaData metadata) + { Task writeTask = null; if (metadata.type != SqlDbType.Variant) { @@ -2180,11 +2325,11 @@ private Task ReadWriteColumnValueAsync(int col) if (variantInternalType == SqlBuffer.StorageType.DateTime2) { - _parser.WriteSqlVariantDateTime2(((DateTime)value), _stateObj); + _parser.WriteSqlVariantDateTime2(value.GenericCast(), _stateObj); } else if (variantInternalType == SqlBuffer.StorageType.Date) { - _parser.WriteSqlVariantDate(((DateTime)value), _stateObj); + _parser.WriteSqlVariantDate(value.GenericCast(), _stateObj); } else { diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs index 2bab7e4dbd..7fd2dec220 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs @@ -2084,13 +2084,21 @@ private int ValueSizeCore(object value) // Coerced Value is also used in SqlBulkCopy.ConvertValue(object value, _SqlMetaData metadata) internal static object CoerceValue(object value, MetaType destinationType, out bool coercedToDataFeed, out bool typeChanged, bool allowStreaming = true) + { + typeChanged = CoerceValueIfNeeded(value, destinationType, out var objValue, out coercedToDataFeed, allowStreaming); + + return typeChanged ? objValue : value; + } + + internal static bool CoerceValueIfNeeded(T value, MetaType destinationType, out object objValue, out bool coercedToDataFeed, bool allowStreaming = true) { Debug.Assert(!(value is DataFeed), "Value provided should not already be a data feed"); Debug.Assert(!ADP.IsNull(value), "Value provided should not be null"); Debug.Assert(null != destinationType, "null destinationType"); + objValue = null; coercedToDataFeed = false; - typeChanged = false; + var typeChanged = false; Type currentType = value.GetType(); if ( @@ -2111,45 +2119,45 @@ internal static object CoerceValue(object value, MetaType destinationType, out b // For Xml data, destination Type is always string if (currentType == typeof(SqlXml)) { - value = MetaType.GetStringFromXml((XmlReader)(((SqlXml)value).CreateReader())); + objValue = MetaType.GetStringFromXml(value.GenericCast().CreateReader()); } else if (currentType == typeof(SqlString)) { typeChanged = false; // Do nothing } - else if (typeof(XmlReader).IsAssignableFrom(currentType)) + else if (value is XmlReader xmlReader) { if (allowStreaming) { coercedToDataFeed = true; - value = new XmlDataFeed((XmlReader)value); + objValue = new XmlDataFeed(xmlReader); } else { - value = MetaType.GetStringFromXml((XmlReader)value); + objValue = MetaType.GetStringFromXml(xmlReader); } } else if (currentType == typeof(char[])) { - value = new string((char[])value); + objValue = new string(value.GenericCast()); } else if (currentType == typeof(SqlChars)) { - value = new string(((SqlChars)value).Value); + objValue = new string(value.GenericCast().Value); } else if (value is TextReader textReader && allowStreaming) { coercedToDataFeed = true; - value = new TextDataFeed(textReader); + objValue = new TextDataFeed(textReader); } else { - value = Convert.ChangeType(value, destinationType.ClassType, null); + objValue = Convert.ChangeType(value, destinationType.ClassType, null); } } else if ((destinationType.DbType == DbType.Currency) && (currentType == typeof(string))) { - value = decimal.Parse((string)value, NumberStyles.Currency, null); + objValue = decimal.Parse(value.GenericCast(), NumberStyles.Currency, null); } else if ((currentType == typeof(SqlBytes)) && (destinationType.ClassType == typeof(byte[]))) { @@ -2157,15 +2165,15 @@ internal static object CoerceValue(object value, MetaType destinationType, out b } else if ((currentType == typeof(string)) && (destinationType.SqlDbType == SqlDbType.Time)) { - value = TimeSpan.Parse((string)value); + objValue = TimeSpan.Parse(value.GenericCast()); } else if ((currentType == typeof(string)) && (destinationType.SqlDbType == SqlDbType.DateTimeOffset)) { - value = DateTimeOffset.Parse((string)value, (IFormatProvider)null); + objValue = DateTimeOffset.Parse(value.GenericCast(), (IFormatProvider)null); } else if ((currentType == typeof(DateTime)) && (destinationType.SqlDbType == SqlDbType.DateTimeOffset)) { - value = new DateTimeOffset((DateTime)value); + objValue = new DateTimeOffset(value.GenericCast()); } else if ( TdsEnums.SQLTABLE == destinationType.TDSType && @@ -2182,11 +2190,11 @@ value is IEnumerable else if (destinationType.ClassType == typeof(byte[]) && allowStreaming && value is Stream stream) { coercedToDataFeed = true; - value = new StreamDataFeed(stream); + objValue = new StreamDataFeed(stream); } else { - value = Convert.ChangeType(value, destinationType.ClassType, null); + objValue = Convert.ChangeType(value, destinationType.ClassType, null); } } catch (Exception e) @@ -2201,8 +2209,8 @@ value is IEnumerable } Debug.Assert(allowStreaming || !coercedToDataFeed, "Streaming is not allowed, but type was coerced into a data feed"); - Debug.Assert(value.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged"); - return value; + Debug.Assert(objValue == null || objValue.GetType() == currentType ^ typeChanged, "Incorrect value for typeChanged"); + return typeChanged; } private static int StringSize(object value, bool isSqlType) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs index e0ebfdf669..78da1a5a03 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs @@ -6642,10 +6642,10 @@ internal Task WriteSqlVariantValue(object value, int length, int offset, TdsPars // Therefore the sql_variant value must not include the MaxLength. This is the major difference // between this method and WriteSqlVariantValue above. // - internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject stateObj, bool canAccumulate = true) + internal Task WriteSqlVariantDataRowValue(T value, TdsParserStateObject stateObj, bool canAccumulate = true) { // handle null values - if ((null == value) || (DBNull.Value == value)) + if (null == value || value is DBNull) { WriteInt(TdsEnums.FIXEDNULL, stateObj); return null; @@ -6656,44 +6656,44 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta if (metatype.IsAnsiType) { - length = GetEncodingCharLength((string)value, length, 0, _defaultEncoding); + length = GetEncodingCharLength(value.GenericCast(), length, 0, _defaultEncoding); } switch (metatype.TDSType) { case TdsEnums.SQLFLT4: WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj); - WriteFloat((float)value, stateObj); + WriteFloat(value.GenericCast(), stateObj); break; case TdsEnums.SQLFLT8: WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteDouble((double)value, stateObj); + WriteDouble(value.GenericCast(), stateObj); break; case TdsEnums.SQLINT8: WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteLong((long)value, stateObj); + WriteLong(value.GenericCast(), stateObj); break; case TdsEnums.SQLINT4: WriteSqlVariantHeader(6, metatype.TDSType, metatype.PropBytes, stateObj); - WriteInt((int)value, stateObj); + WriteInt(value.GenericCast(), stateObj); break; case TdsEnums.SQLINT2: WriteSqlVariantHeader(4, metatype.TDSType, metatype.PropBytes, stateObj); - WriteShort((short)value, stateObj); + WriteShort(value.GenericCast(), stateObj); break; case TdsEnums.SQLINT1: WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj); - stateObj.WriteByte((byte)value); + stateObj.WriteByte(value.GenericCast()); break; case TdsEnums.SQLBIT: WriteSqlVariantHeader(3, metatype.TDSType, metatype.PropBytes, stateObj); - if ((bool)value == true) + if (value.GenericCast()) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -6702,7 +6702,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLBIGVARBINARY: { - byte[] b = (byte[])value; + byte[] b = value.GenericCast(); length = b.Length; WriteSqlVariantHeader(4 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -6712,7 +6712,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLBIGVARCHAR: { - string s = (string)value; + string s = value.GenericCast(); length = s.Length; WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -6724,7 +6724,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLUNIQUEID: { - System.Guid guid = (System.Guid)value; + Guid guid = value.GenericCast(); Span b = stackalloc byte[16]; FillGuidBytes(guid, b); @@ -6737,7 +6737,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLNVARCHAR: { - string s = (string)value; + string s = value.GenericCast(); length = s.Length * 2; WriteSqlVariantHeader(9 + length, metatype.TDSType, metatype.PropBytes, stateObj); @@ -6752,7 +6752,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLDATETIME: { - TdsDateTime dt = MetaType.FromDateTime((DateTime)value, 8); + TdsDateTime dt = MetaType.FromDateTime(value.GenericCast(), 8); WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); WriteInt(dt.days, stateObj); @@ -6763,7 +6763,7 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta case TdsEnums.SQLMONEY: { WriteSqlVariantHeader(10, metatype.TDSType, metatype.PropBytes, stateObj); - WriteCurrency((decimal)value, 8, stateObj); + WriteCurrency(value.GenericCast(), 8, stateObj); break; } @@ -6771,21 +6771,22 @@ internal Task WriteSqlVariantDataRowValue(object value, TdsParserStateObject sta { WriteSqlVariantHeader(21, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Precision); //propbytes: precision - stateObj.WriteByte((byte)((decimal.GetBits((decimal)value)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale - WriteDecimal((decimal)value, stateObj); + var decValue = value.GenericCast(); + stateObj.WriteByte((byte)((decimal.GetBits(decValue)[3] & 0x00ff0000) >> 0x10)); // propbytes: scale + WriteDecimal(decValue, stateObj); break; } case TdsEnums.SQLTIME: WriteSqlVariantHeader(8, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Scale); //propbytes: scale - WriteTime((TimeSpan)value, metatype.Scale, 5, stateObj); + WriteTime(value.GenericCast(), metatype.Scale, 5, stateObj); break; case TdsEnums.SQLDATETIMEOFFSET: WriteSqlVariantHeader(13, metatype.TDSType, metatype.PropBytes, stateObj); stateObj.WriteByte(metatype.Scale); //propbytes: scale - WriteDateTimeOffset((DateTimeOffset)value, metatype.Scale, 10, stateObj); + WriteDateTimeOffset(value.GenericCast(), metatype.Scale, 10, stateObj); break; default: @@ -10397,7 +10398,7 @@ internal bool ShouldEncryptValuesForBulkCopy() /// Encrypts a column value (for SqlBulkCopy) /// /// - internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType) + internal byte[] EncryptColumnValue(T value, SqlMetaDataPriv metadata, string column, TdsParserStateObject stateObj, bool isDataFeed, bool isSqlType) { Debug.Assert(IsColumnEncryptionSupported, "Server doesn't support encryption, yet we received encryption metadata"); Debug.Assert(ShouldEncryptValuesForBulkCopy(), "Encryption attempted when not requested"); @@ -10422,7 +10423,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin // when we normalize and serialize the data buffers. The serialization routine expects us // to report the size of data to be copied out (for serialization). If we underreport the // size, truncation will happen for us! - actualLengthInBytes = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; + actualLengthInBytes = (isSqlType) + ? value.GenericCast().Length + : value.GenericCast().Length; + if (metadata.baseTI.length > 0 && actualLengthInBytes > metadata.baseTI.length) { @@ -10442,7 +10446,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin ThrowUnsupportedCollationEncountered(null); // stateObject only when reading } - string stringValue = (isSqlType) ? ((SqlString)value).Value : (string)value; + string stringValue = (isSqlType) + ? value.GenericCast().Value + : value.GenericCast(); + actualLengthInBytes = _defaultEncoding.GetByteCount(stringValue); // If the string length is > max length, then use the max length (see comments above) @@ -10456,7 +10463,10 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: - actualLengthInBytes = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + actualLengthInBytes = (isSqlType + ? value.GenericCast().Value.Length + : value.GenericCast().Length) + * 2; if (metadata.baseTI.length > 0 && actualLengthInBytes > metadata.baseTI.length) @@ -10501,7 +10511,7 @@ internal object EncryptColumnValue(object value, SqlMetaDataPriv metadata, strin _connHandler.Connection); } - internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull) + internal Task WriteBulkCopyValue(T value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isSqlType, bool isDataFeed, bool isNull) { Debug.Assert(!isSqlType || value is INullable, "isSqlType is true, but value can not be type cast to an INullable"); Debug.Assert(!isDataFeed ^ value is DataFeed, "Incorrect value for isDataFeed"); @@ -10558,6 +10568,8 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars return resultTask; } + string stringValue = null; + if (!isDataFeed) { switch (metatype.NullableType) @@ -10566,7 +10578,9 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLBIGVARBINARY: case TdsEnums.SQLIMAGE: case TdsEnums.SQLUDT: - ccb = (isSqlType) ? ((SqlBinary)value).Length : ((byte[])value).Length; + ccb = (isSqlType) + ? value.GenericCast().Length + : value.GenericCast().Length; break; case TdsEnums.SQLUNIQUEID: ccb = GUID_SIZE; @@ -10579,15 +10593,9 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars ThrowUnsupportedCollationEncountered(null); // stateObject only when reading } - string stringValue = null; - if (isSqlType) - { - stringValue = ((SqlString)value).Value; - } - else - { - stringValue = (string)value; - } + stringValue = isSqlType + ? value.GenericCast().Value + : value.GenericCast(); ccb = stringValue.Length; ccbStringBytes = _defaultEncoding.GetByteCount(stringValue); @@ -10595,15 +10603,29 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: - ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + stringValue = stringValue = isSqlType + ? value.GenericCast().Value + : value.GenericCast(); + + ccb = stringValue.Length * 2; break; case TdsEnums.SQLXMLTYPE: // Value here could be string or XmlReader - if (value is XmlReader) + + if (value is XmlReader xr) { - value = MetaType.GetStringFromXml((XmlReader)value); + stringValue = MetaType.GetStringFromXml(xr); } - ccb = ((isSqlType) ? ((SqlString)value).Value.Length : ((string)value).Length) * 2; + else if (isSqlType) + { + stringValue = value.GenericCast().Value; + } + else + { + stringValue = value.GenericCast(); + } + + ccb = stringValue.Length * 2; break; default: @@ -10653,19 +10675,18 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars { internalWriteTask = WriteSqlValue(value, metatype, ccb, ccbStringBytes, 0, stateObj); } + else if (stringValue != null) + { + internalWriteTask = WriteValueWithWait(stringValue, metadata, stateObj, isDataFeed, metatype, ccb, ccbStringBytes); + } else if (metatype.SqlDbType != SqlDbType.Udt || metatype.IsLong) { - internalWriteTask = WriteValue(value, metatype, metadata.scale, ccb, ccbStringBytes, 0, stateObj, metadata.length, isDataFeed); - if ((internalWriteTask == null) && (_asyncWrite)) - { - internalWriteTask = stateObj.WaitForAccumulatedWrites(); - } - Debug.Assert(_asyncWrite || stateObj.WaitForAccumulatedWrites() == null, "Should not have accumulated writes when writing sync"); + internalWriteTask = WriteValueWithWait(value, metadata, stateObj, isDataFeed, metatype, ccb, ccbStringBytes); } else { WriteShort(ccb, stateObj); - internalWriteTask = stateObj.WriteByteArray((byte[])value, ccb, 0); + internalWriteTask = stateObj.WriteByteArray(value.GenericCast(), ccb, 0); } #if DEBUG @@ -10693,6 +10714,17 @@ internal Task WriteBulkCopyValue(object value, SqlMetaDataPriv metadata, TdsPars return resultTask; } + private Task WriteValueWithWait(T value, SqlMetaDataPriv metadata, TdsParserStateObject stateObj, bool isDataFeed, MetaType metatype, int ccb, int ccbStringBytes) + { + Task internalWriteTask = WriteValue(value, metatype, metadata.scale, ccb, ccbStringBytes, 0, stateObj, metadata.length, isDataFeed); + if ((internalWriteTask == null) && (_asyncWrite)) + { + internalWriteTask = stateObj.WaitForAccumulatedWrites(); + } + Debug.Assert(_asyncWrite || stateObj.WaitForAccumulatedWrites() == null, "Should not have accumulated writes when writing sync"); + return internalWriteTask; + } + // This is in its own method to avoid always allocating the lambda in WriteBulkCopyValue private Task WriteBulkCopyValueSetupContinuation(Task internalWriteTask, Encoding saveEncoding, SqlCollation saveCollation, int saveCodePage, int saveLCID) { @@ -10974,7 +11006,7 @@ private bool IsBOMNeeded(MetaType type, object value) return false; } - private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed) + private Task GetTerminationTask(Task unterminatedWriteTask, MetaType type, int actualLength, TdsParserStateObject stateObj, bool isDataFeed) { if (type.IsPlp && ((actualLength > 0) || isDataFeed)) { @@ -10995,16 +11027,16 @@ private Task GetTerminationTask(Task unterminatedWriteTask, object value, MetaTy } - private Task WriteSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) + private Task WriteSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) { return GetTerminationTask( WriteUnterminatedSqlValue(value, type, actualLength, codePageByteSize, offset, stateObj), - value, type, actualLength, stateObj, false); + type, actualLength, stateObj, false); } // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) + private Task WriteUnterminatedSqlValue(T value, MetaType type, int actualLength, int codePageByteSize, int offset, TdsParserStateObject stateObj) { Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) || (value is INullable && !((INullable)value).IsNull)), @@ -11015,11 +11047,11 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - WriteFloat(((SqlSingle)value).Value, stateObj); + WriteFloat(value.GenericCast().Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - WriteDouble(((SqlDouble)value).Value, stateObj); + WriteDouble(value.GenericCast().Value, stateObj); } break; @@ -11035,12 +11067,12 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe if (value is SqlBinary) { - return stateObj.WriteByteArray(((SqlBinary)value).Value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast().Value, actualLength, offset, canAccumulate: false); } else { Debug.Assert(value is SqlBytes); - return stateObj.WriteByteArray(((SqlBytes)value).Value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast().Value, actualLength, offset, canAccumulate: false); } } @@ -11048,7 +11080,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe { Debug.Assert(actualLength == 16, "Invalid length for guid type in com+ object"); Span b = stackalloc byte[16]; - SqlGuid sqlGuid = (SqlGuid)value; + SqlGuid sqlGuid = value.GenericCast(); if (sqlGuid.IsNull) { @@ -11066,7 +11098,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLBITN: { Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); - if (((SqlBoolean)value).Value == true) + if (value.GenericCast().Value == true) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -11076,17 +11108,17 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLINTN: if (type.FixedLength == 1) - stateObj.WriteByte(((SqlByte)value).Value); + stateObj.WriteByte(value.GenericCast().Value); else if (type.FixedLength == 2) - WriteShort(((SqlInt16)value).Value, stateObj); + WriteShort(value.GenericCast().Value, stateObj); else if (type.FixedLength == 4) - WriteInt(((SqlInt32)value).Value, stateObj); + WriteInt(value.GenericCast().Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - WriteLong(((SqlInt64)value).Value, stateObj); + WriteLong(value.GenericCast().Value, stateObj); } break; @@ -11100,14 +11132,14 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe } if (value is SqlChars) { - string sch = new string(((SqlChars)value).Value); + string sch = new string(value.GenericCast().Value); return WriteEncodingChar(sch, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } else { Debug.Assert(value is SqlString); - return WriteEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); + return WriteEncodingChar(value.GenericCast().Value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } @@ -11136,21 +11168,21 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe if (value is SqlChars) { - return WriteCharArray(((SqlChars)value).Value, actualLength, offset, stateObj, canAccumulate: false); + return WriteCharArray(value.GenericCast().Value, actualLength, offset, stateObj, canAccumulate: false); } else { Debug.Assert(value is SqlString); - return WriteString(((SqlString)value).Value, actualLength, offset, stateObj, canAccumulate: false); + return WriteString(value.GenericCast().Value, actualLength, offset, stateObj, canAccumulate: false); } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - WriteSqlDecimal((SqlDecimal)value, stateObj); + WriteSqlDecimal(value.GenericCast(), stateObj); break; case TdsEnums.SQLDATETIMN: - SqlDateTime dt = (SqlDateTime)value; + SqlDateTime dt = value.GenericCast(); if (type.FixedLength == 4) { @@ -11170,7 +11202,7 @@ private Task WriteUnterminatedSqlValue(object value, MetaType type, int actualLe case TdsEnums.SQLMONEYN: { - WriteSqlMoney((SqlMoney)value, type.FixedLength, stateObj); + WriteSqlMoney(value.GenericCast(), type.FixedLength, stateObj); break; } @@ -11640,28 +11672,28 @@ private Task NullIfCompletedWriteTask(Task task) } } - private Task WriteValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) + private Task WriteValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) { return GetTerminationTask(WriteUnterminatedValue(value, type, scale, actualLength, encodingByteSize, offset, stateObj, paramSize, isDataFeed), - value, type, actualLength, stateObj, isDataFeed); + type, actualLength, stateObj, isDataFeed); } // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) + private Task WriteUnterminatedValue(T value, MetaType type, byte scale, int actualLength, int encodingByteSize, int offset, TdsParserStateObject stateObj, int paramSize, bool isDataFeed) { - Debug.Assert((null != value) && (DBNull.Value != value), "unexpected missing or empty object"); + Debug.Assert((null != value) && !(value is DBNull), "unexpected missing or empty object"); // parameters are always sent over as BIG or N types switch (type.NullableType) { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - WriteFloat((float)value, stateObj); + WriteFloat(value.GenericCast(), stateObj); else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - WriteDouble((double)value, stateObj); + WriteDouble(value.GenericCast(), stateObj); } break; @@ -11678,7 +11710,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int if (isDataFeed) { Debug.Assert(type.IsPlp, "Stream assigned to non-PLP was not converted!"); - return NullIfCompletedWriteTask(WriteStreamFeed((StreamDataFeed)value, stateObj, paramSize)); + return NullIfCompletedWriteTask(WriteStreamFeed(value.GenericCast(), stateObj, paramSize)); } else { @@ -11686,7 +11718,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { WriteInt(actualLength, stateObj); // chunk length } - return stateObj.WriteByteArray((byte[])value, actualLength, offset, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast(), actualLength, offset, canAccumulate: false); } } @@ -11694,7 +11726,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { Debug.Assert(actualLength == 16, "Invalid length for guid type in com+ object"); Span b = stackalloc byte[16]; - FillGuidBytes((System.Guid)value, b); + FillGuidBytes(value.GenericCast(), b); stateObj.WriteByteSpan(b); break; } @@ -11702,7 +11734,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLBITN: { Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); - if ((bool)value == true) + if (value.GenericCast() == true) stateObj.WriteByte(1); else stateObj.WriteByte(0); @@ -11712,15 +11744,15 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLINTN: if (type.FixedLength == 1) - stateObj.WriteByte((byte)value); + stateObj.WriteByte(value.GenericCast()); else if (type.FixedLength == 2) - WriteShort((short)value, stateObj); + WriteShort(value.GenericCast(), stateObj); else if (type.FixedLength == 4) - WriteInt((int)value, stateObj); + WriteInt(value.GenericCast(), stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - WriteLong((long)value, stateObj); + WriteLong(value.GenericCast(), stateObj); } break; @@ -11738,7 +11770,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int TextDataFeed tdf = value as TextDataFeed; if (tdf == null) { - return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize)); + return NullIfCompletedWriteTask(WriteXmlFeed(value.GenericCast(), stateObj, needBom: true, encoding: _defaultEncoding, size: paramSize)); } else { @@ -11753,11 +11785,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int } if (value is byte[]) { // If LazyMat non-filled blob, send cookie rather than value - return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast(), actualLength, 0, canAccumulate: false); } else { - return WriteEncodingChar((string)value, actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); + return WriteEncodingChar(value.GenericCast(), actualLength, offset, _defaultEncoding, stateObj, canAccumulate: false); } } } @@ -11775,7 +11807,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int TextDataFeed tdf = value as TextDataFeed; if (tdf == null) { - return NullIfCompletedWriteTask(WriteXmlFeed((XmlDataFeed)value, stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize)); + return NullIfCompletedWriteTask(WriteXmlFeed(value.GenericCast(), stateObj, IsBOMNeeded(type, value), Encoding.Unicode, paramSize)); } else { @@ -11798,25 +11830,25 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int } if (value is byte[]) { // If LazyMat non-filled blob, send cookie rather than value - return stateObj.WriteByteArray((byte[])value, actualLength, 0, canAccumulate: false); + return stateObj.WriteByteArray(value.GenericCast(), actualLength, 0, canAccumulate: false); } else { // convert to cchars instead of cbytes actualLength >>= 1; - return WriteString((string)value, actualLength, offset, stateObj, canAccumulate: false); + return WriteString(value.GenericCast(), actualLength, offset, stateObj, canAccumulate: false); } } } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - WriteDecimal((decimal)value, stateObj); + WriteDecimal(value.GenericCast(), stateObj); break; case TdsEnums.SQLDATETIMN: Debug.Assert(type.FixedLength <= 0xff, "Invalid Fixed Length"); - TdsDateTime dt = MetaType.FromDateTime((DateTime)value, (byte)type.FixedLength); + TdsDateTime dt = MetaType.FromDateTime(value.GenericCast(), (byte)type.FixedLength); if (type.FixedLength == 4) { @@ -11836,13 +11868,13 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int case TdsEnums.SQLMONEYN: { - WriteCurrency((decimal)value, type.FixedLength, stateObj); + WriteCurrency(value.GenericCast(), type.FixedLength, stateObj); break; } case TdsEnums.SQLDATE: { - WriteDate((DateTime)value, stateObj); + WriteDate(value.GenericCast(), stateObj); break; } @@ -11851,7 +11883,7 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { throw SQL.TimeScaleValueOutOfRange(scale); } - WriteTime((TimeSpan)value, scale, actualLength, stateObj); + WriteTime(value.GenericCast(), scale, actualLength, stateObj); break; case TdsEnums.SQLDATETIME2: @@ -11859,11 +11891,11 @@ private Task WriteUnterminatedValue(object value, MetaType type, byte scale, int { throw SQL.TimeScaleValueOutOfRange(scale); } - WriteDateTime2((DateTime)value, scale, actualLength, stateObj); + WriteDateTime2(value.GenericCast(), scale, actualLength, stateObj); break; case TdsEnums.SQLDATETIMEOFFSET: - WriteDateTimeOffset((DateTimeOffset)value, scale, actualLength, stateObj); + WriteDateTimeOffset(value.GenericCast(), scale, actualLength, stateObj); break; default: @@ -12107,7 +12139,7 @@ private byte[] SerializeUnencryptedValue(object value, MetaType type, byte scale // For MAX types, this method can only write everything in one big chunk. If multiple // chunk writes needed, please use WritePlpBytes/WritePlpChars - private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj) + private byte[] SerializeUnencryptedSqlValue(T value, MetaType type, int actualLength, int offset, byte normalizationVersion, TdsParserStateObject stateObj) { Debug.Assert(((type.NullableType == TdsEnums.SQLXMLTYPE) || (value is INullable && !((INullable)value).IsNull)), @@ -12123,11 +12155,13 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act { case TdsEnums.SQLFLTN: if (type.FixedLength == 4) - return SerializeFloat(((SqlSingle)value).Value); + { + return SerializeFloat(value.GenericCast().Value); + } else { Debug.Assert(type.FixedLength == 8, "Invalid length for SqlDouble type!"); - return SerializeDouble(((SqlDouble)value).Value); + return SerializeDouble(value.GenericCast().Value); } case TdsEnums.SQLBIGBINARY: @@ -12138,19 +12172,19 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act if (value is SqlBinary) { - Buffer.BlockCopy(((SqlBinary)value).Value, offset, b, 0, actualLength); + Buffer.BlockCopy(value.GenericCast().Value, offset, b, 0, actualLength); } else { Debug.Assert(value is SqlBytes); - Buffer.BlockCopy(((SqlBytes)value).Value, offset, b, 0, actualLength); + Buffer.BlockCopy(value.GenericCast().Value, offset, b, 0, actualLength); } return b; } case TdsEnums.SQLUNIQUEID: { - byte[] b = ((SqlGuid)value).ToByteArray(); + byte[] b = value.GenericCast().ToByteArray(); Debug.Assert((actualLength == b.Length) && (actualLength == 16), "Invalid length for guid type in com+ object"); return b; @@ -12161,23 +12195,23 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act Debug.Assert(type.FixedLength == 1, "Invalid length for SqlBoolean type"); // We normalize to allow conversion across data types. BIT is serialized into a BIGINT. - return SerializeLong(((SqlBoolean)value).Value == true ? 1 : 0, stateObj); + return SerializeLong(value.GenericCast().Value == true ? 1 : 0, stateObj); } case TdsEnums.SQLINTN: // We normalize to allow conversion across data types. All data types below are serialized into a BIGINT. if (type.FixedLength == 1) - return SerializeLong(((SqlByte)value).Value, stateObj); + return SerializeLong(value.GenericCast().Value, stateObj); if (type.FixedLength == 2) - return SerializeLong(((SqlInt16)value).Value, stateObj); + return SerializeLong(value.GenericCast().Value, stateObj); if (type.FixedLength == 4) - return SerializeLong(((SqlInt32)value).Value, stateObj); + return SerializeLong(value.GenericCast().Value, stateObj); else { Debug.Assert(type.FixedLength == 8, "invalid length for SqlIntN type: " + type.FixedLength.ToString(CultureInfo.InvariantCulture)); - return SerializeLong(((SqlInt64)value).Value, stateObj); + return SerializeLong(value.GenericCast().Value, stateObj); } case TdsEnums.SQLBIGCHAR: @@ -12185,13 +12219,13 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act case TdsEnums.SQLTEXT: if (value is SqlChars) { - String sch = new String(((SqlChars)value).Value); + String sch = new String(value.GenericCast().Value); return SerializeEncodingChar(sch, actualLength, offset, _defaultEncoding); } else { Debug.Assert(value is SqlString); - return SerializeEncodingChar(((SqlString)value).Value, actualLength, offset, _defaultEncoding); + return SerializeEncodingChar(value.GenericCast().Value, actualLength, offset, _defaultEncoding); } @@ -12206,20 +12240,20 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act if (value is SqlChars) { - return SerializeCharArray(((SqlChars)value).Value, actualLength, offset); + return SerializeCharArray(value.GenericCast().Value, actualLength, offset); } else { Debug.Assert(value is SqlString); - return SerializeString(((SqlString)value).Value, actualLength, offset); + return SerializeString(value.GenericCast().Value, actualLength, offset); } case TdsEnums.SQLNUMERICN: Debug.Assert(type.FixedLength <= 17, "Decimal length cannot be greater than 17 bytes"); - return SerializeSqlDecimal((SqlDecimal)value, stateObj); + return SerializeSqlDecimal(value.GenericCast(), stateObj); case TdsEnums.SQLDATETIMN: - SqlDateTime dt = (SqlDateTime)value; + SqlDateTime dt = value.GenericCast(); if (type.FixedLength == 4) { @@ -12265,7 +12299,7 @@ private byte[] SerializeUnencryptedSqlValue(object value, MetaType type, int act case TdsEnums.SQLMONEYN: { - return SerializeSqlMoney((SqlMoney)value, type.FixedLength, stateObj); + return SerializeSqlMoney(value.GenericCast(), type.FixedLength, stateObj); } default: diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index def31ffc10..46908dcbec 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -70,6 +70,7 @@ Common\System\Collections\DictionaryExtensions.cs + @@ -302,6 +303,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs index 4c3d594ad1..caf5dea633 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/DataConversionErrorMessageTest.cs @@ -162,7 +162,7 @@ private bool StringToIntTest(SqlConnection cnn, string targetTable, SourceType s string expectedErrorMsg = string.Format(pattern, args); - Assert.True(ex.Message.Contains(expectedErrorMsg), "Unexpected error message: " + ex.Message); + Assert.True(ex.Message.Contains(expectedErrorMsg), $"Unexpected error message: {ex}"); hitException = true; } return hitException; diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs new file mode 100644 index 0000000000..ad431a920d --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs @@ -0,0 +1,431 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Data; +using System.Linq; +using System.Linq.Expressions; +using BenchmarkDotNet.Attributes; +using BenchmarkDotNet.Configs; +using BenchmarkDotNet.Diagnosers; +using BenchmarkDotNet.Jobs; +using BenchmarkDotNet.Running; +using BenchmarkDotNet.Validators; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + public class NoBoxingValueTypes : IDisposable + { + private static readonly string _table = DataTestUtility.GetUniqueNameForSqlServer(nameof(NoBoxingValueTypes)); + private const int _count = 5000; + private static readonly ItemToCopy _item; + private static readonly IEnumerable _items; + private static readonly IDataReader _reader; + + private static readonly string _connString = DataTestUtility.TCPConnectionString; + + private class ItemToCopy + { + // keeping this data static so the performance of the benchmark is not varied by the data size & shape + public int IntColumn { get; } = 123456; + public bool BoolColumn { get; } = true; + } + + static NoBoxingValueTypes() + { + _item = new ItemToCopy(); + + _items = Enumerable.Range(0, _count).Select(x => _item).ToArray(); + + // It would've been great to use mgravell/FastMember here to make thing logic much cleaner and not have to include the custom reader + // however, that package is not available on the dotnet-public Nuget source, which is a bummer. + _reader = new EnumerableDataReaderFactoryBuilder(_table) + .Add("IntColumn", i => i.IntColumn) + .Add("BoolColumn", i => i.BoolColumn) + .BuildFactory() + .CreateReader(_items) + ; + } + + public NoBoxingValueTypes() + { + using (var conn = new SqlConnection(_connString)) + using (var cmd = conn.CreateCommand()) + { + conn.Open(); + Helpers.TryExecute(cmd, $@" + CREATE TABLE {_table} ( + IntColumn INT NOT NULL, + BoolColumn BIT NOT NULL + ) + "); + } + } + + private class RunOnceConfig : ManualConfig + { + public RunOnceConfig() + { + Add(Job.InProcess.WithLaunchCount(1).WithIterationCount(1).WithWarmupCount(0)); + Add(MemoryDiagnoser.Default); + + Add(JitOptimizationsValidator.DontFailOnError); + } + } + + + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer))] + public void Should_Not_Box() + { // in debug mode, the double boxing DOES occur as the JIT optimizes less code, which causes the test to fail +#if DEBUG + return; +#else + //cannot figure out an easy way to get this to work on all platforms + + var config = new RunOnceConfig(); // cannot use fluent syntax to still support net461 + + var summary = BenchmarkRunner.Run(config); + + var numValueTypeColumns = 2; + var totalBytesWhenBoxed = IntPtr.Size * _count * numValueTypeColumns; + + var report = summary.Reports.First(); + + Assert.Equal(1, report.AllMeasurements.Count); + Assert.True(report.GcStats.BytesAllocatedPerOperation < totalBytesWhenBoxed); +#endif + } + + public class NoBoxingValueTypesBenchmark + { + [Benchmark] + public void BulkCopy() + { + _reader.Close(); // this resets the reader + + using (var bc = new SqlBulkCopy(DataTestUtility.TCPConnectionString, SqlBulkCopyOptions.TableLock)) + { + bc.BatchSize = _count; + bc.DestinationTableName = _table; + bc.BulkCopyTimeout = 60; + + bc.WriteToServer(_reader); + } + } + } + + public void Dispose() + { + using (var conn = new SqlConnection(_connString)) + using (var cmd = conn.CreateCommand()) + { + conn.Open(); + Helpers.TryExecute(cmd, $@" + DROP TABLE IF EXISTS {_table} + "); + } + } + + //all code here and below is a custom data reader implementation to support the benchmark + private class EnumerableDataReaderFactoryBuilder + { + private readonly List _expressions = new List(); + private readonly List> _objExpressions = new List>(); + private readonly DataTable _schemaTable; + + public EnumerableDataReaderFactoryBuilder(string tableName) + { + Name = tableName; + _schemaTable = new DataTable(); + } + + private static readonly HashSet _validTypes = new HashSet + { + typeof(decimal), + typeof(decimal?), + typeof(string), + typeof(int), + typeof(int?), + typeof(double), + typeof(bool), + typeof(bool?), + typeof(Guid), + typeof(DateTime), + }; + + public EnumerableDataReaderFactoryBuilder Add(string column, Expression> expression) + { + var t = typeof(TColumn); + + var func = expression.Compile(); + + // don't do any optimizations for boxing bools here to detect boxing occurring properly. + Expression> objExpression = o => func(o); + + _objExpressions.Add(objExpression.Compile()); + + if (_validTypes.Contains(t)) + { + t = Nullable.GetUnderlyingType(t) ?? t; // data table doesn't accept nullable. + _schemaTable.Columns.Add(column, t); + _expressions.Add(expression); + } + else + { + Console.WriteLine($"Could not matching return type for {Name}.{column} of: {t.Name}"); + _schemaTable.Columns.Add(column); //add w/o type to force using GetValue + + _expressions.Add(objExpression); + } + + return this; + } + + public EnumerableDataReaderFactory BuildFactory() => new EnumerableDataReaderFactory(_schemaTable, _expressions, _objExpressions); + + public string Name { get; } + } + + public class EnumerableDataReaderFactory + { + public DataTable SchemaTable { get; } + public Func[] ObjectGetters { get; } + public Func[] DecimalGetters { get; } + public Func[] NullableDecimalGetters { get; } + public Func[] StringGetters { get; } + public Func[] DoubleGetters { get; } + public Func[] IntGetters { get; } + public Func[] NullableIntGetters { get; } + public Func[] BoolGetters { get; } + + public Func[] NullableBoolGetters { get; } + + public Func[] GuidGetters { get; } + public Func[] DateTimeGetters { get; } + public bool[] NullableIndexes { get; } + + public EnumerableDataReaderFactory(DataTable schemaTable, List expressions, List> objectGetters) + { + SchemaTable = schemaTable; + DecimalGetters = new Func[expressions.Count]; + NullableDecimalGetters = new Func[expressions.Count]; + StringGetters = new Func[expressions.Count]; + DoubleGetters = new Func[expressions.Count]; + IntGetters = new Func[expressions.Count]; + NullableIntGetters = new Func[expressions.Count]; + BoolGetters = new Func[expressions.Count]; + NullableBoolGetters = new Func[expressions.Count]; + GuidGetters = new Func[expressions.Count]; + DateTimeGetters = new Func[expressions.Count]; + NullableIndexes = new bool[expressions.Count]; + + ObjectGetters = objectGetters.ToArray(); + + for (int i = 0; i < expressions.Count; i++) + { + var expression = expressions[i]; + + NullableIndexes[i] = !expression.ReturnType.IsValueType || Nullable.GetUnderlyingType(expression.ReturnType) != null; + + switch (expression) + { + case Expression> e: + break; // do nothing + case Expression> e: + DecimalGetters[i] = e.Compile(); + break; + case Expression> e: + NullableDecimalGetters[i] = e.Compile(); + break; + case Expression> e: + StringGetters[i] = e.Compile(); + break; + case Expression> e: + DoubleGetters[i] = e.Compile(); + break; + case Expression> e: + IntGetters[i] = e.Compile(); + break; + case Expression> e: + NullableIntGetters[i] = e.Compile(); + break; + case Expression> e: + BoolGetters[i] = e.Compile(); + break; + case Expression> e: + NullableBoolGetters[i] = e.Compile(); + break; + case Expression> e: + GuidGetters[i] = e.Compile(); + break; + case Expression> e: + DateTimeGetters[i] = e.Compile(); + break; + default: + throw new Exception($"Type missing: {expression.GetType().FullName}"); + } + } + } + + public IDataReader CreateReader(IEnumerable items) => new EnumerableDataReader(this, items.GetEnumerator()); + } + + public class EnumerableDataReader : IDataReader + { + private readonly IEnumerator _source; + private readonly EnumerableDataReaderFactory _context; + + public EnumerableDataReader(EnumerableDataReaderFactory context, IEnumerator source) + { + _source = source; + _context = context; + } + + public object GetValue(int i) + { + var v = _context.ObjectGetters[i](_source.Current); + return v; + } + + public int FieldCount => _context.ObjectGetters.Length; + + public bool Read() => _source.MoveNext(); + + public void Close() => _source.Reset(); + + public void Dispose() => this.Close(); + + public bool NextResult() => throw new NotImplementedException(); + + public int Depth => 0; + + public bool IsClosed => false; + + public int RecordsAffected => -1; + + public DataTable GetSchemaTable() => _context.SchemaTable; + + public object this[string name] => throw new NotImplementedException(); + + public object this[int i] => GetValue(i); + + public bool GetBoolean(int i) + { + var g = _context.BoolGetters[i]; + + if (g != null) + return g(_source.Current); + + return _context.NullableBoolGetters[i](_source.Current).Value; + } + + public byte GetByte(int i) => throw new NotImplementedException(); + + public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) => throw new NotImplementedException(); + + public char GetChar(int i) => throw new NotImplementedException(); + public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) => -1; + + public IDataReader GetData(int i) => throw new NotImplementedException(); + + public string GetDataTypeName(int i) => throw new NotImplementedException(); + + public DateTime GetDateTime(int i) => _context.DateTimeGetters[i](_source.Current); + + public decimal GetDecimal(int i) + { + var g = _context.DecimalGetters[i]; + + if (g != null) + return g(_source.Current); + + return _context.NullableDecimalGetters[i](_source.Current).Value; + } + + public double GetDouble(int i) => _context.DoubleGetters[i](_source.Current); + + public Type GetFieldType(int i) => _context.SchemaTable.Columns[i].DataType; + + public float GetFloat(int i) => throw new NotImplementedException(); + + public Guid GetGuid(int i) => _context.GuidGetters[i](_source.Current); + + public short GetInt16(int i) => throw new NotImplementedException(); + + public int GetInt32(int i) + { + var g = _context.IntGetters[i]; + + if (g != null) + return g(_source.Current); + + return _context.NullableIntGetters[i](_source.Current).Value; + } + + public long GetInt64(int i) => throw new NotImplementedException(); + + public string GetName(int i) + { + if (_context.SchemaTable.Columns.Count > i) + { + return _context.SchemaTable.Columns[i].ColumnName; + } + throw new IndexOutOfRangeException($"No column for index {i}"); + } + + public int GetOrdinal(string name) + { + if (_context.SchemaTable.Columns.Count == 0) + { + throw new Exception("Schema table is empty"); + } + return _context.SchemaTable.Columns.IndexOf(name); + } + + public string GetString(int i) => _context.StringGetters[i](_source.Current); + + public int GetValues(object[] values) => throw new NotImplementedException(); + + public bool IsDBNull(int i) + { + // short circuit for non-nullable types + if (!_context.NullableIndexes[i]) + { + return false; + } + + // otherwise find the first one -- starting w/ most occurring to least + + var ig = _context.NullableIntGetters[i]; + if (ig != null) + { + return ig(_source.Current) == null; + } + + var sg = _context.StringGetters[i]; + if (sg != null) + { + return sg(_source.Current) == null; + } + + var bg = _context.NullableBoolGetters[i]; + if (bg != null) + { + return bg(_source.Current) == null; + } + + var dg = _context.NullableDecimalGetters[i]; + if (dg != null) + { + return dg(_source.Current) == null; + } + + return false; + } + } + } +} diff --git a/tools/props/Versions.props b/tools/props/Versions.props index ec4365798c..b7b11684d7 100644 --- a/tools/props/Versions.props +++ b/tools/props/Versions.props @@ -50,6 +50,7 @@ + 0.11.3 3.1.1 5.2.6 15.9.0 From 667ffb15dbcc9dd0a43bf2a04dda3d743af511a0 Mon Sep 17 00:00:00 2001 From: Carl Meyertons Date: Mon, 26 Apr 2021 12:01:05 -0500 Subject: [PATCH 2/6] tests Dispose pattern --- .../SqlBulkCopyTest/NoBoxingValuesTypes.cs | 82 ++++++++++++++++--- 1 file changed, 69 insertions(+), 13 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs index ad431a920d..c218339894 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs @@ -11,13 +11,12 @@ using BenchmarkDotNet.Configs; using BenchmarkDotNet.Diagnosers; using BenchmarkDotNet.Jobs; -using BenchmarkDotNet.Running; using BenchmarkDotNet.Validators; using Xunit; namespace Microsoft.Data.SqlClient.ManualTesting.Tests { - public class NoBoxingValueTypes : IDisposable + public sealed class NoBoxingValueTypes : IDisposable { private static readonly string _table = DataTestUtility.GetUniqueNameForSqlServer(nameof(NoBoxingValueTypes)); private const int _count = 5000; @@ -26,6 +25,7 @@ public class NoBoxingValueTypes : IDisposable private static readonly IDataReader _reader; private static readonly string _connString = DataTestUtility.TCPConnectionString; + private bool _disposedValue; private class ItemToCopy { @@ -118,20 +118,36 @@ public void BulkCopy() } } - public void Dispose() + private void Dispose(bool disposing) { - using (var conn = new SqlConnection(_connString)) - using (var cmd = conn.CreateCommand()) + if (!_disposedValue) { - conn.Open(); - Helpers.TryExecute(cmd, $@" - DROP TABLE IF EXISTS {_table} - "); + if (disposing) + { + _reader.Dispose(); + using (var conn = new SqlConnection(_connString)) + using (var cmd = conn.CreateCommand()) + { + conn.Open(); + Helpers.TryExecute(cmd, $@" + DROP TABLE IF EXISTS {_table} + "); + } + } + + _disposedValue = true; } } + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } + //all code here and below is a custom data reader implementation to support the benchmark - private class EnumerableDataReaderFactoryBuilder + private sealed class EnumerableDataReaderFactoryBuilder: IDisposable { private readonly List _expressions = new List(); private readonly List> _objExpressions = new List>(); @@ -156,6 +172,7 @@ public EnumerableDataReaderFactoryBuilder(string tableName) typeof(Guid), typeof(DateTime), }; + private bool _disposedValue; public EnumerableDataReaderFactoryBuilder Add(string column, Expression> expression) { @@ -188,6 +205,26 @@ public EnumerableDataReaderFactoryBuilder Add(string column, Express public EnumerableDataReaderFactory BuildFactory() => new EnumerableDataReaderFactory(_schemaTable, _expressions, _objExpressions); public string Name { get; } + + private void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + _schemaTable.Dispose(); + } + + _disposedValue = true; + } + } + + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } } public class EnumerableDataReaderFactory @@ -274,10 +311,11 @@ public EnumerableDataReaderFactory(DataTable schemaTable, List public IDataReader CreateReader(IEnumerable items) => new EnumerableDataReader(this, items.GetEnumerator()); } - public class EnumerableDataReader : IDataReader + public sealed class EnumerableDataReader : IDataReader { private readonly IEnumerator _source; private readonly EnumerableDataReaderFactory _context; + private bool _disposedValue; public EnumerableDataReader(EnumerableDataReaderFactory context, IEnumerator source) { @@ -297,8 +335,6 @@ public object GetValue(int i) public void Close() => _source.Reset(); - public void Dispose() => this.Close(); - public bool NextResult() => throw new NotImplementedException(); public int Depth => 0; @@ -426,6 +462,26 @@ public bool IsDBNull(int i) return false; } + + private void Dispose(bool disposing) + { + if (!_disposedValue) + { + if (disposing) + { + this.Close(); + } + + _disposedValue = true; + } + } + + public void Dispose() + { + // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method + Dispose(disposing: true); + GC.SuppressFinalize(this); + } } } } From b3ed9a03e8fcc63698abdd687ec2830331bc99f0 Mon Sep 17 00:00:00 2001 From: Carl Meyertons Date: Mon, 26 Apr 2021 12:52:00 -0500 Subject: [PATCH 3/6] removing test --- ....Data.SqlClient.ManualTesting.Tests.csproj | 2 - .../SqlBulkCopyTest/NoBoxingValuesTypes.cs | 936 +++++++++--------- tools/props/Versions.props | 1 - 3 files changed, 467 insertions(+), 472 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index 46908dcbec..def31ffc10 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -70,7 +70,6 @@ Common\System\Collections\DictionaryExtensions.cs - @@ -303,7 +302,6 @@ - diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs index c218339894..2a0f0b537d 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs @@ -1,487 +1,485 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. -// See the LICENSE file in the project root for more information. - -using System; -using System.Collections.Generic; -using System.Data; -using System.Linq; -using System.Linq.Expressions; -using BenchmarkDotNet.Attributes; -using BenchmarkDotNet.Configs; -using BenchmarkDotNet.Diagnosers; -using BenchmarkDotNet.Jobs; -using BenchmarkDotNet.Validators; -using Xunit; - -namespace Microsoft.Data.SqlClient.ManualTesting.Tests -{ - public sealed class NoBoxingValueTypes : IDisposable - { - private static readonly string _table = DataTestUtility.GetUniqueNameForSqlServer(nameof(NoBoxingValueTypes)); - private const int _count = 5000; - private static readonly ItemToCopy _item; - private static readonly IEnumerable _items; - private static readonly IDataReader _reader; - - private static readonly string _connString = DataTestUtility.TCPConnectionString; - private bool _disposedValue; - - private class ItemToCopy - { - // keeping this data static so the performance of the benchmark is not varied by the data size & shape - public int IntColumn { get; } = 123456; - public bool BoolColumn { get; } = true; - } - - static NoBoxingValueTypes() - { - _item = new ItemToCopy(); - - _items = Enumerable.Range(0, _count).Select(x => _item).ToArray(); - - // It would've been great to use mgravell/FastMember here to make thing logic much cleaner and not have to include the custom reader - // however, that package is not available on the dotnet-public Nuget source, which is a bummer. - _reader = new EnumerableDataReaderFactoryBuilder(_table) - .Add("IntColumn", i => i.IntColumn) - .Add("BoolColumn", i => i.BoolColumn) - .BuildFactory() - .CreateReader(_items) - ; - } - - public NoBoxingValueTypes() - { - using (var conn = new SqlConnection(_connString)) - using (var cmd = conn.CreateCommand()) - { - conn.Open(); - Helpers.TryExecute(cmd, $@" - CREATE TABLE {_table} ( - IntColumn INT NOT NULL, - BoolColumn BIT NOT NULL - ) - "); - } - } - - private class RunOnceConfig : ManualConfig - { - public RunOnceConfig() - { - Add(Job.InProcess.WithLaunchCount(1).WithIterationCount(1).WithWarmupCount(0)); - Add(MemoryDiagnoser.Default); - - Add(JitOptimizationsValidator.DontFailOnError); - } - } - - - - [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer))] - public void Should_Not_Box() - { // in debug mode, the double boxing DOES occur as the JIT optimizes less code, which causes the test to fail -#if DEBUG - return; -#else - //cannot figure out an easy way to get this to work on all platforms - - var config = new RunOnceConfig(); // cannot use fluent syntax to still support net461 - - var summary = BenchmarkRunner.Run(config); - - var numValueTypeColumns = 2; - var totalBytesWhenBoxed = IntPtr.Size * _count * numValueTypeColumns; - - var report = summary.Reports.First(); - - Assert.Equal(1, report.AllMeasurements.Count); - Assert.True(report.GcStats.BytesAllocatedPerOperation < totalBytesWhenBoxed); -#endif - } - - public class NoBoxingValueTypesBenchmark - { - [Benchmark] - public void BulkCopy() - { - _reader.Close(); // this resets the reader - - using (var bc = new SqlBulkCopy(DataTestUtility.TCPConnectionString, SqlBulkCopyOptions.TableLock)) - { - bc.BatchSize = _count; - bc.DestinationTableName = _table; - bc.BulkCopyTimeout = 60; - - bc.WriteToServer(_reader); - } - } - } - - private void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _reader.Dispose(); - using (var conn = new SqlConnection(_connString)) - using (var cmd = conn.CreateCommand()) - { - conn.Open(); - Helpers.TryExecute(cmd, $@" - DROP TABLE IF EXISTS {_table} - "); - } - } - - _disposedValue = true; - } - } - - public void Dispose() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - - //all code here and below is a custom data reader implementation to support the benchmark - private sealed class EnumerableDataReaderFactoryBuilder: IDisposable - { - private readonly List _expressions = new List(); - private readonly List> _objExpressions = new List>(); - private readonly DataTable _schemaTable; - - public EnumerableDataReaderFactoryBuilder(string tableName) - { - Name = tableName; - _schemaTable = new DataTable(); - } - - private static readonly HashSet _validTypes = new HashSet - { - typeof(decimal), - typeof(decimal?), - typeof(string), - typeof(int), - typeof(int?), - typeof(double), - typeof(bool), - typeof(bool?), - typeof(Guid), - typeof(DateTime), - }; - private bool _disposedValue; - - public EnumerableDataReaderFactoryBuilder Add(string column, Expression> expression) - { - var t = typeof(TColumn); - - var func = expression.Compile(); - - // don't do any optimizations for boxing bools here to detect boxing occurring properly. - Expression> objExpression = o => func(o); - - _objExpressions.Add(objExpression.Compile()); - - if (_validTypes.Contains(t)) - { - t = Nullable.GetUnderlyingType(t) ?? t; // data table doesn't accept nullable. - _schemaTable.Columns.Add(column, t); - _expressions.Add(expression); - } - else - { - Console.WriteLine($"Could not matching return type for {Name}.{column} of: {t.Name}"); - _schemaTable.Columns.Add(column); //add w/o type to force using GetValue - - _expressions.Add(objExpression); - } - - return this; - } - - public EnumerableDataReaderFactory BuildFactory() => new EnumerableDataReaderFactory(_schemaTable, _expressions, _objExpressions); - - public string Name { get; } - - private void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - _schemaTable.Dispose(); - } - - _disposedValue = true; - } - } - - public void Dispose() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - } - - public class EnumerableDataReaderFactory - { - public DataTable SchemaTable { get; } - public Func[] ObjectGetters { get; } - public Func[] DecimalGetters { get; } - public Func[] NullableDecimalGetters { get; } - public Func[] StringGetters { get; } - public Func[] DoubleGetters { get; } - public Func[] IntGetters { get; } - public Func[] NullableIntGetters { get; } - public Func[] BoolGetters { get; } - - public Func[] NullableBoolGetters { get; } - - public Func[] GuidGetters { get; } - public Func[] DateTimeGetters { get; } - public bool[] NullableIndexes { get; } - - public EnumerableDataReaderFactory(DataTable schemaTable, List expressions, List> objectGetters) - { - SchemaTable = schemaTable; - DecimalGetters = new Func[expressions.Count]; - NullableDecimalGetters = new Func[expressions.Count]; - StringGetters = new Func[expressions.Count]; - DoubleGetters = new Func[expressions.Count]; - IntGetters = new Func[expressions.Count]; - NullableIntGetters = new Func[expressions.Count]; - BoolGetters = new Func[expressions.Count]; - NullableBoolGetters = new Func[expressions.Count]; - GuidGetters = new Func[expressions.Count]; - DateTimeGetters = new Func[expressions.Count]; - NullableIndexes = new bool[expressions.Count]; - - ObjectGetters = objectGetters.ToArray(); - - for (int i = 0; i < expressions.Count; i++) - { - var expression = expressions[i]; - - NullableIndexes[i] = !expression.ReturnType.IsValueType || Nullable.GetUnderlyingType(expression.ReturnType) != null; - - switch (expression) - { - case Expression> e: - break; // do nothing - case Expression> e: - DecimalGetters[i] = e.Compile(); - break; - case Expression> e: - NullableDecimalGetters[i] = e.Compile(); - break; - case Expression> e: - StringGetters[i] = e.Compile(); - break; - case Expression> e: - DoubleGetters[i] = e.Compile(); - break; - case Expression> e: - IntGetters[i] = e.Compile(); - break; - case Expression> e: - NullableIntGetters[i] = e.Compile(); - break; - case Expression> e: - BoolGetters[i] = e.Compile(); - break; - case Expression> e: - NullableBoolGetters[i] = e.Compile(); - break; - case Expression> e: - GuidGetters[i] = e.Compile(); - break; - case Expression> e: - DateTimeGetters[i] = e.Compile(); - break; - default: - throw new Exception($"Type missing: {expression.GetType().FullName}"); - } - } - } - - public IDataReader CreateReader(IEnumerable items) => new EnumerableDataReader(this, items.GetEnumerator()); - } - - public sealed class EnumerableDataReader : IDataReader - { - private readonly IEnumerator _source; - private readonly EnumerableDataReaderFactory _context; - private bool _disposedValue; - - public EnumerableDataReader(EnumerableDataReaderFactory context, IEnumerator source) - { - _source = source; - _context = context; - } - - public object GetValue(int i) - { - var v = _context.ObjectGetters[i](_source.Current); - return v; - } - - public int FieldCount => _context.ObjectGetters.Length; - - public bool Read() => _source.MoveNext(); - - public void Close() => _source.Reset(); - - public bool NextResult() => throw new NotImplementedException(); - - public int Depth => 0; - - public bool IsClosed => false; - - public int RecordsAffected => -1; - - public DataTable GetSchemaTable() => _context.SchemaTable; - - public object this[string name] => throw new NotImplementedException(); - - public object this[int i] => GetValue(i); - - public bool GetBoolean(int i) - { - var g = _context.BoolGetters[i]; - - if (g != null) - return g(_source.Current); - - return _context.NullableBoolGetters[i](_source.Current).Value; - } - - public byte GetByte(int i) => throw new NotImplementedException(); - - public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) => throw new NotImplementedException(); - - public char GetChar(int i) => throw new NotImplementedException(); - public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) => -1; - - public IDataReader GetData(int i) => throw new NotImplementedException(); - - public string GetDataTypeName(int i) => throw new NotImplementedException(); - - public DateTime GetDateTime(int i) => _context.DateTimeGetters[i](_source.Current); - - public decimal GetDecimal(int i) - { - var g = _context.DecimalGetters[i]; - - if (g != null) - return g(_source.Current); - - return _context.NullableDecimalGetters[i](_source.Current).Value; - } - - public double GetDouble(int i) => _context.DoubleGetters[i](_source.Current); - - public Type GetFieldType(int i) => _context.SchemaTable.Columns[i].DataType; - - public float GetFloat(int i) => throw new NotImplementedException(); +//// Licensed to the .NET Foundation under one or more agreements. +//// The .NET Foundation licenses this file to you under the MIT license. +//// See the LICENSE file in the project root for more information. + +//using System; +//using System.Collections.Generic; +//using System.Data; +//using System.Linq; +//using System.Linq.Expressions; +//using BenchmarkDotNet.Attributes; +//using BenchmarkDotNet.Configs; +//using BenchmarkDotNet.Diagnosers; +//using BenchmarkDotNet.Jobs; +//using BenchmarkDotNet.Validators; +//using Xunit; + +//namespace Microsoft.Data.SqlClient.ManualTesting.Tests +//{ +// public sealed class NoBoxingValueTypes : IDisposable +// { +// private static readonly string _table = DataTestUtility.GetUniqueNameForSqlServer(nameof(NoBoxingValueTypes)); +// private const int _count = 5000; +// private static readonly ItemToCopy _item; +// private static readonly IEnumerable _items; +// private static readonly IDataReader _reader; + +// private static readonly string _connString = DataTestUtility.TCPConnectionString; +// private bool _disposedValue; + +// private class ItemToCopy +// { +// // keeping this data static so the performance of the benchmark is not varied by the data size & shape +// public int IntColumn { get; } = 123456; +// public bool BoolColumn { get; } = true; +// } + +// static NoBoxingValueTypes() +// { +// _item = new ItemToCopy(); + +// _items = Enumerable.Range(0, _count).Select(x => _item).ToArray(); + +// // It would've been great to use mgravell/FastMember here to make thing logic much cleaner and not have to include the custom reader +// // however, that package is not available on the dotnet-public Nuget source, which is a bummer. +// _reader = new EnumerableDataReaderFactoryBuilder(_table) +// .Add("IntColumn", i => i.IntColumn) +// .Add("BoolColumn", i => i.BoolColumn) +// .BuildFactory() +// .CreateReader(_items) +// ; +// } + +// public NoBoxingValueTypes() +// { +// using (var conn = new SqlConnection(_connString)) +// using (var cmd = conn.CreateCommand()) +// { +// conn.Open(); +// Helpers.TryExecute(cmd, $@" +// CREATE TABLE {_table} ( +// IntColumn INT NOT NULL, +// BoolColumn BIT NOT NULL +// ) +// "); +// } +// } + +// private class RunOnceConfig : ManualConfig +// { +// public RunOnceConfig() +// { +// Add(Job.InProcess.WithLaunchCount(1).WithIterationCount(1).WithWarmupCount(0)); +// Add(MemoryDiagnoser.Default); + +// Add(JitOptimizationsValidator.DontFailOnError); +// } +// } + +// [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer))] +// public void Should_Not_Box() +// { // in debug mode, the double boxing DOES occur as the JIT optimizes less code, which causes the test to fail +//#if DEBUG +// return; +//#else +// //cannot figure out an easy way to get this to work on all platforms + +// var config = new RunOnceConfig(); // cannot use fluent syntax to still support net461 + +// var summary = BenchmarkRunner.Run(config); + +// var numValueTypeColumns = 2; +// var totalBytesWhenBoxed = IntPtr.Size * _count * numValueTypeColumns; + +// var report = summary.Reports.First(); + +// Assert.Equal(1, report.AllMeasurements.Count); +// Assert.True(report.GcStats.BytesAllocatedPerOperation < totalBytesWhenBoxed); +//#endif +// } + +// public class NoBoxingValueTypesBenchmark +// { +// [Benchmark] +// public void BulkCopy() +// { +// _reader.Close(); // this resets the reader + +// using (var bc = new SqlBulkCopy(DataTestUtility.TCPConnectionString, SqlBulkCopyOptions.TableLock)) +// { +// bc.BatchSize = _count; +// bc.DestinationTableName = _table; +// bc.BulkCopyTimeout = 60; + +// bc.WriteToServer(_reader); +// } +// } +// } + +// private void Dispose(bool disposing) +// { +// if (!_disposedValue) +// { +// if (disposing) +// { +// _reader.Dispose(); +// using (var conn = new SqlConnection(_connString)) +// using (var cmd = conn.CreateCommand()) +// { +// conn.Open(); +// Helpers.TryExecute(cmd, $@" +// DROP TABLE IF EXISTS {_table} +// "); +// } +// } + +// _disposedValue = true; +// } +// } + +// public void Dispose() +// { +// // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method +// Dispose(disposing: true); +// GC.SuppressFinalize(this); +// } + +// //all code here and below is a custom data reader implementation to support the benchmark +// private sealed class EnumerableDataReaderFactoryBuilder: IDisposable +// { +// private readonly List _expressions = new List(); +// private readonly List> _objExpressions = new List>(); +// private readonly DataTable _schemaTable; + +// public EnumerableDataReaderFactoryBuilder(string tableName) +// { +// Name = tableName; +// _schemaTable = new DataTable(); +// } + +// private static readonly HashSet _validTypes = new HashSet +// { +// typeof(decimal), +// typeof(decimal?), +// typeof(string), +// typeof(int), +// typeof(int?), +// typeof(double), +// typeof(bool), +// typeof(bool?), +// typeof(Guid), +// typeof(DateTime), +// }; +// private bool _disposedValue; + +// public EnumerableDataReaderFactoryBuilder Add(string column, Expression> expression) +// { +// var t = typeof(TColumn); + +// var func = expression.Compile(); + +// // don't do any optimizations for boxing bools here to detect boxing occurring properly. +// Expression> objExpression = o => func(o); + +// _objExpressions.Add(objExpression.Compile()); + +// if (_validTypes.Contains(t)) +// { +// t = Nullable.GetUnderlyingType(t) ?? t; // data table doesn't accept nullable. +// _schemaTable.Columns.Add(column, t); +// _expressions.Add(expression); +// } +// else +// { +// Console.WriteLine($"Could not matching return type for {Name}.{column} of: {t.Name}"); +// _schemaTable.Columns.Add(column); //add w/o type to force using GetValue + +// _expressions.Add(objExpression); +// } + +// return this; +// } + +// public EnumerableDataReaderFactory BuildFactory() => new EnumerableDataReaderFactory(_schemaTable, _expressions, _objExpressions); + +// public string Name { get; } + +// private void Dispose(bool disposing) +// { +// if (!_disposedValue) +// { +// if (disposing) +// { +// _schemaTable.Dispose(); +// } + +// _disposedValue = true; +// } +// } + +// public void Dispose() +// { +// // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method +// Dispose(disposing: true); +// GC.SuppressFinalize(this); +// } +// } + +// public class EnumerableDataReaderFactory +// { +// public DataTable SchemaTable { get; } +// public Func[] ObjectGetters { get; } +// public Func[] DecimalGetters { get; } +// public Func[] NullableDecimalGetters { get; } +// public Func[] StringGetters { get; } +// public Func[] DoubleGetters { get; } +// public Func[] IntGetters { get; } +// public Func[] NullableIntGetters { get; } +// public Func[] BoolGetters { get; } + +// public Func[] NullableBoolGetters { get; } + +// public Func[] GuidGetters { get; } +// public Func[] DateTimeGetters { get; } +// public bool[] NullableIndexes { get; } + +// public EnumerableDataReaderFactory(DataTable schemaTable, List expressions, List> objectGetters) +// { +// SchemaTable = schemaTable; +// DecimalGetters = new Func[expressions.Count]; +// NullableDecimalGetters = new Func[expressions.Count]; +// StringGetters = new Func[expressions.Count]; +// DoubleGetters = new Func[expressions.Count]; +// IntGetters = new Func[expressions.Count]; +// NullableIntGetters = new Func[expressions.Count]; +// BoolGetters = new Func[expressions.Count]; +// NullableBoolGetters = new Func[expressions.Count]; +// GuidGetters = new Func[expressions.Count]; +// DateTimeGetters = new Func[expressions.Count]; +// NullableIndexes = new bool[expressions.Count]; + +// ObjectGetters = objectGetters.ToArray(); + +// for (int i = 0; i < expressions.Count; i++) +// { +// var expression = expressions[i]; + +// NullableIndexes[i] = !expression.ReturnType.IsValueType || Nullable.GetUnderlyingType(expression.ReturnType) != null; + +// switch (expression) +// { +// case Expression> e: +// break; // do nothing +// case Expression> e: +// DecimalGetters[i] = e.Compile(); +// break; +// case Expression> e: +// NullableDecimalGetters[i] = e.Compile(); +// break; +// case Expression> e: +// StringGetters[i] = e.Compile(); +// break; +// case Expression> e: +// DoubleGetters[i] = e.Compile(); +// break; +// case Expression> e: +// IntGetters[i] = e.Compile(); +// break; +// case Expression> e: +// NullableIntGetters[i] = e.Compile(); +// break; +// case Expression> e: +// BoolGetters[i] = e.Compile(); +// break; +// case Expression> e: +// NullableBoolGetters[i] = e.Compile(); +// break; +// case Expression> e: +// GuidGetters[i] = e.Compile(); +// break; +// case Expression> e: +// DateTimeGetters[i] = e.Compile(); +// break; +// default: +// throw new Exception($"Type missing: {expression.GetType().FullName}"); +// } +// } +// } + +// public IDataReader CreateReader(IEnumerable items) => new EnumerableDataReader(this, items.GetEnumerator()); +// } + +// public sealed class EnumerableDataReader : IDataReader +// { +// private readonly IEnumerator _source; +// private readonly EnumerableDataReaderFactory _context; +// private bool _disposedValue; + +// public EnumerableDataReader(EnumerableDataReaderFactory context, IEnumerator source) +// { +// _source = source; +// _context = context; +// } + +// public object GetValue(int i) +// { +// var v = _context.ObjectGetters[i](_source.Current); +// return v; +// } + +// public int FieldCount => _context.ObjectGetters.Length; + +// public bool Read() => _source.MoveNext(); + +// public void Close() => _source.Reset(); + +// public bool NextResult() => throw new NotImplementedException(); + +// public int Depth => 0; + +// public bool IsClosed => false; + +// public int RecordsAffected => -1; + +// public DataTable GetSchemaTable() => _context.SchemaTable; + +// public object this[string name] => throw new NotImplementedException(); + +// public object this[int i] => GetValue(i); + +// public bool GetBoolean(int i) +// { +// var g = _context.BoolGetters[i]; + +// if (g != null) +// return g(_source.Current); + +// return _context.NullableBoolGetters[i](_source.Current).Value; +// } + +// public byte GetByte(int i) => throw new NotImplementedException(); + +// public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) => throw new NotImplementedException(); + +// public char GetChar(int i) => throw new NotImplementedException(); +// public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) => -1; + +// public IDataReader GetData(int i) => throw new NotImplementedException(); + +// public string GetDataTypeName(int i) => throw new NotImplementedException(); + +// public DateTime GetDateTime(int i) => _context.DateTimeGetters[i](_source.Current); + +// public decimal GetDecimal(int i) +// { +// var g = _context.DecimalGetters[i]; + +// if (g != null) +// return g(_source.Current); + +// return _context.NullableDecimalGetters[i](_source.Current).Value; +// } + +// public double GetDouble(int i) => _context.DoubleGetters[i](_source.Current); + +// public Type GetFieldType(int i) => _context.SchemaTable.Columns[i].DataType; + +// public float GetFloat(int i) => throw new NotImplementedException(); - public Guid GetGuid(int i) => _context.GuidGetters[i](_source.Current); +// public Guid GetGuid(int i) => _context.GuidGetters[i](_source.Current); - public short GetInt16(int i) => throw new NotImplementedException(); +// public short GetInt16(int i) => throw new NotImplementedException(); - public int GetInt32(int i) - { - var g = _context.IntGetters[i]; +// public int GetInt32(int i) +// { +// var g = _context.IntGetters[i]; - if (g != null) - return g(_source.Current); +// if (g != null) +// return g(_source.Current); - return _context.NullableIntGetters[i](_source.Current).Value; - } +// return _context.NullableIntGetters[i](_source.Current).Value; +// } - public long GetInt64(int i) => throw new NotImplementedException(); +// public long GetInt64(int i) => throw new NotImplementedException(); - public string GetName(int i) - { - if (_context.SchemaTable.Columns.Count > i) - { - return _context.SchemaTable.Columns[i].ColumnName; - } - throw new IndexOutOfRangeException($"No column for index {i}"); - } +// public string GetName(int i) +// { +// if (_context.SchemaTable.Columns.Count > i) +// { +// return _context.SchemaTable.Columns[i].ColumnName; +// } +// throw new IndexOutOfRangeException($"No column for index {i}"); +// } - public int GetOrdinal(string name) - { - if (_context.SchemaTable.Columns.Count == 0) - { - throw new Exception("Schema table is empty"); - } - return _context.SchemaTable.Columns.IndexOf(name); - } +// public int GetOrdinal(string name) +// { +// if (_context.SchemaTable.Columns.Count == 0) +// { +// throw new Exception("Schema table is empty"); +// } +// return _context.SchemaTable.Columns.IndexOf(name); +// } - public string GetString(int i) => _context.StringGetters[i](_source.Current); +// public string GetString(int i) => _context.StringGetters[i](_source.Current); - public int GetValues(object[] values) => throw new NotImplementedException(); +// public int GetValues(object[] values) => throw new NotImplementedException(); - public bool IsDBNull(int i) - { - // short circuit for non-nullable types - if (!_context.NullableIndexes[i]) - { - return false; - } +// public bool IsDBNull(int i) +// { +// // short circuit for non-nullable types +// if (!_context.NullableIndexes[i]) +// { +// return false; +// } - // otherwise find the first one -- starting w/ most occurring to least +// // otherwise find the first one -- starting w/ most occurring to least - var ig = _context.NullableIntGetters[i]; - if (ig != null) - { - return ig(_source.Current) == null; - } +// var ig = _context.NullableIntGetters[i]; +// if (ig != null) +// { +// return ig(_source.Current) == null; +// } - var sg = _context.StringGetters[i]; - if (sg != null) - { - return sg(_source.Current) == null; - } +// var sg = _context.StringGetters[i]; +// if (sg != null) +// { +// return sg(_source.Current) == null; +// } - var bg = _context.NullableBoolGetters[i]; - if (bg != null) - { - return bg(_source.Current) == null; - } +// var bg = _context.NullableBoolGetters[i]; +// if (bg != null) +// { +// return bg(_source.Current) == null; +// } - var dg = _context.NullableDecimalGetters[i]; - if (dg != null) - { - return dg(_source.Current) == null; - } +// var dg = _context.NullableDecimalGetters[i]; +// if (dg != null) +// { +// return dg(_source.Current) == null; +// } - return false; - } +// return false; +// } - private void Dispose(bool disposing) - { - if (!_disposedValue) - { - if (disposing) - { - this.Close(); - } - - _disposedValue = true; - } - } - - public void Dispose() - { - // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method - Dispose(disposing: true); - GC.SuppressFinalize(this); - } - } - } -} +// private void Dispose(bool disposing) +// { +// if (!_disposedValue) +// { +// if (disposing) +// { +// this.Close(); +// } + +// _disposedValue = true; +// } +// } + +// public void Dispose() +// { +// // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method +// Dispose(disposing: true); +// GC.SuppressFinalize(this); +// } +// } +// } +//} diff --git a/tools/props/Versions.props b/tools/props/Versions.props index b7b11684d7..ec4365798c 100644 --- a/tools/props/Versions.props +++ b/tools/props/Versions.props @@ -50,7 +50,6 @@ - 0.11.3 3.1.1 5.2.6 15.9.0 From ee187feb57e2681554d6f26727473970b4b9e6b7 Mon Sep 17 00:00:00 2001 From: Carl Meyertons Date: Tue, 27 Apr 2021 07:57:50 -0500 Subject: [PATCH 4/6] removing test file --- .../SqlBulkCopyTest/NoBoxingValuesTypes.cs | 485 ------------------ 1 file changed, 485 deletions(-) delete mode 100644 src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs deleted file mode 100644 index 2a0f0b537d..0000000000 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/NoBoxingValuesTypes.cs +++ /dev/null @@ -1,485 +0,0 @@ -//// Licensed to the .NET Foundation under one or more agreements. -//// The .NET Foundation licenses this file to you under the MIT license. -//// See the LICENSE file in the project root for more information. - -//using System; -//using System.Collections.Generic; -//using System.Data; -//using System.Linq; -//using System.Linq.Expressions; -//using BenchmarkDotNet.Attributes; -//using BenchmarkDotNet.Configs; -//using BenchmarkDotNet.Diagnosers; -//using BenchmarkDotNet.Jobs; -//using BenchmarkDotNet.Validators; -//using Xunit; - -//namespace Microsoft.Data.SqlClient.ManualTesting.Tests -//{ -// public sealed class NoBoxingValueTypes : IDisposable -// { -// private static readonly string _table = DataTestUtility.GetUniqueNameForSqlServer(nameof(NoBoxingValueTypes)); -// private const int _count = 5000; -// private static readonly ItemToCopy _item; -// private static readonly IEnumerable _items; -// private static readonly IDataReader _reader; - -// private static readonly string _connString = DataTestUtility.TCPConnectionString; -// private bool _disposedValue; - -// private class ItemToCopy -// { -// // keeping this data static so the performance of the benchmark is not varied by the data size & shape -// public int IntColumn { get; } = 123456; -// public bool BoolColumn { get; } = true; -// } - -// static NoBoxingValueTypes() -// { -// _item = new ItemToCopy(); - -// _items = Enumerable.Range(0, _count).Select(x => _item).ToArray(); - -// // It would've been great to use mgravell/FastMember here to make thing logic much cleaner and not have to include the custom reader -// // however, that package is not available on the dotnet-public Nuget source, which is a bummer. -// _reader = new EnumerableDataReaderFactoryBuilder(_table) -// .Add("IntColumn", i => i.IntColumn) -// .Add("BoolColumn", i => i.BoolColumn) -// .BuildFactory() -// .CreateReader(_items) -// ; -// } - -// public NoBoxingValueTypes() -// { -// using (var conn = new SqlConnection(_connString)) -// using (var cmd = conn.CreateCommand()) -// { -// conn.Open(); -// Helpers.TryExecute(cmd, $@" -// CREATE TABLE {_table} ( -// IntColumn INT NOT NULL, -// BoolColumn BIT NOT NULL -// ) -// "); -// } -// } - -// private class RunOnceConfig : ManualConfig -// { -// public RunOnceConfig() -// { -// Add(Job.InProcess.WithLaunchCount(1).WithIterationCount(1).WithWarmupCount(0)); -// Add(MemoryDiagnoser.Default); - -// Add(JitOptimizationsValidator.DontFailOnError); -// } -// } - -// [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer))] -// public void Should_Not_Box() -// { // in debug mode, the double boxing DOES occur as the JIT optimizes less code, which causes the test to fail -//#if DEBUG -// return; -//#else -// //cannot figure out an easy way to get this to work on all platforms - -// var config = new RunOnceConfig(); // cannot use fluent syntax to still support net461 - -// var summary = BenchmarkRunner.Run(config); - -// var numValueTypeColumns = 2; -// var totalBytesWhenBoxed = IntPtr.Size * _count * numValueTypeColumns; - -// var report = summary.Reports.First(); - -// Assert.Equal(1, report.AllMeasurements.Count); -// Assert.True(report.GcStats.BytesAllocatedPerOperation < totalBytesWhenBoxed); -//#endif -// } - -// public class NoBoxingValueTypesBenchmark -// { -// [Benchmark] -// public void BulkCopy() -// { -// _reader.Close(); // this resets the reader - -// using (var bc = new SqlBulkCopy(DataTestUtility.TCPConnectionString, SqlBulkCopyOptions.TableLock)) -// { -// bc.BatchSize = _count; -// bc.DestinationTableName = _table; -// bc.BulkCopyTimeout = 60; - -// bc.WriteToServer(_reader); -// } -// } -// } - -// private void Dispose(bool disposing) -// { -// if (!_disposedValue) -// { -// if (disposing) -// { -// _reader.Dispose(); -// using (var conn = new SqlConnection(_connString)) -// using (var cmd = conn.CreateCommand()) -// { -// conn.Open(); -// Helpers.TryExecute(cmd, $@" -// DROP TABLE IF EXISTS {_table} -// "); -// } -// } - -// _disposedValue = true; -// } -// } - -// public void Dispose() -// { -// // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method -// Dispose(disposing: true); -// GC.SuppressFinalize(this); -// } - -// //all code here and below is a custom data reader implementation to support the benchmark -// private sealed class EnumerableDataReaderFactoryBuilder: IDisposable -// { -// private readonly List _expressions = new List(); -// private readonly List> _objExpressions = new List>(); -// private readonly DataTable _schemaTable; - -// public EnumerableDataReaderFactoryBuilder(string tableName) -// { -// Name = tableName; -// _schemaTable = new DataTable(); -// } - -// private static readonly HashSet _validTypes = new HashSet -// { -// typeof(decimal), -// typeof(decimal?), -// typeof(string), -// typeof(int), -// typeof(int?), -// typeof(double), -// typeof(bool), -// typeof(bool?), -// typeof(Guid), -// typeof(DateTime), -// }; -// private bool _disposedValue; - -// public EnumerableDataReaderFactoryBuilder Add(string column, Expression> expression) -// { -// var t = typeof(TColumn); - -// var func = expression.Compile(); - -// // don't do any optimizations for boxing bools here to detect boxing occurring properly. -// Expression> objExpression = o => func(o); - -// _objExpressions.Add(objExpression.Compile()); - -// if (_validTypes.Contains(t)) -// { -// t = Nullable.GetUnderlyingType(t) ?? t; // data table doesn't accept nullable. -// _schemaTable.Columns.Add(column, t); -// _expressions.Add(expression); -// } -// else -// { -// Console.WriteLine($"Could not matching return type for {Name}.{column} of: {t.Name}"); -// _schemaTable.Columns.Add(column); //add w/o type to force using GetValue - -// _expressions.Add(objExpression); -// } - -// return this; -// } - -// public EnumerableDataReaderFactory BuildFactory() => new EnumerableDataReaderFactory(_schemaTable, _expressions, _objExpressions); - -// public string Name { get; } - -// private void Dispose(bool disposing) -// { -// if (!_disposedValue) -// { -// if (disposing) -// { -// _schemaTable.Dispose(); -// } - -// _disposedValue = true; -// } -// } - -// public void Dispose() -// { -// // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method -// Dispose(disposing: true); -// GC.SuppressFinalize(this); -// } -// } - -// public class EnumerableDataReaderFactory -// { -// public DataTable SchemaTable { get; } -// public Func[] ObjectGetters { get; } -// public Func[] DecimalGetters { get; } -// public Func[] NullableDecimalGetters { get; } -// public Func[] StringGetters { get; } -// public Func[] DoubleGetters { get; } -// public Func[] IntGetters { get; } -// public Func[] NullableIntGetters { get; } -// public Func[] BoolGetters { get; } - -// public Func[] NullableBoolGetters { get; } - -// public Func[] GuidGetters { get; } -// public Func[] DateTimeGetters { get; } -// public bool[] NullableIndexes { get; } - -// public EnumerableDataReaderFactory(DataTable schemaTable, List expressions, List> objectGetters) -// { -// SchemaTable = schemaTable; -// DecimalGetters = new Func[expressions.Count]; -// NullableDecimalGetters = new Func[expressions.Count]; -// StringGetters = new Func[expressions.Count]; -// DoubleGetters = new Func[expressions.Count]; -// IntGetters = new Func[expressions.Count]; -// NullableIntGetters = new Func[expressions.Count]; -// BoolGetters = new Func[expressions.Count]; -// NullableBoolGetters = new Func[expressions.Count]; -// GuidGetters = new Func[expressions.Count]; -// DateTimeGetters = new Func[expressions.Count]; -// NullableIndexes = new bool[expressions.Count]; - -// ObjectGetters = objectGetters.ToArray(); - -// for (int i = 0; i < expressions.Count; i++) -// { -// var expression = expressions[i]; - -// NullableIndexes[i] = !expression.ReturnType.IsValueType || Nullable.GetUnderlyingType(expression.ReturnType) != null; - -// switch (expression) -// { -// case Expression> e: -// break; // do nothing -// case Expression> e: -// DecimalGetters[i] = e.Compile(); -// break; -// case Expression> e: -// NullableDecimalGetters[i] = e.Compile(); -// break; -// case Expression> e: -// StringGetters[i] = e.Compile(); -// break; -// case Expression> e: -// DoubleGetters[i] = e.Compile(); -// break; -// case Expression> e: -// IntGetters[i] = e.Compile(); -// break; -// case Expression> e: -// NullableIntGetters[i] = e.Compile(); -// break; -// case Expression> e: -// BoolGetters[i] = e.Compile(); -// break; -// case Expression> e: -// NullableBoolGetters[i] = e.Compile(); -// break; -// case Expression> e: -// GuidGetters[i] = e.Compile(); -// break; -// case Expression> e: -// DateTimeGetters[i] = e.Compile(); -// break; -// default: -// throw new Exception($"Type missing: {expression.GetType().FullName}"); -// } -// } -// } - -// public IDataReader CreateReader(IEnumerable items) => new EnumerableDataReader(this, items.GetEnumerator()); -// } - -// public sealed class EnumerableDataReader : IDataReader -// { -// private readonly IEnumerator _source; -// private readonly EnumerableDataReaderFactory _context; -// private bool _disposedValue; - -// public EnumerableDataReader(EnumerableDataReaderFactory context, IEnumerator source) -// { -// _source = source; -// _context = context; -// } - -// public object GetValue(int i) -// { -// var v = _context.ObjectGetters[i](_source.Current); -// return v; -// } - -// public int FieldCount => _context.ObjectGetters.Length; - -// public bool Read() => _source.MoveNext(); - -// public void Close() => _source.Reset(); - -// public bool NextResult() => throw new NotImplementedException(); - -// public int Depth => 0; - -// public bool IsClosed => false; - -// public int RecordsAffected => -1; - -// public DataTable GetSchemaTable() => _context.SchemaTable; - -// public object this[string name] => throw new NotImplementedException(); - -// public object this[int i] => GetValue(i); - -// public bool GetBoolean(int i) -// { -// var g = _context.BoolGetters[i]; - -// if (g != null) -// return g(_source.Current); - -// return _context.NullableBoolGetters[i](_source.Current).Value; -// } - -// public byte GetByte(int i) => throw new NotImplementedException(); - -// public long GetBytes(int i, long fieldOffset, byte[] buffer, int bufferoffset, int length) => throw new NotImplementedException(); - -// public char GetChar(int i) => throw new NotImplementedException(); -// public long GetChars(int i, long fieldoffset, char[] buffer, int bufferoffset, int length) => -1; - -// public IDataReader GetData(int i) => throw new NotImplementedException(); - -// public string GetDataTypeName(int i) => throw new NotImplementedException(); - -// public DateTime GetDateTime(int i) => _context.DateTimeGetters[i](_source.Current); - -// public decimal GetDecimal(int i) -// { -// var g = _context.DecimalGetters[i]; - -// if (g != null) -// return g(_source.Current); - -// return _context.NullableDecimalGetters[i](_source.Current).Value; -// } - -// public double GetDouble(int i) => _context.DoubleGetters[i](_source.Current); - -// public Type GetFieldType(int i) => _context.SchemaTable.Columns[i].DataType; - -// public float GetFloat(int i) => throw new NotImplementedException(); - -// public Guid GetGuid(int i) => _context.GuidGetters[i](_source.Current); - -// public short GetInt16(int i) => throw new NotImplementedException(); - -// public int GetInt32(int i) -// { -// var g = _context.IntGetters[i]; - -// if (g != null) -// return g(_source.Current); - -// return _context.NullableIntGetters[i](_source.Current).Value; -// } - -// public long GetInt64(int i) => throw new NotImplementedException(); - -// public string GetName(int i) -// { -// if (_context.SchemaTable.Columns.Count > i) -// { -// return _context.SchemaTable.Columns[i].ColumnName; -// } -// throw new IndexOutOfRangeException($"No column for index {i}"); -// } - -// public int GetOrdinal(string name) -// { -// if (_context.SchemaTable.Columns.Count == 0) -// { -// throw new Exception("Schema table is empty"); -// } -// return _context.SchemaTable.Columns.IndexOf(name); -// } - -// public string GetString(int i) => _context.StringGetters[i](_source.Current); - -// public int GetValues(object[] values) => throw new NotImplementedException(); - -// public bool IsDBNull(int i) -// { -// // short circuit for non-nullable types -// if (!_context.NullableIndexes[i]) -// { -// return false; -// } - -// // otherwise find the first one -- starting w/ most occurring to least - -// var ig = _context.NullableIntGetters[i]; -// if (ig != null) -// { -// return ig(_source.Current) == null; -// } - -// var sg = _context.StringGetters[i]; -// if (sg != null) -// { -// return sg(_source.Current) == null; -// } - -// var bg = _context.NullableBoolGetters[i]; -// if (bg != null) -// { -// return bg(_source.Current) == null; -// } - -// var dg = _context.NullableDecimalGetters[i]; -// if (dg != null) -// { -// return dg(_source.Current) == null; -// } - -// return false; -// } - -// private void Dispose(bool disposing) -// { -// if (!_disposedValue) -// { -// if (disposing) -// { -// this.Close(); -// } - -// _disposedValue = true; -// } -// } - -// public void Dispose() -// { -// // Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method -// Dispose(disposing: true); -// GC.SuppressFinalize(this); -// } -// } -// } -//} From bd390a36f1eeca551275cdcc3fc1af7019fff7b8 Mon Sep 17 00:00:00 2001 From: Carl Meyertons Date: Tue, 27 Apr 2021 11:51:07 -0500 Subject: [PATCH 5/6] don't invoke GetType --- .../netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs index 7fd2dec220..d75c27e9b7 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlParameter.cs @@ -2099,7 +2099,9 @@ internal static bool CoerceValueIfNeeded(T value, MetaType destinationType, o objValue = null; coercedToDataFeed = false; var typeChanged = false; - Type currentType = value.GetType(); + Type currentType = typeof(T) == typeof(object) + ? value.GetType() + : typeof(T); if ( (destinationType.ClassType != typeof(object)) && From f32394a05a3a9369e5ede4f77b51c19f90b15daf Mon Sep 17 00:00:00 2001 From: Carl Meyertons Date: Wed, 28 Apr 2021 08:49:39 -0500 Subject: [PATCH 6/6] method consolidation --- .../Microsoft/Data/SqlClient/SqlBulkCopy.cs | 88 +++++++++---------- 1 file changed, 40 insertions(+), 48 deletions(-) diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 8a5dae6785..843967781a 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -9,6 +9,7 @@ using System.Data.Common; using System.Data.SqlTypes; using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -1558,23 +1559,12 @@ private bool ValidateBulkCopyVariantIfNeeded(T value, out object variantValue } } - private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, bool isNull, bool isSqlType) + private bool ConvertValueIfNeeded(T value, _SqlMetaData metadata, ref bool isSqlType, out bool coercedToDataFeed, out object convertedValue) { - bool coercedToDataFeed = false; - - if (isNull) - { - if (!metadata.IsNullable) - { - throw SQL.BulkLoadBulkLoadNotAllowDBNull(metadata.column); - } - - return DoWriteValueAsync(value, col, isSqlType, coercedToDataFeed, isNull, metadata); - } - MetaType type = metadata.metaType; bool typeChanged = false; - object objValue = null; + coercedToDataFeed = false; + convertedValue = null; // If the column is encrypted then we are going to transparently encrypt this column // (based on connection string setting)- Use the metaType for the underlying @@ -1642,7 +1632,7 @@ private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, typeChanged = false; // Setting this to false as SqlParameter.CoerceValue will only set it to true when converting to a CLR type // returning here to avoid unnecessary decValue initialization for all types - return WriteConvertedValue(sqlValue, col, isSqlType, isNull, coercedToDataFeed, metadata); + return typeChanged; case TdsEnums.SQLINTN: case TdsEnums.SQLFLTN: @@ -1666,17 +1656,17 @@ private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, case TdsEnums.SQLDATETIME2: case TdsEnums.SQLDATETIMEOFFSET: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out convertedValue, out coercedToDataFeed); break; case TdsEnums.SQLNCHAR: case TdsEnums.SQLNVARCHAR: case TdsEnums.SQLNTEXT: mt = MetaType.GetMetaTypeFromSqlDbType(type.SqlDbType, false); - typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out objValue, out coercedToDataFeed, false); + typeChanged = SqlParameter.CoerceValueIfNeeded(value, mt, out convertedValue, out coercedToDataFeed, false); if (!coercedToDataFeed) { // We do not need to test for TextDataFeed as it is only assigned to (N)VARCHAR(MAX) string str = typeChanged - ? (string)objValue + ? (string)convertedValue : isSqlType ? value.GenericCast().Value : value.GenericCast() @@ -1700,7 +1690,7 @@ private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, } break; case TdsEnums.SQLVARIANT: - typeChanged = ValidateBulkCopyVariantIfNeeded(value, out objValue); + typeChanged = ValidateBulkCopyVariantIfNeeded(value, out convertedValue); break; case TdsEnums.SQLUDT: // UDTs are sent as varbinary so we need to get the raw bytes @@ -1711,7 +1701,7 @@ private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, // in byte[] form. if (!(value is byte[])) { - objValue = _connection.GetBytes(value); + convertedValue = _connection.GetBytes(value); typeChanged = true; } break; @@ -1720,7 +1710,7 @@ private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, Debug.Assert((value is XmlReader) || (value is SqlCachedBuffer) || (value is string) || (value is SqlString) || (value is XmlDataFeed), "Invalid value type of Xml datatype"); if (value is XmlReader xmlReader) { - objValue = new XmlDataFeed(xmlReader); + convertedValue = new XmlDataFeed(xmlReader); typeChanged = true; coercedToDataFeed = true; } @@ -1740,16 +1730,7 @@ private Task ConvertWriteValueAsync(T value, int col, _SqlMetaData metadata, throw SQL.BulkLoadCannotConvertValue(value.GetType(), type, metadata.ordinal, RowNumber, metadata.isEncrypted, metadata.column, value.ToString(), e); } - if (typeChanged) - { - // All type changes change to CLR types - isSqlType = false; - return WriteConvertedValue(objValue, col, isSqlType, isNull, coercedToDataFeed, metadata); - } - else - { - return WriteConvertedValue(value, col, isSqlType, isNull, coercedToDataFeed, metadata); - } + return typeChanged; } /// @@ -2276,33 +2257,44 @@ private bool FireRowsCopiedEvent(long rowsCopied) private Task WriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull) { _SqlMetaData metadata = _sortedColumnMappings[col]._metadata; + object convertedValue = null; + bool isTypeChanged = false; if (isDataFeed) { //nothing to convert, skip straight to write - return DoWriteValueAsync(value, col, isSqlType, isDataFeed, isNull, metadata); } - else + else if (isNull) { - return ConvertWriteValueAsync(value, col, metadata, isNull, isSqlType); - } - } - - private Task WriteConvertedValue(T value, int col, bool isSqlType, bool isNull, bool isDatafeed, _SqlMetaData metadata) - { - // If column encryption is requested via connection string option, perform encryption here - if (!isNull && // if value is not NULL - metadata.isEncrypted) - { // If we are transparently encrypting - Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy()); - var bytesValue = _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDatafeed, isSqlType); - isSqlType = false; // Its not a sql type anymore + if (!metadata.IsNullable) + { + throw SQL.BulkLoadBulkLoadNotAllowDBNull(metadata.column); + } - return DoWriteValueAsync(bytesValue, col, isSqlType, isDatafeed, isNull, metadata); + // don't need to convert nulls } else { - return DoWriteValueAsync(value, col, isSqlType, isDatafeed, isNull, metadata); + isTypeChanged = ConvertValueIfNeeded(value, metadata, ref isSqlType, out isDataFeed, out convertedValue); + + // If column encryption is requested via connection string option, perform encryption here + if (metadata.isEncrypted) // If we are transparently encrypting + { + Debug.Assert(_parser.ShouldEncryptValuesForBulkCopy()); + + convertedValue = isTypeChanged + ? _parser.EncryptColumnValue(convertedValue, metadata, metadata.column, _stateObj, isDataFeed, isSqlType) + : _parser.EncryptColumnValue(value, metadata, metadata.column, _stateObj, isDataFeed, isSqlType) + ; + + isTypeChanged = true; // we should use converted value from here on. + isSqlType = false; // Its not a sql type anymore + } } + + return isTypeChanged + ? DoWriteValueAsync(convertedValue, col, isSqlType, isDataFeed, isNull, metadata) + : DoWriteValueAsync(value, col, isSqlType, isDataFeed, isNull, metadata) + ; } private Task DoWriteValueAsync(T value, int col, bool isSqlType, bool isDataFeed, bool isNull, _SqlMetaData metadata)