Skip to content

Commit

Permalink
Bug Fix for Spark 3.x - Avoid converting converted Row values (#868)
Browse files Browse the repository at this point in the history
  • Loading branch information
suhsteve authored Mar 27, 2021
1 parent b9283eb commit 33299cf
Show file tree
Hide file tree
Showing 11 changed files with 196 additions and 104 deletions.
8 changes: 6 additions & 2 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ variables:

# Filter DataFrameTests.TestDataFrameGroupedMapUdf and DataFrameTests.TestGroupedMapUdf backwardCompatible
# tests due to https://github.com/dotnet/spark/pull/711
# Filter UdfSimpleTypesTests.TestUdfWithDuplicateTimestamps 3.x backwardCompatible test due to bug related
# to duplicate types that NeedConversion. Bug fixed in https://github.com/dotnet/spark/pull/868
backwardCompatibleTestOptions_Windows_3_0: "--filter \
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestDataFrameGroupedMapUdf)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestGroupedMapUdf)"
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestGroupedMapUdf)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfSimpleTypesTests.TestUdfWithDuplicateTimestamps)"
forwardCompatibleTestOptions_Windows_3_0: ""
backwardCompatibleTestOptions_Linux_3_0: $(backwardCompatibleTestOptions_Windows_3_0)
forwardCompatibleTestOptions_Linux_3_0: $(forwardCompatibleTestOptions_Linux_2_4)
Expand Down Expand Up @@ -85,7 +88,8 @@ variables:
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataFrameTests.TestUDF)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.SparkSessionExtensionsTests.TestVersion)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataStreamWriterTests.TestForeachBatch)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataStreamWriterTests.TestForeach)"
(FullyQualifiedName!=Microsoft.Spark.E2ETest.IpcTests.DataStreamWriterTests.TestForeach)&\
(FullyQualifiedName!=Microsoft.Spark.E2ETest.UdfTests.UdfSimpleTypesTests.TestUdfWithDuplicateTimestamps)"
# Skip all forwardCompatible tests since microsoft-spark-3-1 jar does not get built when
# building forwardCompatible repo.
forwardCompatibleTestOptions_Windows_3_1: "--filter FullyQualifiedName=NONE"
Expand Down
112 changes: 112 additions & 0 deletions src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/RowTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Linq;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Xunit;
using static Microsoft.Spark.Sql.Functions;

namespace Microsoft.Spark.E2ETest.IpcTests

{
[Collection("Spark E2E Tests")]
public class RowTests
{
private readonly SparkSession _spark;

public RowTests(SparkFixture fixture)
{
_spark = fixture.Spark;
}

[Fact]
public void TestWithDuplicatedRows()
{
var timestamp = new Timestamp(2020, 1, 1, 0, 0, 0, 0);
var schema = new StructType(new StructField[]
{
new StructField("ts", new TimestampType())
});
var data = new GenericRow[]
{
new GenericRow(new object[] { timestamp })
};

DataFrame df = _spark.CreateDataFrame(data, schema);
Row[] rows = df
.WithColumn("tsRow", Struct("ts"))
.WithColumn("tsRowRow", Struct("tsRow"))
.Collect()
.ToArray();

Assert.Single(rows);

Row row = rows[0];
Assert.Equal(3, row.Values.Length);
Assert.Equal(timestamp, row.Values[0]);

Row tsRow = row.Values[1] as Row;
Assert.Single(tsRow.Values);
Assert.Equal(timestamp, tsRow.Values[0]);

Row tsRowRow = row.Values[2] as Row;
Assert.Single(tsRowRow.Values);
Assert.Equal(tsRowRow.Values[0], tsRow);
}

[Fact]
public void TestWithDuplicateTimestamps()
{
var timestamp = new Timestamp(2020, 1, 1, 0, 0, 0, 0);
var schema = new StructType(new StructField[]
{
new StructField("ts", new TimestampType())
});
var data = new GenericRow[]
{
new GenericRow(new object[] { timestamp }),
new GenericRow(new object[] { timestamp }),
new GenericRow(new object[] { timestamp })
};

DataFrame df = _spark.CreateDataFrame(data, schema);
Row[] rows = df.Collect().ToArray();

Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Single(row.Values);
Assert.Equal(timestamp, row.GetAs<Timestamp>(0));
}
}

[Fact]
public void TestWithDuplicateDates()
{
var date = new Date(2020, 1, 1);
var schema = new StructType(new StructField[]
{
new StructField("date", new DateType())
});
var data = new GenericRow[]
{
new GenericRow(new object[] { date }),
new GenericRow(new object[] { date }),
new GenericRow(new object[] { date })
};

DataFrame df = _spark.CreateDataFrame(data, schema);

Row[] rows = df.Collect().ToArray();

Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Single(row.Values);
Assert.Equal(date, row.GetAs<Date>(0));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Streaming;
Expand Down
34 changes: 34 additions & 0 deletions src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfSimpleTypesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,40 @@ public void TestUdfWithTimestampType()
Assert.Equal(expected, rowsToArray);
}

/// <summary>
/// UDF that returns a timestamp string.
/// </summary>
[Fact]
public void TestUdfWithDuplicateTimestamps()
{
var timestamp = new Timestamp(2020, 1, 1, 0, 0, 0, 0);
var schema = new StructType(new StructField[]
{
new StructField("ts", new TimestampType())
});
var data = new GenericRow[]
{
new GenericRow(new object[] { timestamp }),
new GenericRow(new object[] { timestamp }),
new GenericRow(new object[] { timestamp })
};

var expectedTimestamp = new Timestamp(1970, 1, 2, 0, 0, 0, 0);
Func<Column, Column> udf = Udf<Timestamp, Timestamp>(
ts => new Timestamp(1970, 1, 2, 0, 0, 0, 0));

DataFrame df = _spark.CreateDataFrame(data, schema);

Row[] rows = df.Select(udf(df["ts"])).Collect().ToArray();

Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Single(row.Values);
Assert.Equal(expectedTimestamp, row.Values[0]);
}
}

