From 42a32b2956ef49f95af47f3f821774f1a11a8a4f Mon Sep 17 00:00:00 2001 From: Curt Hagenlocher Date: Fri, 27 Oct 2023 15:13:33 -0700 Subject: [PATCH] GH-38351: [C#] Add SqlDecimal support to Decimal128Array (#38481) ### What changes are included in this PR? Adds support for reading and writing System.Data.SqlTypes.SqlDecimal against Decimal128Array. ### Are these changes tested? Yes. ### Are there any user-facing changes? Adds functions to the API. * Closes: #38351 Authored-by: Curt Hagenlocher Signed-off-by: Curt Hagenlocher --- .../Apache.Arrow/Arrays/Decimal128Array.cs | 40 +++++ csharp/src/Apache.Arrow/DecimalUtility.cs | 49 ++++++ .../Decimal128ArrayTests.cs | 139 +++++++++++++++++- .../Apache.Arrow.Tests/DecimalUtilityTests.cs | 61 +++++++- 4 files changed, 278 insertions(+), 11 deletions(-) diff --git a/csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs b/csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs index 128e9e5f0818e..7b147f5124c1d 100644 --- a/csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs +++ b/csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs @@ -15,6 +15,9 @@ using System; using System.Collections.Generic; +#if !NETSTANDARD1_3 +using System.Data.SqlTypes; +#endif using System.Diagnostics; using System.Numerics; using Apache.Arrow.Arrays; @@ -61,6 +64,31 @@ public Builder AppendRange(IEnumerable values) return Instance; } +#if !NETSTANDARD1_3 + public Builder Append(SqlDecimal value) + { + Span bytes = stackalloc byte[DataType.ByteWidth]; + DecimalUtility.GetBytes(value, DataType.Precision, DataType.Scale, bytes); + + return Append(bytes); + } + + public Builder AppendRange(IEnumerable values) + { + if (values == null) + { + throw new ArgumentNullException(nameof(values)); + } + + foreach (SqlDecimal d in values) + { + Append(d); + } + + return Instance; + } +#endif + public Builder Set(int index, decimal value) { Span bytes = stackalloc byte[DataType.ByteWidth]; @@ -91,5 +119,17 @@ public Decimal128Array(ArrayData data) } return DecimalUtility.GetDecimal(ValueBuffer, index, Scale, ByteWidth); } + +#if !NETSTANDARD1_3 + public SqlDecimal? GetSqlDecimal(int index) + { + if (IsNull(index)) + { + return null; + } + + return DecimalUtility.GetSqlDecimal128(ValueBuffer, index, Precision, Scale); + } +#endif } } diff --git a/csharp/src/Apache.Arrow/DecimalUtility.cs b/csharp/src/Apache.Arrow/DecimalUtility.cs index 4a29d068c6eff..35e56ff65e2ed 100644 --- a/csharp/src/Apache.Arrow/DecimalUtility.cs +++ b/csharp/src/Apache.Arrow/DecimalUtility.cs @@ -14,6 +14,9 @@ // limitations under the License. using System; +#if !NETSTANDARD1_3 +using System.Data.SqlTypes; +#endif using System.Numerics; namespace Apache.Arrow @@ -73,6 +76,32 @@ internal static decimal GetDecimal(in ArrowBuffer valueBuffer, int index, int sc } } +#if !NETSTANDARD1_3 + internal static SqlDecimal GetSqlDecimal128(in ArrowBuffer valueBuffer, int index, int precision, int scale) + { + const int byteWidth = 16; + const int intWidth = byteWidth / 4; + const int longWidth = byteWidth / 8; + + byte mostSignificantByte = valueBuffer.Span[(index + 1) * byteWidth - 1]; + bool isPositive = (mostSignificantByte & 0x80) == 0; + + if (isPositive) + { + ReadOnlySpan value = valueBuffer.Span.CastTo().Slice(index * intWidth, intWidth); + return new SqlDecimal((byte)precision, (byte)scale, true, value[0], value[1], value[2], value[3]); + } + else + { + ReadOnlySpan value = valueBuffer.Span.CastTo().Slice(index * longWidth, longWidth); + long data1 = -value[0]; + long data2 = (data1 == 0) ? -value[1] : ~value[1]; + + return new SqlDecimal((byte)precision, (byte)scale, false, (int)(data1 & 0xffffffff), (int)(data1 >> 32), (int)(data2 & 0xffffffff), (int)(data2 >> 32)); + } + } +#endif + private static decimal DivideByScale(BigInteger integerValue, int scale) { decimal result = (decimal)integerValue; // this cast is safe here @@ -169,5 +198,25 @@ internal static void GetBytes(decimal value, int precision, int scale, int byteW } } } + +#if !NETSTANDARD1_3 + internal static void GetBytes(SqlDecimal value, int precision, int scale, Span bytes) + { + if (value.Precision != precision || value.Scale != scale) + { + value = SqlDecimal.ConvertToPrecScale(value, precision, scale); + } + + // TODO: Consider groveling in the internals to avoid the probable allocation + Span span = bytes.CastTo(); + value.Data.AsSpan().CopyTo(span); + if (!value.IsPositive) + { + Span longSpan = bytes.CastTo(); + longSpan[0] = -longSpan[0]; + longSpan[1] = (longSpan[0] == 0) ? -longSpan[1] : ~longSpan[1]; + } + } +#endif } } diff --git a/csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs b/csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs index 4c4e6537269a4..8d7adfef42b54 100644 --- a/csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs +++ b/csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs @@ -14,7 +14,9 @@ // limitations under the License. using System; -using System.Collections.Generic; +#if !NETSTANDARD1_3 +using System.Data.SqlTypes; +#endif using Apache.Arrow.Types; using Xunit; @@ -22,6 +24,18 @@ namespace Apache.Arrow.Tests { public class Decimal128ArrayTests { +#if !NETSTANDARD1_3 + static SqlDecimal? Convert(decimal? value) + { + return value == null ? null : new SqlDecimal(value.Value); + } + + static decimal? Convert(SqlDecimal? value) + { + return value == null ? null : value.Value.Value; + } +#endif + public class Builder { public class AppendNull @@ -30,7 +44,7 @@ public class AppendNull public void AppendThenGetGivesNull() { // Arrange - var builder = new Decimal128Array.Builder(new Decimal128Type(8,2)); + var builder = new Decimal128Array.Builder(new Decimal128Type(8, 2)); // Act @@ -45,6 +59,12 @@ public void AppendThenGetGivesNull() Assert.Null(array.GetValue(0)); Assert.Null(array.GetValue(1)); Assert.Null(array.GetValue(2)); + +#if !NETSTANDARD1_3 + Assert.Null(array.GetSqlDecimal(0)); + Assert.Null(array.GetSqlDecimal(1)); + Assert.Null(array.GetSqlDecimal(2)); +#endif } } @@ -67,7 +87,7 @@ public void AppendDecimal(int count) testData[i] = null; continue; } - decimal rnd = i * (decimal)Math.Round(new Random().NextDouble(),10); + decimal rnd = i * (decimal)Math.Round(new Random().NextDouble(), 10); testData[i] = rnd; builder.Append(rnd); } @@ -78,6 +98,9 @@ public void AppendDecimal(int count) for (int i = 0; i < count; i++) { Assert.Equal(testData[i], array.GetValue(i)); +#if !NETSTANDARD1_3 + Assert.Equal(Convert(testData[i]), array.GetSqlDecimal(i)); +#endif } } @@ -95,6 +118,11 @@ public void AppendLargeDecimal() var array = builder.Build(); Assert.Equal(large, array.GetValue(0)); Assert.Equal(-large, array.GetValue(1)); + +#if !NETSTANDARD1_3 + Assert.Equal(Convert(large), array.GetSqlDecimal(0)); + Assert.Equal(Convert(-large), array.GetSqlDecimal(1)); +#endif } [Fact] @@ -115,6 +143,13 @@ public void AppendMaxAndMinDecimal() Assert.Equal(Decimal.MinValue, array.GetValue(1)); Assert.Equal(Decimal.MaxValue - 10, array.GetValue(2)); Assert.Equal(Decimal.MinValue + 10, array.GetValue(3)); + +#if !NETSTANDARD1_3 + Assert.Equal(Convert(Decimal.MaxValue), array.GetSqlDecimal(0)); + Assert.Equal(Convert(Decimal.MinValue), array.GetSqlDecimal(1)); + Assert.Equal(Convert(Decimal.MaxValue) - 10, array.GetSqlDecimal(2)); + Assert.Equal(Convert(Decimal.MinValue) + 10, array.GetSqlDecimal(3)); +#endif } [Fact] @@ -131,6 +166,11 @@ public void AppendFractionalDecimal() var array = builder.Build(); Assert.Equal(fraction, array.GetValue(0)); Assert.Equal(-fraction, array.GetValue(1)); + +#if !NETSTANDARD1_3 + Assert.Equal(Convert(fraction), array.GetSqlDecimal(0)); + Assert.Equal(Convert(-fraction), array.GetSqlDecimal(1)); +#endif } [Fact] @@ -138,7 +178,7 @@ public void AppendRangeDecimal() { // Arrange var builder = new Decimal128Array.Builder(new Decimal128Type(24, 8)); - var range = new decimal[] {2.123M, 1.5984M, -0.0000001M, 9878987987987987.1235407M}; + var range = new decimal[] { 2.123M, 1.5984M, -0.0000001M, 9878987987987987.1235407M }; // Act builder.AppendRange(range); @@ -146,12 +186,15 @@ public void AppendRangeDecimal() // Assert var array = builder.Build(); - for(int i = 0; i < range.Length; i ++) + for (int i = 0; i < range.Length; i++) { Assert.Equal(range[i], array.GetValue(i)); +#if !NETSTANDARD1_3 + Assert.Equal(Convert(range[i]), array.GetSqlDecimal(i)); +#endif } - - Assert.Null( array.GetValue(range.Length)); + + Assert.Null(array.GetValue(range.Length)); } [Fact] @@ -159,7 +202,7 @@ public void AppendClearAppendDecimal() { // Arrange var builder = new Decimal128Array.Builder(new Decimal128Type(24, 8)); - + // Act builder.Append(1); builder.Clear(); @@ -256,6 +299,86 @@ public void SwapNull() Assert.Equal(123.456M, array.GetValue(1)); } } + +#if !NETSTANDARD1_3 + public class SqlDecimals + { + [Theory] + [InlineData(200)] + public void AppendSqlDecimal(int count) + { + // Arrange + const int precision = 10; + var builder = new Decimal128Array.Builder(new Decimal128Type(14, precision)); + + // Act + SqlDecimal?[] testData = new SqlDecimal?[count]; + for (int i = 0; i < count; i++) + { + if (i == count - 2) + { + builder.AppendNull(); + testData[i] = null; + continue; + } + SqlDecimal rnd = i * (SqlDecimal)Math.Round(new Random().NextDouble(), 10); + builder.Append(rnd); + testData[i] = SqlDecimal.Round(rnd, precision); + } + + // Assert + var array = builder.Build(); + Assert.Equal(count, array.Length); + for (int i = 0; i < count; i++) + { + Assert.Equal(testData[i], array.GetSqlDecimal(i)); + Assert.Equal(Convert(testData[i]), array.GetValue(i)); + } + } + + [Fact] + public void AppendMaxAndMinSqlDecimal() + { + // Arrange + var builder = new Decimal128Array.Builder(new Decimal128Type(38, 0)); + + // Act + builder.Append(SqlDecimal.MaxValue); + builder.Append(SqlDecimal.MinValue); + builder.Append(SqlDecimal.MaxValue - 10); + builder.Append(SqlDecimal.MinValue + 10); + + // Assert + var array = builder.Build(); + Assert.Equal(SqlDecimal.MaxValue, array.GetSqlDecimal(0)); + Assert.Equal(SqlDecimal.MinValue, array.GetSqlDecimal(1)); + Assert.Equal(SqlDecimal.MaxValue - 10, array.GetSqlDecimal(2)); + Assert.Equal(SqlDecimal.MinValue + 10, array.GetSqlDecimal(3)); + } + + [Fact] + public void AppendRangeSqlDecimal() + { + // Arrange + var builder = new Decimal128Array.Builder(new Decimal128Type(24, 8)); + var range = new SqlDecimal[] { 2.123M, 1.5984M, -0.0000001M, 9878987987987987.1235407M }; + + // Act + builder.AppendRange(range); + builder.AppendNull(); + + // Assert + var array = builder.Build(); + for (int i = 0; i < range.Length; i++) + { + Assert.Equal(range[i], array.GetSqlDecimal(i)); + Assert.Equal(Convert(range[i]), array.GetValue(i)); + } + + Assert.Null(array.GetValue(range.Length)); + } + } +#endif } } } diff --git a/csharp/test/Apache.Arrow.Tests/DecimalUtilityTests.cs b/csharp/test/Apache.Arrow.Tests/DecimalUtilityTests.cs index 9c7e5b587cb9d..dd5f7b9d3f67f 100644 --- a/csharp/test/Apache.Arrow.Tests/DecimalUtilityTests.cs +++ b/csharp/test/Apache.Arrow.Tests/DecimalUtilityTests.cs @@ -14,6 +14,9 @@ // limitations under the License. using System; +#if !NETSTANDARD1_3 +using System.Data.SqlTypes; +#endif using Apache.Arrow.Types; using Xunit; @@ -31,13 +34,13 @@ public class Overflow [InlineData(100.123, 5, 2, true)] [InlineData(100.123, 5, 3, true)] [InlineData(100.123, 6, 3, false)] - public void HasExpectedResultOrThrows(decimal d, int precision , int scale, bool shouldThrow) + public void HasExpectedResultOrThrows(decimal d, int precision, int scale, bool shouldThrow) { var builder = new Decimal128Array.Builder(new Decimal128Type(precision, scale)); if (shouldThrow) { - Assert.Throws(() => builder.Append(d)); + Assert.Throws(() => builder.Append(d)); } else { @@ -55,7 +58,7 @@ public void Decimal256HasExpectedResultOrThrows(decimal d, int precision, int sc var builder = new Decimal256Array.Builder(new Decimal256Type(precision, scale)); builder.Append(d); Decimal256Array result = builder.Build(new TestMemoryAllocator()); ; - + if (shouldThrow) { Assert.Throws(() => result.GetValue(0)); @@ -66,5 +69,57 @@ public void Decimal256HasExpectedResultOrThrows(decimal d, int precision, int sc } } } + + public class SqlDecimals + { + +#if !NETSTANDARD1_3 + [Fact] + public void NegativeSqlDecimal() + { + const int precision = 38; + const int scale = 0; + const int bitWidth = 16; + + var negative = new SqlDecimal(precision, scale, false, 0, 0, 1, 0); + var bytes = new byte[16]; + DecimalUtility.GetBytes(negative.Value, precision, scale, bitWidth, bytes); + var sqlNegative = DecimalUtility.GetSqlDecimal128(new ArrowBuffer(bytes), 0, precision, scale); + Assert.Equal(negative, sqlNegative); + + DecimalUtility.GetBytes(sqlNegative, precision, scale, bytes); + var decimalNegative = DecimalUtility.GetDecimal(new ArrowBuffer(bytes), 0, scale, bitWidth); + Assert.Equal(negative.Value, decimalNegative); + } + + [Fact] + public void LargeScale() + { + string digits = "1.2345678901234567890123456789012345678"; + + var positive = SqlDecimal.Parse(digits); + Assert.Equal(38, positive.Precision); + Assert.Equal(37, positive.Scale); + + var bytes = new byte[16]; + DecimalUtility.GetBytes(positive, positive.Precision, positive.Scale, bytes); + var sqlPositive = DecimalUtility.GetSqlDecimal128(new ArrowBuffer(bytes), 0, positive.Precision, positive.Scale); + + Assert.Equal(positive, sqlPositive); + Assert.Equal(digits, sqlPositive.ToString()); + + digits = "-" + digits; + var negative = SqlDecimal.Parse(digits); + Assert.Equal(38, positive.Precision); + Assert.Equal(37, positive.Scale); + + DecimalUtility.GetBytes(negative, negative.Precision, negative.Scale, bytes); + var sqlNegative = DecimalUtility.GetSqlDecimal128(new ArrowBuffer(bytes), 0, negative.Precision, negative.Scale); + + Assert.Equal(negative, sqlNegative); + Assert.Equal(digits, sqlNegative.ToString()); + } +#endif + } } }