Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16792][SQL] Dataset containing a Case Class with a List type causes a CompileException (converting sequence to list) #16240

Closed
wants to merge 12 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,46 @@ object ScalaReflection extends ScalaReflection {
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag
import scala.util.{Try, Success}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark code style discourage the usage of Try and Success, can you refactor your code a little bit? i.e. move cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) out of the Invoke code block.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I tried looking up the code style you mentioned, but only found the Databricks' Scala Code Style Guide. And that is not mentioned in the Spark docs as far as I know.

Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
Try(cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])) match {
case Success(_) =>
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
case _ => Nil
}
) :: Nil
)
}

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}

test("SPARK 16792: Get correct deserializer for List[_]") {
val listDeserializer = deserializerFor[List[Int]]
assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
}

test("serialize and deserialize arbitrary sequence types") {
import scala.collection.immutable.Queue
val queueSerializer = serializerFor[Queue[Int]](BoundReference(
0, ObjectType(classOf[Queue[Int]]), nullable = false))
assert(queueSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val queueDeserializer = deserializerFor[Queue[Int]]
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))

import scala.collection.mutable.ArrayBuffer
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
assert(arrayBufferSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))

// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]

Expand Down
41 changes: 28 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
* @since 1.6.0
*/
@InterfaceStability.Evolving
abstract class SQLImplicits {
abstract class SQLImplicits extends LowPrioritySQLImplicits {

protected def _sqlContext: SQLContext

Expand All @@ -45,9 +45,6 @@ abstract class SQLImplicits {
}
}

/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

// Primitives

/** @since 1.6.0 */
Expand Down Expand Up @@ -100,31 +97,36 @@ abstract class SQLImplicits {
// Seqs

/** @since 1.6.1 */
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
implicit def newIntSeqEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
implicit def newLongSeqEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
implicit def newDoubleSeqEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
implicit def newFloatSeqEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
implicit def newByteSeqEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
implicit def newShortSeqEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
implicit def newBooleanSeqEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
implicit def newStringSeqEncoder[T <: Seq[String] : TypeTag]: Encoder[T] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()
implicit def newProductSeqEncoder[A <: Product : TypeTag, T <: Seq[A] : TypeTag]: Encoder[T] =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my only concern now. Can you provide more details about it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is the same as all the other ones, just with Product subclasses. If you were concerned about the TypeTag on A, it was actually not needed as T's tag already contains all the information. I just tested it to be sure and removed it.

ExpressionEncoder()

// Workaround for implicit resolution problem for Seq.toDS (only supports Seq)
implicit def newProductSeqOnlyEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] =
newProductSeqEncoder[A, Seq[A]]

// Arrays

Expand Down Expand Up @@ -180,3 +182,16 @@ abstract class SQLImplicits {
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)

}

/**
* Lower priority implicit methods for converting Scala objects into [[Dataset]]s.
* Conflicting implicits are placed here to disambiguate resolution.
*
* Reasons for including specific implicits:
* newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]]
*/
trait LowPrioritySQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

}
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,34 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}

test("arbitrary sequences") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also test nested sequences, e.g. List(Queue(1)), and sequences inside product, e.g. List(1) -> Queue(1)

Copy link
Contributor Author

@michalsenkyr michalsenkyr Jan 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some sequence-product combination tests.
Nested sequences were never supported (tried on master and 2.0.2). That would probably be worthy of another ticket.

import scala.collection.immutable.Queue
checkDataset(Seq(Queue(1)).toDS(), Queue(1))
checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong))
checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble))
checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat))
checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte))
checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort))
checkDataset(Seq(Queue(true)).toDS(), Queue(true))
checkDataset(Seq(Queue("test")).toDS(), Queue("test"))
// Implicit resolution problem - encoder needs to be provided explicitly
implicit val queueEncoder = newProductSeqEncoder[Tuple1[Int], Queue[Tuple1[Int]]]
checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1)))

import scala.collection.mutable.ArrayBuffer
checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1))
checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong))
checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble))
checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat))
checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte))
checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort))
checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true))
checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test"))
// Implicit resolution problem - encoder needs to be provided explicitly
implicit val arrayBufferEncoder = newProductSeqEncoder[Tuple1[Int], ArrayBuffer[Tuple1[Int]]]
checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1)))
}

test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
Expand Down