Skip to content

Commit

Permalink
Update FromArrowRecordBatches for dotnet-spark (dotnet#2978)
Browse files Browse the repository at this point in the history
* Support for RecordBatches with StructArrays

* Sq

* Address comments

* Nits

* Nits
  • Loading branch information
Prashanth Govindarajan authored Oct 23, 2020
1 parent db5c49e commit cb7ab00
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 109 deletions.
213 changes: 114 additions & 99 deletions src/Microsoft.Data.Analysis/DataFrame.Arrow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,127 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
using System.Threading.Tasks;
using Apache.Arrow;
using Apache.Arrow.Types;

namespace Microsoft.Data.Analysis
{
public partial class DataFrame
{
private static void AppendDataFrameColumnFromArrowArray(Field field, IArrowArray arrowArray, DataFrame ret, string fieldNamePrefix = "")
{
IArrowType fieldType = field.DataType;
DataFrameColumn dataFrameColumn = null;
string fieldName = fieldNamePrefix + field.Name;
switch (fieldType.TypeId)
{
case ArrowTypeId.Boolean:
BooleanArray arrowBooleanArray = (BooleanArray)arrowArray;
ReadOnlyMemory<byte> valueBuffer = arrowBooleanArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> nullBitMapBuffer = arrowBooleanArray.NullBitmapBuffer.Memory;
dataFrameColumn = new BooleanDataFrameColumn(fieldName, valueBuffer, nullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Double:
PrimitiveArray<double> arrowDoubleArray = (PrimitiveArray<double>)arrowArray;
ReadOnlyMemory<byte> doubleValueBuffer = arrowDoubleArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> doubleNullBitMapBuffer = arrowDoubleArray.NullBitmapBuffer.Memory;
dataFrameColumn = new DoubleDataFrameColumn(fieldName, doubleValueBuffer, doubleNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Float:
PrimitiveArray<float> arrowFloatArray = (PrimitiveArray<float>)arrowArray;
ReadOnlyMemory<byte> floatValueBuffer = arrowFloatArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> floatNullBitMapBuffer = arrowFloatArray.NullBitmapBuffer.Memory;
dataFrameColumn = new SingleDataFrameColumn(fieldName, floatValueBuffer, floatNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Int8:
PrimitiveArray<sbyte> arrowsbyteArray = (PrimitiveArray<sbyte>)arrowArray;
ReadOnlyMemory<byte> sbyteValueBuffer = arrowsbyteArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> sbyteNullBitMapBuffer = arrowsbyteArray.NullBitmapBuffer.Memory;
dataFrameColumn = new SByteDataFrameColumn(fieldName, sbyteValueBuffer, sbyteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Int16:
PrimitiveArray<short> arrowshortArray = (PrimitiveArray<short>)arrowArray;
ReadOnlyMemory<byte> shortValueBuffer = arrowshortArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> shortNullBitMapBuffer = arrowshortArray.NullBitmapBuffer.Memory;
dataFrameColumn = new Int16DataFrameColumn(fieldName, shortValueBuffer, shortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Int32:
PrimitiveArray<int> arrowIntArray = (PrimitiveArray<int>)arrowArray;
ReadOnlyMemory<byte> intValueBuffer = arrowIntArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> intNullBitMapBuffer = arrowIntArray.NullBitmapBuffer.Memory;
dataFrameColumn = new Int32DataFrameColumn(fieldName, intValueBuffer, intNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Int64:
PrimitiveArray<long> arrowLongArray = (PrimitiveArray<long>)arrowArray;
ReadOnlyMemory<byte> longValueBuffer = arrowLongArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> longNullBitMapBuffer = arrowLongArray.NullBitmapBuffer.Memory;
dataFrameColumn = new Int64DataFrameColumn(fieldName, longValueBuffer, longNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.String:
StringArray stringArray = (StringArray)arrowArray;
ReadOnlyMemory<byte> dataMemory = stringArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> offsetsMemory = stringArray.ValueOffsetsBuffer.Memory;
ReadOnlyMemory<byte> nullMemory = stringArray.NullBitmapBuffer.Memory;
dataFrameColumn = new ArrowStringDataFrameColumn(fieldName, dataMemory, offsetsMemory, nullMemory, stringArray.Length, stringArray.NullCount);
break;
case ArrowTypeId.UInt8:
PrimitiveArray<byte> arrowbyteArray = (PrimitiveArray<byte>)arrowArray;
ReadOnlyMemory<byte> byteValueBuffer = arrowbyteArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> byteNullBitMapBuffer = arrowbyteArray.NullBitmapBuffer.Memory;
dataFrameColumn = new ByteDataFrameColumn(fieldName, byteValueBuffer, byteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.UInt16:
PrimitiveArray<ushort> arrowUshortArray = (PrimitiveArray<ushort>)arrowArray;
ReadOnlyMemory<byte> ushortValueBuffer = arrowUshortArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> ushortNullBitMapBuffer = arrowUshortArray.NullBitmapBuffer.Memory;
dataFrameColumn = new UInt16DataFrameColumn(fieldName, ushortValueBuffer, ushortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.UInt32:
PrimitiveArray<uint> arrowUintArray = (PrimitiveArray<uint>)arrowArray;
ReadOnlyMemory<byte> uintValueBuffer = arrowUintArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> uintNullBitMapBuffer = arrowUintArray.NullBitmapBuffer.Memory;
dataFrameColumn = new UInt32DataFrameColumn(fieldName, uintValueBuffer, uintNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.UInt64:
PrimitiveArray<ulong> arrowUlongArray = (PrimitiveArray<ulong>)arrowArray;
ReadOnlyMemory<byte> ulongValueBuffer = arrowUlongArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> ulongNullBitMapBuffer = arrowUlongArray.NullBitmapBuffer.Memory;
dataFrameColumn = new UInt64DataFrameColumn(fieldName, ulongValueBuffer, ulongNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Struct:
StructArray structArray = (StructArray)arrowArray;
StructType structType = (StructType)field.DataType;
IEnumerator<Field> fieldsEnumerator = structType.Fields.GetEnumerator();
IEnumerator<IArrowArray> structArrayEnumerator = structArray.Fields.GetEnumerator();
while (fieldsEnumerator.MoveNext() && structArrayEnumerator.MoveNext())
{
AppendDataFrameColumnFromArrowArray(fieldsEnumerator.Current, structArrayEnumerator.Current, ret, field.Name + "_");
}
break;
case ArrowTypeId.Decimal:
case ArrowTypeId.Binary:
case ArrowTypeId.Date32:
case ArrowTypeId.Date64:
case ArrowTypeId.Dictionary:
case ArrowTypeId.FixedSizedBinary:
case ArrowTypeId.HalfFloat:
case ArrowTypeId.Interval:
case ArrowTypeId.List:
case ArrowTypeId.Map:
case ArrowTypeId.Null:
case ArrowTypeId.Time32:
case ArrowTypeId.Time64:
default:
throw new NotImplementedException(nameof(fieldType.Name));
}

if (dataFrameColumn != null)
{
ret.Columns.Insert(ret.Columns.Count, dataFrameColumn);
}
}

/// <summary>
/// Wraps a <see cref="DataFrame"/> around an Arrow <see cref="RecordBatch"/> without copying data
/// </summary>
Expand All @@ -29,101 +138,7 @@ public static DataFrame FromArrowRecordBatch(RecordBatch recordBatch)
foreach (IArrowArray arrowArray in arrowArrays)
{
Field field = arrowSchema.GetFieldByIndex(fieldIndex);
IArrowType fieldType = field.DataType;
DataFrameColumn dataFrameColumn = null;
switch (fieldType.TypeId)
{
case ArrowTypeId.Boolean:
BooleanArray arrowBooleanArray = (BooleanArray)arrowArray;
ReadOnlyMemory<byte> valueBuffer = arrowBooleanArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> nullBitMapBuffer = arrowBooleanArray.NullBitmapBuffer.Memory;
dataFrameColumn = new BooleanDataFrameColumn(field.Name, valueBuffer, nullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Double:
PrimitiveArray<double> arrowDoubleArray = (PrimitiveArray<double>)arrowArray;
ReadOnlyMemory<byte> doubleValueBuffer = arrowDoubleArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> doubleNullBitMapBuffer = arrowDoubleArray.NullBitmapBuffer.Memory;
dataFrameColumn = new DoubleDataFrameColumn(field.Name, doubleValueBuffer, doubleNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Float:
PrimitiveArray<float> arrowFloatArray = (PrimitiveArray<float>)arrowArray;
ReadOnlyMemory<byte> floatValueBuffer = arrowFloatArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> floatNullBitMapBuffer = arrowFloatArray.NullBitmapBuffer.Memory;
dataFrameColumn = new SingleDataFrameColumn(field.Name, floatValueBuffer, floatNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Int8:
PrimitiveArray<sbyte> arrowsbyteArray = (PrimitiveArray<sbyte>)arrowArray;
ReadOnlyMemory<byte> sbyteValueBuffer = arrowsbyteArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> sbyteNullBitMapBuffer = arrowsbyteArray.NullBitmapBuffer.Memory;
dataFrameColumn = new SByteDataFrameColumn(field.Name, sbyteValueBuffer, sbyteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Int16:
PrimitiveArray<short> arrowshortArray = (PrimitiveArray<short>)arrowArray;
ReadOnlyMemory<byte> shortValueBuffer = arrowshortArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> shortNullBitMapBuffer = arrowshortArray.NullBitmapBuffer.Memory;
dataFrameColumn = new Int16DataFrameColumn(field.Name, shortValueBuffer, shortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Int32:
PrimitiveArray<int> arrowIntArray = (PrimitiveArray<int>)arrowArray;
ReadOnlyMemory<byte> intValueBuffer = arrowIntArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> intNullBitMapBuffer = arrowIntArray.NullBitmapBuffer.Memory;
dataFrameColumn = new Int32DataFrameColumn(field.Name, intValueBuffer, intNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Int64:
PrimitiveArray<long> arrowLongArray = (PrimitiveArray<long>)arrowArray;
ReadOnlyMemory<byte> longValueBuffer = arrowLongArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> longNullBitMapBuffer = arrowLongArray.NullBitmapBuffer.Memory;
dataFrameColumn = new Int64DataFrameColumn(field.Name, longValueBuffer, longNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.String:
StringArray stringArray = (StringArray)arrowArray;
ReadOnlyMemory<byte> dataMemory = stringArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> offsetsMemory = stringArray.ValueOffsetsBuffer.Memory;
ReadOnlyMemory<byte> nullMemory = stringArray.NullBitmapBuffer.Memory;
dataFrameColumn = new ArrowStringDataFrameColumn(field.Name, dataMemory, offsetsMemory, nullMemory, stringArray.Length, stringArray.NullCount);
break;
case ArrowTypeId.UInt8:
PrimitiveArray<byte> arrowbyteArray = (PrimitiveArray<byte>)arrowArray;
ReadOnlyMemory<byte> byteValueBuffer = arrowbyteArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> byteNullBitMapBuffer = arrowbyteArray.NullBitmapBuffer.Memory;
dataFrameColumn = new ByteDataFrameColumn(field.Name, byteValueBuffer, byteNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.UInt16:
PrimitiveArray<ushort> arrowUshortArray = (PrimitiveArray<ushort>)arrowArray;
ReadOnlyMemory<byte> ushortValueBuffer = arrowUshortArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> ushortNullBitMapBuffer = arrowUshortArray.NullBitmapBuffer.Memory;
dataFrameColumn = new UInt16DataFrameColumn(field.Name, ushortValueBuffer, ushortNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.UInt32:
PrimitiveArray<uint> arrowUintArray = (PrimitiveArray<uint>)arrowArray;
ReadOnlyMemory<byte> uintValueBuffer = arrowUintArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> uintNullBitMapBuffer = arrowUintArray.NullBitmapBuffer.Memory;
dataFrameColumn = new UInt32DataFrameColumn(field.Name, uintValueBuffer, uintNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.UInt64:
PrimitiveArray<ulong> arrowUlongArray = (PrimitiveArray<ulong>)arrowArray;
ReadOnlyMemory<byte> ulongValueBuffer = arrowUlongArray.ValueBuffer.Memory;
ReadOnlyMemory<byte> ulongNullBitMapBuffer = arrowUlongArray.NullBitmapBuffer.Memory;
dataFrameColumn = new UInt64DataFrameColumn(field.Name, ulongValueBuffer, ulongNullBitMapBuffer, arrowArray.Length, arrowArray.NullCount);
break;
case ArrowTypeId.Decimal:
case ArrowTypeId.Binary:
case ArrowTypeId.Date32:
case ArrowTypeId.Date64:
case ArrowTypeId.Dictionary:
case ArrowTypeId.FixedSizedBinary:
case ArrowTypeId.HalfFloat:
case ArrowTypeId.Interval:
case ArrowTypeId.List:
case ArrowTypeId.Map:
case ArrowTypeId.Null:
case ArrowTypeId.Struct:
case ArrowTypeId.Time32:
case ArrowTypeId.Time64:
default:
throw new NotImplementedException(nameof(fieldType.Name));
}
ret.Columns.Insert(ret.Columns.Count, dataFrameColumn);
AppendDataFrameColumnFromArrowArray(field, arrowArray, ret);
fieldIndex++;
}
return ret;
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.Data.Analysis/Microsoft.Data.Analysis.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Apache.Arrow" Version="0.14.1" />
<PackageReference Include="Apache.Arrow" Version="2.0.0" />
<PackageReference Include="System.Memory" Version="4.5.3" />
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="4.5.2" />
<PackageReference Include="System.Buffers" Version="4.5.0" />
Expand Down
36 changes: 27 additions & 9 deletions tests/Microsoft.Data.Analysis.Tests/ArrayComparer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ public class ArrayComparer :
IArrowArrayVisitor<Date64Array>,
IArrowArrayVisitor<ListArray>,
IArrowArrayVisitor<StringArray>,
IArrowArrayVisitor<BinaryArray>
IArrowArrayVisitor<BinaryArray>,
IArrowArrayVisitor<StructArray>
{
private readonly IArrowArray _expectedArray;

Expand Down Expand Up @@ -54,6 +55,23 @@ public ArrayComparer(IArrowArray expectedArray)
public void Visit(BinaryArray array) => throw new NotImplementedException();
public void Visit(IArrowArray array) => throw new NotImplementedException();

public void Visit(StructArray array)
{
Assert.IsAssignableFrom<StructArray>(_expectedArray);
StructArray expectedArray = (StructArray)_expectedArray;

Assert.Equal(expectedArray.Length, array.Length);
Assert.Equal(expectedArray.NullCount, array.NullCount);
Assert.Equal(expectedArray.Offset, array.Offset);
Assert.Equal(expectedArray.Data.Children.Length, array.Data.Children.Length);
Assert.Equal(expectedArray.Fields.Count, array.Fields.Count);

for (int i = 0; i < array.Fields.Count; i++)
{
array.Fields[i].Accept(new ArrayComparer(expectedArray.Fields[i]));
}
}

private void CompareArrays<T>(PrimitiveArray<T> actualArray)
where T : struct, IEquatable<T>
{
Expand All @@ -68,15 +86,15 @@ private void CompareArrays<T>(PrimitiveArray<T> actualArray)
{
Assert.True(expectedArray.NullBitmapBuffer.Span.SequenceEqual(actualArray.NullBitmapBuffer.Span));
}
else
{
else
{
// expectedArray may have passed in a null bitmap. DataFrame might have populated it with Length set bits
Assert.Equal(0, expectedArray.NullCount);
Assert.Equal(0, actualArray.NullCount);
for (int i = 0; i < actualArray.Length; i++)
{
Assert.True(actualArray.IsValid(i));
}
Assert.Equal(0, expectedArray.NullCount);
Assert.Equal(0, actualArray.NullCount);
for (int i = 0; i < actualArray.Length; i++)
{
Assert.True(actualArray.IsValid(i));
}
}
Assert.True(expectedArray.Values.Slice(0, expectedArray.Length).SequenceEqual(actualArray.Values.Slice(0, actualArray.Length)));
}
Expand Down
Loading

0 comments on commit cb7ab00

Please sign in to comment.