Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved decimal scale conversion #470

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions BUILDGUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,9 @@ Tests can be built and run with custom Target Frameworks. See the below examples
Managed SNI can be enabled on Windows by enabling the below AppContext switch:

**"Microsoft.Data.SqlClient.UseManagedNetworkingOnWindows"**

## Set truncation on for scaled decimal parameters

Scaled decimal parameter truncation can be enabled by enabling the below AppContext switch:

**"Switch.Microsoft.Data.SqlClient.TruncateScaledDecimal"**
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ internal sealed partial class TdsParser
// Constants
private const int constBinBufferSize = 4096; // Size of the buffer used to read input parameter of type Stream
private const int constTextBufferSize = 4096; // Size of the buffer (in chars) user to read input parameter of type TextReader
private const string enableTruncateSwitch = "Switch.Microsoft.Data.SqlClient.TruncateScaledDecimal"; // for applications that need to maintain backwards compatibility with the previous behavior

// State variables
internal TdsParserState _state = TdsParserState.Closed; // status flag for connection
Expand Down Expand Up @@ -183,6 +184,16 @@ internal SqlInternalConnectionTds Connection
}
}

private static bool EnableTruncateSwitch
{
get
{
bool value;
value = AppContext.TryGetSwitch(enableTruncateSwitch, out value) ? value : false;
DavoudEshtehari marked this conversation as resolved.
Show resolved Hide resolved
return value;
}
}

