Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-38351: [C#] Add SqlDecimal support to Decimal128Array #38481

Merged
merged 6 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions csharp/src/Apache.Arrow/Arrays/Decimal128Array.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

using System;
using System.Collections.Generic;
#if !NETSTANDARD1_3
using System.Data.SqlTypes;
#endif
using System.Diagnostics;
using System.Numerics;
using Apache.Arrow.Arrays;
Expand Down Expand Up @@ -61,6 +64,31 @@ public Builder AppendRange(IEnumerable<decimal> values)
return Instance;
}

#if !NETSTANDARD1_3
public Builder Append(SqlDecimal value)
{
Span<byte> bytes = stackalloc byte[DataType.ByteWidth];
DecimalUtility.GetBytes(value, DataType.Precision, DataType.Scale, bytes);

return Append(bytes);
}

public Builder AppendRange(IEnumerable<SqlDecimal> values)
{
if (values == null)
{
throw new ArgumentNullException(nameof(values));
}

foreach (decimal d in values)
CurtHagenlocher marked this conversation as resolved.
Show resolved Hide resolved
{
Append(d);
}

return Instance;
}
#endif

public Builder Set(int index, decimal value)
{
Span<byte> bytes = stackalloc byte[DataType.ByteWidth];
Expand Down Expand Up @@ -91,5 +119,17 @@ public Decimal128Array(ArrayData data)
}
return DecimalUtility.GetDecimal(ValueBuffer, index, Scale, ByteWidth);
}

#if !NETSTANDARD1_3
public SqlDecimal? GetSqlDecimal(int index)
{
if (IsNull(index))
{
return null;
}

return DecimalUtility.GetSqlDecimal128(ValueBuffer, index, Precision, Scale);
}
#endif
}
}
45 changes: 45 additions & 0 deletions csharp/src/Apache.Arrow/DecimalUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
// limitations under the License.

using System;
#if !NETSTANDARD1_3
using System.Data.SqlTypes;
#endif
using System.Numerics;

namespace Apache.Arrow
Expand Down Expand Up @@ -73,6 +76,27 @@ internal static decimal GetDecimal(in ArrowBuffer valueBuffer, int index, int sc
}
}

#if !NETSTANDARD1_3
internal static SqlDecimal GetSqlDecimal128(in ArrowBuffer valueBuffer, int index, int precision, int scale)
{
const int byteWidth = 16;
const int intWidth = byteWidth / 4;

byte mostSignificantByte = valueBuffer.Span[(index + 1) * byteWidth - 1];
bool isPositive = (mostSignificantByte & 0x80) == 0;

ReadOnlySpan<int> value = valueBuffer.Span.CastTo<int>().Slice(index * intWidth, intWidth);
if (isPositive)
{
return new SqlDecimal((byte)precision, (byte)scale, true, value[0], value[1], value[2], value[3]);
}
else
{
return new SqlDecimal((byte)precision, (byte)scale, false, -value[0], ~value[1], ~value[2], ~value[3]);
}
}
#endif

private static decimal DivideByScale(BigInteger integerValue, int scale)
{
decimal result = (decimal)integerValue; // this cast is safe here
Expand Down Expand Up @@ -169,5 +193,26 @@ internal static void GetBytes(decimal value, int precision, int scale, int byteW
}
}
}

#if !NETSTANDARD1_3
internal static void GetBytes(SqlDecimal value, int precision, int scale, Span<byte> bytes)
{
if (value.Precision != precision || value.Scale != scale)
{
value = SqlDecimal.ConvertToPrecScale(value, precision, scale);
}

// TODO: Consider groveling in the internals to avoid the probable allocation
Span<int> span = bytes.CastTo<int>();
value.Data.AsSpan().CopyTo(span);
if (!value.IsPositive)
{
span[0] = -span[0];
span[1] = ~span[1];
span[2] = ~span[2];
span[3] = ~span[3];
}
}
#endif
}
}
117 changes: 109 additions & 8 deletions csharp/test/Apache.Arrow.Tests/Decimal128ArrayTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,28 @@
// limitations under the License.

using System;
using System.Collections.Generic;
#if !NETSTANDARD1_3
using System.Data.SqlTypes;
#endif
using Apache.Arrow.Types;
using Xunit;

namespace Apache.Arrow.Tests
{
public class Decimal128ArrayTests
{
#if !NETSTANDARD1_3
static SqlDecimal? Convert(decimal? value)
{
return value == null ? null : new SqlDecimal(value.Value);
}

static decimal? Convert(SqlDecimal? value)
{
return value == null ? null : value.Value.Value;
}
#endif

public class Builder
{
public class AppendNull
Expand All @@ -30,7 +44,7 @@ public class AppendNull
public void AppendThenGetGivesNull()
{
// Arrange
var builder = new Decimal128Array.Builder(new Decimal128Type(8,2));
var builder = new Decimal128Array.Builder(new Decimal128Type(8, 2));

// Act

Expand All @@ -45,6 +59,12 @@ public void AppendThenGetGivesNull()
Assert.Null(array.GetValue(0));
Assert.Null(array.GetValue(1));
Assert.Null(array.GetValue(2));

#if !NETSTANDARD1_3
Assert.Null(array.GetSqlDecimal(0));
Assert.Null(array.GetSqlDecimal(1));
Assert.Null(array.GetSqlDecimal(2));
#endif
}
}

Expand All @@ -67,7 +87,7 @@ public void AppendDecimal(int count)
testData[i] = null;
continue;
}
decimal rnd = i * (decimal)Math.Round(new Random().NextDouble(),10);
decimal rnd = i * (decimal)Math.Round(new Random().NextDouble(), 10);
testData[i] = rnd;
builder.Append(rnd);
}
Expand All @@ -78,6 +98,9 @@ public void AppendDecimal(int count)
for (int i = 0; i < count; i++)
{
Assert.Equal(testData[i], array.GetValue(i));
#if !NETSTANDARD1_3
Assert.Equal(Convert(testData[i]), array.GetSqlDecimal(i));
#endif
}
}