/// <summary>
/// UDF that returns Timestamp type.
/// </summary>
Expand Down
4 changes: 2 additions & 2 deletions src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ public void RowConstructorTest()
pickledBytes.Length);

Assert.Equal(2, unpickledData.Length);
Assert.Equal(row1, (unpickledData[0] as RowConstructor).GetRow());
Assert.Equal(row2, (unpickledData[1] as RowConstructor).GetRow());
Assert.Equal(row1, unpickledData[0]);
Assert.Equal(row2, unpickledData[1]);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ protected internal override CommandExecutorStat ExecuteCore(
// The following can happen if an UDF takes Row object(s).
// The JVM Spark side sends a Row object that wraps all the columns used
// in the UDF, thus, it is normalized below (the extra layer is removed).
if (row is RowConstructor rowConstructor)
if (row is Row r)
{
row = rowConstructor.GetRow().Values;
row = r.Values;
}

// Split id is not used for SQL UDFs, so 0 is passed.
Expand Down
5 changes: 1 addition & 4 deletions src/csharp/Microsoft.Spark/RDD/Collector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ public object Deserialize(Stream stream, int length)
{
// Refer to the AutoBatchedPickler class in spark/core/src/main/scala/org/apache/
// spark/api/python/SerDeUtil.scala regarding how the Rows may be batched.
return PythonSerDe.GetUnpickledObjects(stream, length)
.Cast<RowConstructor>()
.Select(rc => rc.GetRow())
.ToArray();
return PythonSerDe.GetUnpickledObjects(stream, length).Cast<Row>().ToArray();
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/csharp/Microsoft.Spark/Sql/RowCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public IEnumerable<Row> Collect(ISocketWrapper socket)

foreach (object unpickled in unpickledObjects)
{
yield return (unpickled as RowConstructor).GetRow();
yield return unpickled as Row;
}
}
}
Expand Down
116 changes: 37 additions & 79 deletions src/csharp/Microsoft.Spark/Sql/RowConstructor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,94 +22,39 @@ internal sealed class RowConstructor : IObjectConstructor
/// sent per batch if there are nested rows contained in the row. Note that
/// this is thread local variable because one RowConstructor object is
/// registered to the Unpickler and there could be multiple threads unpickling
/// the data using the same object registered.
/// the data using the same registered object.
/// </summary>
[ThreadStatic]
private static IDictionary<string, StructType> s_schemaCache;

/// <summary>
/// The RowConstructor that created this instance.
/// Used by Unpickler to pass unpickled schema for handling. The Unpickler
/// will reuse the <see cref="RowConstructor"/> object when
/// it needs to start constructing a <see cref="Row"/>. The schema is passed
/// to <see cref="construct(object[])"/> and the returned
/// <see cref="IObjectConstructor"/> is used to build the rest of the <see cref="Row"/>.
/// </summary>
private readonly RowConstructor _parent;

/// <summary>
/// Stores the args passed from construct().
/// </summary>
private readonly object[] _args;