internal SqlInternalTransaction CurrentTransaction
{
get
Expand Down Expand Up @@ -7003,7 +7014,8 @@ internal static SqlDecimal AdjustSqlDecimalScale(SqlDecimal d, int newScale)
{
if (d.Scale != newScale)
{
return SqlDecimal.AdjustScale(d, newScale - d.Scale, false /* Don't round, truncate. */);
bool round = !EnableTruncateSwitch;
return SqlDecimal.AdjustScale(d, newScale - d.Scale, round);
}

return d;
Expand All @@ -7015,9 +7027,9 @@ internal static decimal AdjustDecimalScale(decimal value, int newScale)

if (newScale != oldScale)
{
bool round = !EnableTruncateSwitch;
SqlDecimal num = new SqlDecimal(value);

num = SqlDecimal.AdjustScale(num, newScale - oldScale, false /* Don't round, truncate. */);
num = SqlDecimal.AdjustScale(num, newScale - oldScale, round);
return num.Value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ internal static void Assert(string message)
// Constants
const int constBinBufferSize = 4096; // Size of the buffer used to read input parameter of type Stream
const int constTextBufferSize = 4096; // Size of the buffer (in chars) user to read input parameter of type TextReader
private const string enableTruncateSwitch = "Switch.Microsoft.Data.SqlClient.TruncateScaledDecimal"; // for applications that need to maintain backwards compatibility with the previous behavior

// State variables
internal TdsParserState _state = TdsParserState.Closed; // status flag for connection
Expand Down Expand Up @@ -309,6 +310,16 @@ internal SqlInternalConnectionTds Connection
}
}

private static bool EnableTruncateSwitch
{
get
{
bool value;
value = AppContext.TryGetSwitch(enableTruncateSwitch, out value) ? value : false;
return value;
}
}

internal SqlInternalTransaction CurrentTransaction
{
get
Expand Down Expand Up @@ -2046,7 +2057,7 @@ internal bool RunReliably(RunBehavior runBehavior, SqlCommand cmdHandler, SqlDat
{
tdsReliabilitySection.Start();
#endif //DEBUG
return Run(runBehavior, cmdHandler, dataStream, bulkCopyHandler, stateObj);
return Run(runBehavior, cmdHandler, dataStream, bulkCopyHandler, stateObj);
#if DEBUG
}
finally
Expand Down Expand Up @@ -7726,7 +7737,8 @@ static internal SqlDecimal AdjustSqlDecimalScale(SqlDecimal d, int newScale)
{
if (d.Scale != newScale)
{
return SqlDecimal.AdjustScale(d, newScale - d.Scale, false /* Don't round, truncate. MDAC 69229 */);
bool round = !EnableTruncateSwitch;
return SqlDecimal.AdjustScale(d, newScale - d.Scale, round);
}

return d;
Expand All @@ -7738,9 +7750,10 @@ static internal decimal AdjustDecimalScale(decimal value, int newScale)

if (newScale != oldScale)
{
bool round = !EnableTruncateSwitch;
SqlDecimal num = new SqlDecimal(value);

num = SqlDecimal.AdjustScale(num, newScale - oldScale, false /* Don't round, truncate. MDAC 69229 */);
num = SqlDecimal.AdjustScale(num, newScale - oldScale, round);
return num.Value;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ public static class DataTestUtility

private static Dictionary<string, bool> AvailableDatabases;
private static TraceEventListener TraceListener;
public static IEnumerable<string> ConnectionStrings
{
get
{
if (!string.IsNullOrEmpty(TCPConnectionString))
{
yield return TCPConnectionString;
}
else if (!string.IsNullOrEmpty(NPConnectionString))
{
yield return NPConnectionString;
}
foreach (string connStrAE in AEConnStrings)
{
yield return connStrAE;
}
}
}

private class Config
{
Expand Down Expand Up @@ -309,6 +327,30 @@ public static string GetUniqueNameForSqlServer(string prefix)
return name;
}

public static void DropTable(SqlConnection sqlConnection, string tableName)
{
using (SqlCommand cmd = new SqlCommand(string.Format("IF (OBJECT_ID('{0}') IS NOT NULL) \n DROP TABLE {0}", tableName), sqlConnection))
{
cmd.ExecuteNonQuery();
}
}

public static void DropUserDefinedType(SqlConnection sqlConnection, string typeName)
{
using (SqlCommand cmd = new SqlCommand(string.Format("IF (TYPE_ID('{0}') IS NOT NULL) \n DROP TYPE {0}", typeName), sqlConnection))
{
cmd.ExecuteNonQuery();
}
}

public static void DropStoredProcedure(SqlConnection sqlConnection, string spName)
{
using (SqlCommand cmd = new SqlCommand(string.Format("IF (OBJECT_ID('{0}') IS NOT NULL) \n DROP PROCEDURE {0}", spName), sqlConnection))
{
cmd.ExecuteNonQuery();
}
}

public static bool IsLocalDBInstalled() => SupportsLocalDb;

public static bool IsIntegratedSecuritySetup() => SupportsIntegratedSecurity;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Data.SqlTypes;
using Xunit;
Expand Down Expand Up @@ -314,6 +315,202 @@ public static void TestParametersWithDatatablesTVPInsert()
}
}

#region Scaled Decimal Parameter & TVP Test
[Theory]
[ClassData(typeof(ConnectionStringsProvider))]
public static void TestScaledDecimalParameter_CommandInsert(string connectionString, bool truncateScaledDecimal)
{
string tableName = DataTestUtility.GetUniqueNameForSqlServer("TestDecimalParameterCMD");
using (SqlConnection connection = InitialDatabaseTable(connectionString, tableName))
{
try
{
using (SqlCommand cmd = connection.CreateCommand())
{
AppContext.SetSwitch(truncateDecimalSwitch, truncateScaledDecimal);
var p = new SqlParameter("@Value", null);
p.Precision = 18;
p.Scale = 2;
cmd.Parameters.Add(p);
for (int i = 0; i < _testValues.Length; i++)
{
p.Value = _testValues[i];
cmd.CommandText = $"INSERT INTO {tableName} (Id, [Value]) VALUES({i}, @Value)";
cmd.ExecuteNonQuery();
}
}
Assert.True(ValidateInsertedValues(connection, tableName, truncateScaledDecimal), $"Invalid test happened with connection string [{connection.ConnectionString}]");
}
finally
{
DataTestUtility.DropTable(connection, tableName);
}
}
}

[Theory]
[ClassData(typeof(ConnectionStringsProvider))]
public static void TestScaledDecimalParameter_BulkCopy(string connectionString, bool truncateScaledDecimal)
{
string tableName = DataTestUtility.GetUniqueNameForSqlServer("TestDecimalParameterBC");
using (SqlConnection connection = InitialDatabaseTable(connectionString, tableName))
{
try
{
using (SqlBulkCopy bulkCopy = new SqlBulkCopy(connection))
{
DataTable table = new DataTable(tableName);
table.Columns.Add("Id", typeof(int));
table.Columns.Add("Value", typeof(decimal));
for (int i = 0; i < _testValues.Length; i++)
{
var newRow = table.NewRow();
newRow["Id"] = i;
newRow["Value"] = _testValues[i];
table.Rows.Add(newRow);
}

bulkCopy.DestinationTableName = tableName;
AppContext.SetSwitch(truncateDecimalSwitch, truncateScaledDecimal);
bulkCopy.WriteToServer(table);
}
Assert.True(ValidateInsertedValues(connection, tableName, truncateScaledDecimal), $"Invalid test happened with connection string [{connection.ConnectionString}]");
}
finally
{
DataTestUtility.DropTable(connection, tableName);
}
}
}

[Theory]
[ClassData(typeof(ConnectionStringsProvider))]
public static void TestScaledDecimalTVP_CommandSP(string connectionString, bool truncateScaledDecimal)
{
string tableName = DataTestUtility.GetUniqueNameForSqlServer("TestDecimalParameterBC");
string tableTypeName = DataTestUtility.GetUniqueNameForSqlServer("UDTTTestDecimalParameterBC");
string spName = DataTestUtility.GetUniqueNameForSqlServer("spTestDecimalParameterBC");
using (SqlConnection connection = InitialDatabaseUDTT(connectionString, tableName, tableTypeName, spName))
{
try
{
using (SqlCommand cmd = connection.CreateCommand())
{
var p = new SqlParameter("@tvp", SqlDbType.Structured);
p.TypeName = $"dbo.{tableTypeName}";
cmd.CommandText = spName;
cmd.CommandType = CommandType.StoredProcedure;
cmd.Parameters.Add(p);

DataTable table = new DataTable(tableName);
table.Columns.Add("Id", typeof(int));
table.Columns.Add("Value", typeof(decimal));
for (int i = 0; i < _testValues.Length; i++)
{
var newRow = table.NewRow();
newRow["Id"] = i;
newRow["Value"] = _testValues[i];
table.Rows.Add(newRow);
}
p.Value = table;
AppContext.SetSwitch(truncateDecimalSwitch, truncateScaledDecimal);
cmd.ExecuteNonQuery();
}
// TVP always rounds data without attention to the configuration.
Assert.True(ValidateInsertedValues(connection, tableName, false && truncateScaledDecimal), $"Invalid test happened with connection string [{connection.ConnectionString}]");
}
finally
{
DataTestUtility.DropTable(connection, tableName);
DataTestUtility.DropStoredProcedure(connection, spName);
DataTestUtility.DropUserDefinedType(connection, tableTypeName);
}
}
}

#region Decimal parameter test setup
private static readonly decimal[] _testValues = new[] { 4210862852.8600000000_0000000000m, 19.1560m, 19.1550m, 19.1549m };
private static readonly decimal[] _expectedRoundedValues = new[] { 4210862852.86m, 19.16m, 19.16m, 19.15m };
private static readonly decimal[] _expectedTruncatedValues = new[] { 4210862852.86m, 19.15m, 19.15m, 19.15m };
private const string truncateDecimalSwitch = "Switch.Microsoft.Data.SqlClient.TruncateScaledDecimal";

private static SqlConnection InitialDatabaseUDTT(string cnnString, string tableName, string tableTypeName, string spName)
{
SqlConnection connection = new SqlConnection(cnnString);
connection.Open();
using (SqlCommand cmd = connection.CreateCommand())
{
cmd.CommandType = CommandType.Text;
cmd.CommandText = $"CREATE TABLE {tableName} (Id INT, Value Decimal(38, 2)) \n";
cmd.CommandText += $"CREATE TYPE {tableTypeName} AS TABLE (Id INT, Value Decimal(38, 2)) ";
cmd.ExecuteNonQuery();
cmd.CommandText = $"CREATE PROCEDURE {spName} (@tvp {tableTypeName} READONLY) AS \n INSERT INTO {tableName} (Id, Value) SELECT * FROM @tvp ORDER BY Id";
cmd.ExecuteNonQuery();
}
return connection;
}

private static SqlConnection InitialDatabaseTable(string cnnString, string tableName)
{
SqlConnection connection = new SqlConnection(cnnString);
connection.Open();
using (SqlCommand cmd = connection.CreateCommand())
{
cmd.CommandType = CommandType.Text;
cmd.CommandText = $"CREATE TABLE {tableName} (Id INT, Value Decimal(38, 2))";
cmd.ExecuteNonQuery();
}
return connection;
}

private static bool ValidateInsertedValues(SqlConnection connection, string tableName, bool truncateScaledDecimal)
{
bool exceptionHit;
decimal[] expectedValues = truncateScaledDecimal ? _expectedTruncatedValues : _expectedRoundedValues;

try
{
using (SqlCommand cmd = connection.CreateCommand())
{
// Verify if the data was as same as our expectation.
cmd.CommandText = $"SELECT [Value] FROM {tableName} ORDER BY Id ASC";
cmd.CommandType = CommandType.Text;
using (SqlDataReader reader = cmd.ExecuteReader())
{
DataTable dbData = new DataTable();
dbData.Load(reader);
Assert.Equal(expectedValues.Length, dbData.Rows.Count);
for (int i = 0; i < expectedValues.Length; i++)
{
Assert.Equal(expectedValues[i], dbData.Rows[i][0]);
}
}
}
exceptionHit = false;
}
catch
{
exceptionHit = true;
}
return !exceptionHit;
}

public class ConnectionStringsProvider : IEnumerable<object[]>
cheenamalhotra marked this conversation as resolved.
Show resolved Hide resolved
{
public IEnumerator<object[]> GetEnumerator()
{
foreach (var cnnString in DataTestUtility.ConnectionStrings)
{
yield return new object[] { cnnString, false };
yield return new object[] { cnnString, true };
}
}

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
#endregion
#endregion

private enum MyEnum
{
A = 1,
Expand Down
Loading