Expand All @@ -95,6 +118,11 @@ public void AppendLargeDecimal()
var array = builder.Build();
Assert.Equal(large, array.GetValue(0));
Assert.Equal(-large, array.GetValue(1));

#if !NETSTANDARD1_3
Assert.Equal(Convert(large), array.GetSqlDecimal(0));
Assert.Equal(Convert(-large), array.GetSqlDecimal(1));
#endif
}

[Fact]
Expand All @@ -115,6 +143,13 @@ public void AppendMaxAndMinDecimal()
Assert.Equal(Decimal.MinValue, array.GetValue(1));
Assert.Equal(Decimal.MaxValue - 10, array.GetValue(2));
Assert.Equal(Decimal.MinValue + 10, array.GetValue(3));

#if !NETSTANDARD1_3
Assert.Equal(Convert(Decimal.MaxValue), array.GetSqlDecimal(0));
Assert.Equal(Convert(Decimal.MinValue), array.GetSqlDecimal(1));
Assert.Equal(Convert(Decimal.MaxValue) - 10, array.GetSqlDecimal(2));
Assert.Equal(Convert(Decimal.MinValue) + 10, array.GetSqlDecimal(3));
#endif
}

[Fact]
Expand All @@ -131,35 +166,43 @@ public void AppendFractionalDecimal()
var array = builder.Build();
Assert.Equal(fraction, array.GetValue(0));
Assert.Equal(-fraction, array.GetValue(1));

#if !NETSTANDARD1_3
Assert.Equal(Convert(fraction), array.GetSqlDecimal(0));
Assert.Equal(Convert(-fraction), array.GetSqlDecimal(1));
#endif
}

[Fact]
public void AppendRangeDecimal()
{
// Arrange
var builder = new Decimal128Array.Builder(new Decimal128Type(24, 8));
var range = new decimal[] {2.123M, 1.5984M, -0.0000001M, 9878987987987987.1235407M};
var range = new decimal[] { 2.123M, 1.5984M, -0.0000001M, 9878987987987987.1235407M };

// Act
builder.AppendRange(range);
builder.AppendNull();

// Assert
var array = builder.Build();
for(int i = 0; i < range.Length; i ++)
for (int i = 0; i < range.Length; i++)
{
Assert.Equal(range[i], array.GetValue(i));
#if !NETSTANDARD1_3
Assert.Equal(Convert(range[i]), array.GetSqlDecimal(i));
#endif
}
Assert.Null( array.GetValue(range.Length));

Assert.Null(array.GetValue(range.Length));
}

[Fact]
public void AppendClearAppendDecimal()
{
// Arrange
var builder = new Decimal128Array.Builder(new Decimal128Type(24, 8));

// Act
builder.Append(1);
builder.Clear();
Expand Down Expand Up @@ -256,6 +299,64 @@ public void SwapNull()
Assert.Equal(123.456M, array.GetValue(1));
}
}

#if !NETSTANDARD1_3
public class SqlDecimals
{
[Theory]
[InlineData(200)]
public void AppendSqlDecimal(int count)
{
// Arrange
const int precision = 10;
var builder = new Decimal128Array.Builder(new Decimal128Type(14, precision));

// Act
SqlDecimal?[] testData = new SqlDecimal?[count];
for (int i = 0; i < count; i++)
{
if (i == count - 2)
{
builder.AppendNull();
testData[i] = null;
continue;
}
SqlDecimal rnd = i * (SqlDecimal)Math.Round(new Random().NextDouble(), 10);
builder.Append(rnd);
testData[i] = SqlDecimal.Round(rnd, precision);
}

// Assert
var array = builder.Build();
Assert.Equal(count, array.Length);
for (int i = 0; i < count; i++)
{
Assert.Equal(testData[i], array.GetSqlDecimal(i));
Assert.Equal(Convert(testData[i]), array.GetValue(i));
}
}

[Fact]
public void AppendMaxAndMinSqlDecimal()
{
// Arrange
var builder = new Decimal128Array.Builder(new Decimal128Type(38, 0));

// Act
builder.Append(SqlDecimal.MaxValue);
builder.Append(SqlDecimal.MinValue);
builder.Append(SqlDecimal.MaxValue - 10);
builder.Append(SqlDecimal.MinValue + 10);

// Assert
var array = builder.Build();
Assert.Equal(SqlDecimal.MaxValue, array.GetSqlDecimal(0));
Assert.Equal(SqlDecimal.MinValue, array.GetSqlDecimal(1));
Assert.Equal(SqlDecimal.MaxValue - 10, array.GetSqlDecimal(2));
Assert.Equal(SqlDecimal.MinValue + 10, array.GetSqlDecimal(3));
}
}
#endif
}
}
}
Loading