diff --git a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 273394395a..92c6d0075c 100644 --- a/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -1640,14 +1640,12 @@ public void WriteToServer(DbDataReader reader) try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowSource = reader; _dbDataReaderRowSource = reader; _sqlDataReaderRowSource = reader as SqlDataReader; - - _dataTableSource = null; _rowSourceType = ValueSourceType.DbDataReader; - _isAsyncBulkCopy = false; WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } finally @@ -1673,12 +1671,11 @@ public void WriteToServer(IDataReader reader) try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowSource = reader; _sqlDataReaderRowSource = _rowSource as SqlDataReader; _dbDataReaderRowSource = _rowSource as DbDataReader; - _dataTableSource = null; _rowSourceType = ValueSourceType.IDataReader; - _isAsyncBulkCopy = false; WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } finally @@ -1707,13 +1704,12 @@ public void WriteToServer(DataTable table, DataRowState rowState) try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowStateToSkip = ((rowState == 0) || (rowState == DataRowState.Deleted)) ? DataRowState.Deleted : ~rowState | DataRowState.Deleted; _rowSource = table; _dataTableSource = table; - _sqlDataReaderRowSource = null; _rowSourceType = ValueSourceType.DataTable; _rowEnumerator = table.Rows.GetEnumerator(); - _isAsyncBulkCopy = false; WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } @@ -1746,16 +1742,14 @@ public void WriteToServer(DataRow[] rows) try { statistics = SqlStatistics.StartTimer(Statistics); - + ResetWriteToServerGlobalVariables(); DataTable table = rows[0].Table; Debug.Assert(null != table, "How can we have rows without a table?"); _rowStateToSkip = DataRowState.Deleted; // Don't allow deleted rows _rowSource = rows; _dataTableSource = table; - _sqlDataReaderRowSource = null; _rowSourceType = ValueSourceType.RowArray; _rowEnumerator = rows.GetEnumerator(); - _isAsyncBulkCopy = false; WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } @@ -1787,7 +1781,7 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok try { statistics = SqlStatistics.StartTimer(Statistics); - + ResetWriteToServerGlobalVariables(); if (rows.Length == 0) { return cancellationToken.IsCancellationRequested ? @@ -1800,7 +1794,6 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok _rowStateToSkip = DataRowState.Deleted; // Don't allow deleted rows _rowSource = rows; _dataTableSource = table; - _sqlDataReaderRowSource = null; _rowSourceType = ValueSourceType.RowArray; _rowEnumerator = rows.GetEnumerator(); _isAsyncBulkCopy = true; @@ -1834,10 +1827,10 @@ public Task WriteToServerAsync(DbDataReader reader, CancellationToken cancellati try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowSource = reader; _sqlDataReaderRowSource = reader as SqlDataReader; _dbDataReaderRowSource = reader; - _dataTableSource = null; _rowSourceType = ValueSourceType.DbDataReader; _isAsyncBulkCopy = true; resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); // It returns Task since _isAsyncBulkCopy = true; @@ -1871,10 +1864,10 @@ public Task WriteToServerAsync(IDataReader reader, CancellationToken cancellatio try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowSource = reader; _sqlDataReaderRowSource = _rowSource as SqlDataReader; _dbDataReaderRowSource = _rowSource as DbDataReader; - _dataTableSource = null; _rowSourceType = ValueSourceType.IDataReader; _isAsyncBulkCopy = true; resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); // It returns Task since _isAsyncBulkCopy = true; @@ -1914,9 +1907,9 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowStateToSkip = ((rowState == 0) || (rowState == DataRowState.Deleted)) ? DataRowState.Deleted : ~rowState | DataRowState.Deleted; _rowSource = table; - _sqlDataReaderRowSource = null; _dataTableSource = table; _rowSourceType = ValueSourceType.DataTable; _rowEnumerator = table.Rows.GetEnumerator(); @@ -3093,5 +3086,17 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken) } return resultTask; } + + private void ResetWriteToServerGlobalVariables() + { + _dataTableSource = null; + _dbDataReaderRowSource = null; + _isAsyncBulkCopy = false; + _rowEnumerator = null; + _rowSource = null; + _rowSourceType = ValueSourceType.Unspecified; + _sqlDataReaderRowSource = null; + _sqlDataReaderRowSource = null; + } } } diff --git a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs index 6b65491499..54b03683b7 100644 --- a/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs +++ b/src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlBulkCopy.cs @@ -1689,6 +1689,7 @@ public void WriteToServer(DbDataReader reader) try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowSource = reader; _dbDataReaderRowSource = reader; _sqlDataReaderRowSource = reader as SqlDataReader; @@ -1697,10 +1698,8 @@ public void WriteToServer(DbDataReader reader) { _rowSourceIsSqlDataReaderSmi = _sqlDataReaderRowSource is SqlDataReaderSmi; } - _dataTableSource = null; _rowSourceType = ValueSourceType.DbDataReader; - _isAsyncBulkCopy = false; WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } finally @@ -1728,6 +1727,7 @@ public void WriteToServer(IDataReader reader) try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowSource = reader; _sqlDataReaderRowSource = _rowSource as SqlDataReader; if (_sqlDataReaderRowSource != null) @@ -1735,9 +1735,7 @@ public void WriteToServer(IDataReader reader) _rowSourceIsSqlDataReaderSmi = _sqlDataReaderRowSource is SqlDataReaderSmi; } _dbDataReaderRowSource = _rowSource as DbDataReader; - _dataTableSource = null; _rowSourceType = ValueSourceType.IDataReader; - _isAsyncBulkCopy = false; WriteRowSourceToServerAsync(reader.FieldCount, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } finally @@ -1768,13 +1766,12 @@ public void WriteToServer(DataTable table, DataRowState rowState) try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowStateToSkip = ((rowState == 0) || (rowState == DataRowState.Deleted)) ? DataRowState.Deleted : ~rowState | DataRowState.Deleted; _rowSource = table; _dataTableSource = table; - _sqlDataReaderRowSource = null; _rowSourceType = ValueSourceType.DataTable; _rowEnumerator = table.Rows.GetEnumerator(); - _isAsyncBulkCopy = false; WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } @@ -1809,16 +1806,14 @@ public void WriteToServer(DataRow[] rows) try { statistics = SqlStatistics.StartTimer(Statistics); - + ResetWriteToServerGlobalVariables(); DataTable table = rows[0].Table; Debug.Assert(null != table, "How can we have rows without a table?"); _rowStateToSkip = DataRowState.Deleted; // Don't allow deleted rows _rowSource = rows; _dataTableSource = table; - _sqlDataReaderRowSource = null; _rowSourceType = ValueSourceType.RowArray; _rowEnumerator = rows.GetEnumerator(); - _isAsyncBulkCopy = false; WriteRowSourceToServerAsync(table.Columns.Count, CancellationToken.None); //It returns null since _isAsyncBulkCopy = false; } @@ -1851,7 +1846,7 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok try { statistics = SqlStatistics.StartTimer(Statistics); - + ResetWriteToServerGlobalVariables(); if (rows.Length == 0) { TaskCompletionSource source = new TaskCompletionSource(); @@ -1872,7 +1867,6 @@ public Task WriteToServerAsync(DataRow[] rows, CancellationToken cancellationTok _rowStateToSkip = DataRowState.Deleted; // Don't allow deleted rows _rowSource = rows; _dataTableSource = table; - _sqlDataReaderRowSource = null; _rowSourceType = ValueSourceType.RowArray; _rowEnumerator = rows.GetEnumerator(); _isAsyncBulkCopy = true; @@ -1908,10 +1902,10 @@ public Task WriteToServerAsync(DbDataReader reader, CancellationToken cancellati try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowSource = reader; _sqlDataReaderRowSource = reader as SqlDataReader; _dbDataReaderRowSource = reader; - _dataTableSource = null; _rowSourceType = ValueSourceType.DbDataReader; _isAsyncBulkCopy = true; resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); // It returns Task since _isAsyncBulkCopy = true; @@ -1946,10 +1940,10 @@ public Task WriteToServerAsync(IDataReader reader, CancellationToken cancellatio try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowSource = reader; _sqlDataReaderRowSource = _rowSource as SqlDataReader; _dbDataReaderRowSource = _rowSource as DbDataReader; - _dataTableSource = null; _rowSourceType = ValueSourceType.IDataReader; _isAsyncBulkCopy = true; resultTask = WriteRowSourceToServerAsync(reader.FieldCount, cancellationToken); // It returns Task since _isAsyncBulkCopy = true; @@ -1990,9 +1984,9 @@ public Task WriteToServerAsync(DataTable table, DataRowState rowState, Cancellat try { statistics = SqlStatistics.StartTimer(Statistics); + ResetWriteToServerGlobalVariables(); _rowStateToSkip = ((rowState == 0) || (rowState == DataRowState.Deleted)) ? DataRowState.Deleted : ~rowState | DataRowState.Deleted; _rowSource = table; - _sqlDataReaderRowSource = null; _dataTableSource = table; _rowSourceType = ValueSourceType.DataTable; _rowEnumerator = table.Rows.GetEnumerator(); @@ -3212,5 +3206,17 @@ private Task WriteToServerInternalAsync(CancellationToken ctoken) } return resultTask; } + + private void ResetWriteToServerGlobalVariables() + { + _dataTableSource = null; + _dbDataReaderRowSource = null; + _isAsyncBulkCopy = false; + _rowEnumerator = null; + _rowSource = null; + _rowSourceType = ValueSourceType.Unspecified; + _sqlDataReaderRowSource = null; + _sqlDataReaderRowSource = null; + } } } diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj index da3f55dfc1..728d379f6a 100644 --- a/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/Microsoft.Data.SqlClient.ManualTesting.Tests.csproj @@ -163,6 +163,7 @@ + diff --git a/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/WriteToServerTest.cs b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/WriteToServerTest.cs new file mode 100644 index 0000000000..343a7bcfe2 --- /dev/null +++ b/src/Microsoft.Data.SqlClient/tests/ManualTests/SQL/SqlBulkCopyTest/WriteToServerTest.cs @@ -0,0 +1,138 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Data; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Data.SqlClient.ManualTesting.Tests +{ + public class WriteToServerTest + { + private readonly string _connectionString = null; + private readonly string _tableName1 = DataTestUtility.GetUniqueName("Bulk1"); + private readonly string _tableName2 = DataTestUtility.GetUniqueName("Bulk2"); + + public WriteToServerTest() + { + _connectionString = (new SqlConnectionStringBuilder(DataTestUtility.TCPConnectionString) { MultipleActiveResultSets = true }).ConnectionString; + } + + [ConditionalFact(typeof(DataTestUtility), nameof(DataTestUtility.AreConnStringsSetup), nameof(DataTestUtility.IsNotAzureServer), nameof(DataTestUtility.IsNotAzureSynapse))] + public async Task WriteToServerWithDbReaderFollowedByWriteToServerWithDataRowsShouldSucceed() + { + try + { + SetupTestTables(); + + DataRow[] dataRows = WriteToServerTest.CreateDataRows(); + Assert.Equal(4, dataRows.Length); // Verify the number of rows created + + DoBulkCopy(dataRows); + await DoBulkCopyAsync(dataRows); + } + finally + { + RemoveTestTables(); + } + } + + private void SetupTestTables() + { + // Create the source table and insert some data + using SqlConnection connection = new SqlConnection(DataTestUtility.TCPConnectionString); + connection.Open(); + + DataTestUtility.DropTable(connection, _tableName1); + DataTestUtility.DropTable(connection, _tableName2); + + using SqlCommand command = connection.CreateCommand(); + + Helpers.TryExecute(command, $"create table {_tableName1} (Id int identity primary key, FirstName nvarchar(50), LastName nvarchar(50))"); + Helpers.TryExecute(command, $"create table {_tableName2} (Id int identity primary key, FirstName nvarchar(50), LastName nvarchar(50))"); + + Helpers.TryExecute(command, $"insert into {_tableName1} (Firstname, LastName) values ('John', 'Doe')"); + Helpers.TryExecute(command, $"insert into {_tableName1} (Firstname, LastName) values ('Johnny', 'Smith')"); + Helpers.TryExecute(command, $"insert into {_tableName1} (Firstname, LastName) values ('Jenny', 'Doe')"); + Helpers.TryExecute(command, $"insert into {_tableName1} (Firstname, LastName) values ('Jane', 'Smith')"); + } + + private static DataRow[] CreateDataRows() + { + DataTable table = new DataTable(); + table.Columns.Add("Id", typeof(int)); + table.Columns.Add("FirstName", typeof(string)); + table.Columns.Add("LastName", typeof(string)); + + table.Rows.Add(null, "Aaron", "Washington"); + table.Rows.Add(null, "Barry", "Mannilow"); + table.Rows.Add(null, "Charles", "Babage"); + table.Rows.Add(null, "Dean", "Snipes"); + + return table.Select(); + } + + private void RemoveTestTables() + { + // Simplify the using statement in a small block of code + using SqlConnection connection = new SqlConnection(DataTestUtility.TCPConnectionString); + connection.Open(); + + DataTestUtility.DropTable(connection, _tableName1); + DataTestUtility.DropTable(connection, _tableName2); + } + + private void DoBulkCopy(DataRow[] dataRows) + { + using SqlConnection connection = new SqlConnection(_connectionString); + connection.Open(); + + using SqlCommand command = connection.CreateCommand(); + command.CommandText = $"select * from {_tableName1}"; + + using IDataReader reader = command.ExecuteReader(); + + using SqlBulkCopy bulkCopy = new SqlBulkCopy(connection); + + bulkCopy.DestinationTableName = _tableName2; + + BulkCopy(bulkCopy, reader, dataRows); + } + + private async Task DoBulkCopyAsync(DataRow[] dataRows) + { + // Test should be run with MARS enabled + using SqlConnection connection = new SqlConnection(_connectionString); + await connection.OpenAsync(); + + using SqlCommand command = connection.CreateCommand(); + command.CommandText = $"select * from {_tableName1}"; + + using IDataReader reader = await command.ExecuteReaderAsync(); + + using SqlBulkCopy bulkCopy = new SqlBulkCopy(connection); + + bulkCopy.DestinationTableName = _tableName2; + + await BulkCopyAsync(bulkCopy, reader, dataRows); + } + + private static void BulkCopy(SqlBulkCopy bulkCopy, IDataReader reader, DataRow[] dataRows) + { + bulkCopy.WriteToServer(reader); + Assert.Equal(dataRows.Length, bulkCopy.RowsCopied); // Verify the number of rows copied from the reader + bulkCopy.WriteToServer(dataRows); + Assert.Equal(dataRows.Length, bulkCopy.RowsCopied); // Verify the number of rows copied from the reader + } + + private static async Task BulkCopyAsync(SqlBulkCopy bulkCopy, IDataReader reader, DataRow[] dataRows) + { + await bulkCopy.WriteToServerAsync(reader); + Assert.Equal(dataRows.Length, bulkCopy.RowsCopied); // Verify the number of rows copied from the reader + await bulkCopy.WriteToServerAsync(dataRows); + Assert.Equal(dataRows.Length, bulkCopy.RowsCopied); // Verify the number of rows copied from the reader + } + } +}