Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug Fix for Spark 3.x - Avoid converting converted Row values #868

Merged
merged 10 commits into from
Mar 27, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions src/csharp/Microsoft.Spark.E2ETest/IpcTests/Sql/RowTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Linq;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Types;
using Xunit;
using static Microsoft.Spark.Sql.Functions;

namespace Microsoft.Spark.E2ETest.IpcTests

{
[Collection("Spark E2E Tests")]
public class RowTests
suhsteve marked this conversation as resolved.
Show resolved Hide resolved
{
private readonly SparkSession _spark;

public RowTests(SparkFixture fixture)
{
_spark = fixture.Spark;
}

[Fact]
public void TestWithDuplicatedRows()
{
var timestamp = new Timestamp(2020, 1, 1, 0, 0, 0, 0);
var schema = new StructType(new StructField[]
{
new StructField("ts", new TimestampType())
});
var data = new GenericRow[]
{
new GenericRow(new object[] { timestamp })
};

DataFrame df = _spark.CreateDataFrame(data, schema);
Row[] rows = df
.WithColumn("tsRow", Struct("ts"))
.WithColumn("tsRowRow", Struct("tsRow"))
.Collect()
.ToArray();

Assert.Single(rows);

Row row = rows[0];
Assert.Equal(3, row.Values.Length);
Assert.Equal(timestamp, row.Values[0]);

Row tsRow = row.Values[1] as Row;
Assert.Single(tsRow.Values);
Assert.Equal(timestamp, tsRow.Values[0]);

Row tsRowRow = row.Values[2] as Row;
Assert.Single(tsRowRow.Values);
Assert.Equal(tsRowRow.Values[0], tsRow);
}

[Fact]
public void TestWithDuplicateTimestamps()
{
var timestamp = new Timestamp(2020, 1, 1, 0, 0, 0, 0);
var schema = new StructType(new StructField[]
{
new StructField("ts", new TimestampType())
});
var data = new GenericRow[]
{
new GenericRow(new object[] { timestamp }),
new GenericRow(new object[] { timestamp }),
new GenericRow(new object[] { timestamp })
};

DataFrame df = _spark.CreateDataFrame(data, schema);
Row[] rows = df.Collect().ToArray();

Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Single(row.Values);
Assert.Equal(timestamp, row.GetAs<Timestamp>(0));
}
}

[Fact]
public void TestWithDuplicateDates()
{
var date = new Date(2020, 1, 1);
var schema = new StructType(new StructField[]
{
new StructField("date", new DateType())
});
var data = new GenericRow[]
{
new GenericRow(new object[] { date }),
new GenericRow(new object[] { date }),
new GenericRow(new object[] { date })
};

DataFrame df = _spark.CreateDataFrame(data, schema);

Row[] rows = df.Collect().ToArray();

Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Single(row.Values);
Assert.Equal(date, row.GetAs<Date>(0));
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using Microsoft.Spark.E2ETest.Utils;
using Microsoft.Spark.Sql;
using Microsoft.Spark.Sql.Streaming;
Expand Down
40 changes: 40 additions & 0 deletions src/csharp/Microsoft.Spark.E2ETest/UdfTests/UdfSimpleTypesTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,46 @@ public void TestUdfWithTimestampType()
Assert.Equal(expected, rowsToArray);
}

/// <summary>
/// UDF that returns a timestamp string.
/// </summary>
[Fact]
public void TestUdfWithDuplicateTimestamps()
{
var timestamp = new Timestamp(2020, 1, 1, 0, 0, 0, 0);
var schema = new StructType(new StructField[]
{
new StructField("ts", new TimestampType())
});
var data = new GenericRow[]
{
new GenericRow(new object[] { timestamp }),
new GenericRow(new object[] { timestamp }),
new GenericRow(new object[] { timestamp })
};

var expectedTimestamp = new Timestamp(1970, 1, 2, 0, 0, 0, 0);
string tsString = expectedTimestamp.ToString();
var returnType = new StructType(new[] { new StructField("tsString", new StringType()) });
Func<Column, Column> udf =
Udf<Row>(row => new GenericRow(new string[] { tsString }), returnType);

DataFrame df = _spark.CreateDataFrame(data, schema);
Column newCol = udf(Struct(df.Col("ts")))
.GetField("tsString")
.Cast("timestamp")
.Alias("tsStringCastToTs");

Row[] rows = df.Select(newCol).Collect().ToArray();

Assert.Equal(3, rows.Length);
foreach (Row row in rows)
{
Assert.Single(row.Values);
Assert.Equal(expectedTimestamp, row.Values[0]);
}
}

/// <summary>
/// UDF that returns Timestamp type.
/// </summary>
Expand Down
4 changes: 2 additions & 2 deletions src/csharp/Microsoft.Spark.UnitTest/Sql/RowTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ public void RowConstructorTest()
pickledBytes.Length);