public RowConstructor() : this(null, null)
{
}

public RowConstructor(RowConstructor parent, object[] args)
{
_parent = parent;
_args = args;
}

/// <summary>
/// Used by Unpickler to pass unpickled data for handling.
/// </summary>
/// <param name="args">Unpickled data</param>
/// <returns>New RowConstructor object capturing args data</returns>
/// <param name="args">Unpickled schema</param>
/// <returns>
/// New <see cref="RowWithSchemaConstructor"/>object capturing the schema.
/// </returns>
public object construct(object[] args)
{
// Every first call to construct() contains the schema data. When
// a new RowConstructor object is returned from this function,
// construct() is called on the returned object with the actual
// row data. The original RowConstructor object may be reused by the
// Unpickler and each subsequent construct() call can contain the
// schema data or a RowConstructor object that contains row data.
if (s_schemaCache is null)
{
s_schemaCache = new Dictionary<string, StructType>();
}

// Return a new RowConstructor where the args either represent the
// schema or the row data. The parent becomes important when calling
// GetRow() on the RowConstructor containing the row data.
//
// - When args is the schema, return a new RowConstructor where the
// parent is set to the calling RowConstructor object.
//
// - In the case where args is the row data, construct() is called on a
// RowConstructor object that contains the schema for the row data. A
// new RowConstructor is returned where the parent is set to the schema
// containing RowConstructor.
return new RowConstructor(this, args);
}

/// <summary>
/// Construct a Row object from unpickled data. This is only to be called
/// on a RowConstructor that contains the row data.
/// </summary>
/// <returns>A row object with unpickled data</returns>
public Row GetRow()
{
Debug.Assert(_parent != null);

// It is possible that an entry of a Row (row1) may itself be a Row (row2).
// If the entry is a RowConstructor then it will be a RowConstructor
// which contains the data for row2. Therefore we will call GetRow()
// on the RowConstructor to materialize row2 and replace the RowConstructor
// entry in row1.
for (int i = 0; i < _args.Length; ++i)
{
if (_args[i] is RowConstructor rowConstructor)
{
_args[i] = rowConstructor.GetRow();
}
}

return new Row(_args, _parent.GetSchema());
Debug.Assert((args != null) && (args.Length == 1) && (args[0] is string));
return new RowWithSchemaConstructor(GetSchema(s_schemaCache, (string)args[0]));
}

/// <summary>
/// Clears the schema cache. Spark sends rows in batches and for each
/// row there is an accompanying set of schemas and row entries. If the
/// schema was not cached, then it would need to be parsed and converted
/// to a StructType for every row in the batch. A new batch may contain
/// rows from a different table, so calling <c>Reset</c> after each
/// rows from a different table, so calling <see cref="Reset"/> after each
/// batch would aid in preventing the cache from growing too large.
/// Caching the schemas for each batch, ensures that each schema is
/// only parsed and converted to a StructType once per batch.
Expand All @@ -119,23 +64,36 @@ internal void Reset()
s_schemaCache?.Clear();
}

/// <summary>
/// Get or cache the schema string contained in args. Calling this
/// is only valid if the child args contain the row values.
/// </summary>
/// <returns></returns>
private StructType GetSchema()
private static StructType GetSchema(IDictionary<string, StructType> schemaCache, string schemaString)
{
Debug.Assert(s_schemaCache != null);
Debug.Assert((_args != null) && (_args.Length == 1) && (_args[0] is string));
var schemaString = (string)_args[0];
if (!s_schemaCache.TryGetValue(schemaString, out StructType schema))
if (!schemaCache.TryGetValue(schemaString, out StructType schema))
{
schema = (StructType)DataType.ParseDataType(schemaString);
s_schemaCache.Add(schemaString, schema);
schemaCache.Add(schemaString, schema);
}

return schema;
}
}

/// <summary>
/// Created from <see cref="RowConstructor"/> and subsequently used
/// by the Unpickler to construct a <see cref="Row"/>.
/// </summary>
internal sealed class RowWithSchemaConstructor : IObjectConstructor
{
private readonly StructType _schema;

internal RowWithSchemaConstructor(StructType schema)
{
_schema = schema;
}

/// <summary>
/// Used by Unpickler to pass unpickled row values for handling.
/// </summary>
/// <param name="args">Unpickled row values.</param>
/// <returns>Row object.</returns>
public object construct(object[] args) => new Row(args, _schema);
}
}
Loading

0 comments on commit 33299cf

Please sign in to comment.