Assert.Equal(2, unpickledData.Length);
Assert.Equal(row1, (unpickledData[0] as RowConstructor).GetRow());
Assert.Equal(row2, (unpickledData[1] as RowConstructor).GetRow());
Assert.Equal(row1, unpickledData[0]);
Assert.Equal(row2, unpickledData[1]);
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ protected internal override CommandExecutorStat ExecuteCore(
// The following can happen if an UDF takes Row object(s).
// The JVM Spark side sends a Row object that wraps all the columns used
// in the UDF, thus, it is normalized below (the extra layer is removed).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment still relevant?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worker will crash without this, so I believe it is ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I meant respect to the code. I think "extra layer is removed" is regarding the RowConstructor, but now that it's gone, is the comment up to date?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra layer can refer to Row, so we take out Values from it ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with removing the ( )'s though if things sound unclear.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@elvaliuliuliu do we need to update the description or does it still apply ?

if (row is RowConstructor rowConstructor)
if (row is Row r)
{
row = rowConstructor.GetRow().Values;
row = r.Values;
}

// Split id is not used for SQL UDFs, so 0 is passed.
Expand Down
5 changes: 1 addition & 4 deletions src/csharp/Microsoft.Spark/RDD/Collector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,7 @@ public object Deserialize(Stream stream, int length)
{
// Refer to the AutoBatchedPickler class in spark/core/src/main/scala/org/apache/
// spark/api/python/SerDeUtil.scala regarding how the Rows may be batched.
return PythonSerDe.GetUnpickledObjects(stream, length)
.Cast<RowConstructor>()
.Select(rc => rc.GetRow())
.ToArray();
return PythonSerDe.GetUnpickledObjects(stream, length).Cast<Row>().ToArray();
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/csharp/Microsoft.Spark/Sql/RowCollector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public IEnumerable<Row> Collect(ISocketWrapper socket)

foreach (object unpickled in unpickledObjects)
{
yield return (unpickled as RowConstructor).GetRow();
yield return unpickled as Row;
}
}
}
Expand Down
116 changes: 37 additions & 79 deletions src/csharp/Microsoft.Spark/Sql/RowConstructor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,94 +22,39 @@ internal sealed class RowConstructor : IObjectConstructor
/// sent per batch if there are nested rows contained in the row. Note that
/// this is thread local variable because one RowConstructor object is
/// registered to the Unpickler and there could be multiple threads unpickling
/// the data using the same object registered.
/// the data using the same registered object.
/// </summary>
[ThreadStatic]
private static IDictionary<string, StructType> s_schemaCache;

/// <summary>
/// The RowConstructor that created this instance.
/// Used by Unpickler to pass unpickled schema for handling. The Unpickler
/// will reuse the <see cref="RowConstructor"/> object when
/// it needs to start constructing a <see cref="Row"/>. The schema is passed
/// to <see cref="construct(object[])"/> and the returned
/// <see cref="IObjectConstructor"/> is used to build the rest of the <see cref="Row"/>.
/// </summary>
private readonly RowConstructor _parent;

/// <summary>
/// Stores the args passed from construct().
/// </summary>
private readonly object[] _args;

public RowConstructor() : this(null, null)
{
}

public RowConstructor(RowConstructor parent, object[] args)
{
_parent = parent;
_args = args;
}

/// <summary>
/// Used by Unpickler to pass unpickled data for handling.
/// </summary>
/// <param name="args">Unpickled data</param>
/// <returns>New RowConstructor object capturing args data</returns>
/// <param name="args">Unpickled schema</param>
/// <returns>
/// New <see cref="RowWithSchemaConstructor"/>object capturing the schema.
/// </returns>
public object construct(object[] args)
{
// Every first call to construct() contains the schema data. When
// a new RowConstructor object is returned from this function,
// construct() is called on the returned object with the actual
// row data. The original RowConstructor object may be reused by the
// Unpickler and each subsequent construct() call can contain the
// schema data or a RowConstructor object that contains row data.
suhsteve marked this conversation as resolved.
Show resolved Hide resolved
if (s_schemaCache is null)
{
s_schemaCache = new Dictionary<string, StructType>();
}

// Return a new RowConstructor where the args either represent the
suhsteve marked this conversation as resolved.
Show resolved Hide resolved
// schema or the row data. The parent becomes important when calling
// GetRow() on the RowConstructor containing the row data.
//
// - When args is the schema, return a new RowConstructor where the
// parent is set to the calling RowConstructor object.
//
// - In the case where args is the row data, construct() is called on a
// RowConstructor object that contains the schema for the row data. A
// new RowConstructor is returned where the parent is set to the schema
// containing RowConstructor.
return new RowConstructor(this, args);
}

/// <summary>
/// Construct a Row object from unpickled data. This is only to be called
/// on a RowConstructor that contains the row data.
/// </summary>
/// <returns>A row object with unpickled data</returns>
public Row GetRow()
{
Debug.Assert(_parent != null);

// It is possible that an entry of a Row (row1) may itself be a Row (row2).
// If the entry is a RowConstructor then it will be a RowConstructor
Comment on lines -91 to -92
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we already have test case handling this right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah there are a few that have rows as column values.

// which contains the data for row2. Therefore we will call GetRow()
// on the RowConstructor to materialize row2 and replace the RowConstructor
// entry in row1.
for (int i = 0; i < _args.Length; ++i)
{
if (_args[i] is RowConstructor rowConstructor)
{
_args[i] = rowConstructor.GetRow();
}
}

return new Row(_args, _parent.GetSchema());
Debug.Assert((args != null) && (args.Length == 1) && (args[0] is string));
return new RowWithSchemaConstructor(GetSchema(s_schemaCache, (string)args[0]));
}

/// <summary>
/// Clears the schema cache. Spark sends rows in batches and for each
/// row there is an accompanying set of schemas and row entries. If the
/// schema was not cached, then it would need to be parsed and converted
/// to a StructType for every row in the batch. A new batch may contain
/// rows from a different table, so calling <c>Reset</c> after each
/// rows from a different table, so calling <see cref="Reset"/> after each
/// batch would aid in preventing the cache from growing too large.
/// Caching the schemas for each batch, ensures that each schema is
/// only parsed and converted to a StructType once per batch.
Expand All @@ -119,23 +64,36 @@ internal void Reset()
s_schemaCache?.Clear();
}

/// <summary>
/// Get or cache the schema string contained in args. Calling this
/// is only valid if the child args contain the row values.
/// </summary>
/// <returns></returns>
private StructType GetSchema()
private static StructType GetSchema(IDictionary<string, StructType> schemaCache, string schemaString)
{
Debug.Assert(s_schemaCache != null);
suhsteve marked this conversation as resolved.
Show resolved Hide resolved
Debug.Assert((_args != null) && (_args.Length == 1) && (_args[0] is string));
suhsteve marked this conversation as resolved.
Show resolved Hide resolved
var schemaString = (string)_args[0];
if (!s_schemaCache.TryGetValue(schemaString, out StructType schema))
if (!schemaCache.TryGetValue(schemaString, out StructType schema))
{
schema = (StructType)DataType.ParseDataType(schemaString);
s_schemaCache.Add(schemaString, schema);
schemaCache.Add(schemaString, schema);
}

return schema;
}
}

/// <summary>
/// Created from <see cref="RowConstructor"/> and subsequently used
/// by the Unpickler to construct a <see cref="Row"/>.
/// </summary>
internal sealed class RowWithSchemaConstructor : IObjectConstructor
{
private readonly StructType _schema;

internal RowWithSchemaConstructor(StructType schema)
{
_schema = schema;
}

/// <summary>
/// Used by Unpickler to pass unpickled row values for handling.
/// </summary>
/// <param name="args">Unpickled row values.</param>
/// <returns>Row object.</returns>
public object construct(object[] args) => new Row(args, _schema);
}
}
12 changes: 0 additions & 12 deletions src/csharp/Microsoft.Spark/Sql/Types/ComplexTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,5 @@ private DataType FromJson(JObject json)
fieldJObject => new StructField(fieldJObject)).ToList();
return this;
}

internal override bool NeedConversion() => true;

internal override object FromInternal(object obj)
{
if (obj is RowConstructor rowConstructor)
{
return rowConstructor.GetRow();
}

return obj;
}
}
}
2 changes: 1 addition & 1 deletion src/csharp/Microsoft.Spark/Sql/Types/DataType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public abstract class DataType
/// <summary>
/// Normalized type name.
/// </summary>
public string TypeName => _typeName ?? (_typeName = NormalizeTypeName(GetType().Name));
public string TypeName => _typeName ??= NormalizeTypeName(GetType().Name);

/// <summary>
/// Simple string version of the current data type.
Expand Down