Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-16792][SQL] Dataset containing a Case Class with a List type causes a CompileException (converting sequence to list) #16240

Closed
wants to merge 12 commits into from
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,50 @@ object ScalaReflection extends ScalaReflection {
"array",
ObjectType(classOf[Array[Any]]))

StaticInvoke(
val wrappedArray = StaticInvoke(
scala.collection.mutable.WrappedArray.getClass,
ObjectType(classOf[Seq[_]]),
"make",
array :: Nil)

if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) {
wrappedArray
} else {
// Convert to another type using `to`
val cls = mirror.runtimeClass(t.typeSymbol.asClass)
import scala.collection.generic.CanBuildFrom
import scala.reflect.ClassTag

// Some canBuildFrom methods take an implicit ClassTag parameter
val cbfParams = try {
cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]])
StaticInvoke(
ClassTag.getClass,
ObjectType(classOf[ClassTag[_]]),
"apply",
StaticInvoke(
cls,
ObjectType(classOf[Class[_]]),
"getClass"
) :: Nil
) :: Nil
} catch {
case _: NoSuchMethodException => Nil
}

Invoke(
wrappedArray,
"to",
ObjectType(cls),
StaticInvoke(
cls,
ObjectType(classOf[CanBuildFrom[_, _, _]]),
"canBuildFrom",
cbfParams
) :: Nil
)
}

case t if t <:< localTypeOf[Map[_, _]] =>
// TODO: add walked type path for map
val TypeRef(_, _, Seq(keyType, valueType)) = t
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,37 @@ class ScalaReflectionSuite extends SparkFunSuite {
.cls.isAssignableFrom(classOf[org.apache.spark.sql.catalyst.util.GenericArrayData]))
}

test("SPARK 16792: Get correct deserializer for List[_]") {
val listDeserializer = deserializerFor[List[Int]]
assert(listDeserializer.dataType == ObjectType(classOf[List[_]]))
}

test("serialize and deserialize arbitrary sequence types") {
import scala.collection.immutable.Queue
val queueSerializer = serializerFor[Queue[Int]](BoundReference(
0, ObjectType(classOf[Queue[Int]]), nullable = false))
assert(queueSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val queueDeserializer = deserializerFor[Queue[Int]]
assert(queueDeserializer.dataType == ObjectType(classOf[Queue[_]]))

import scala.collection.mutable.ArrayBuffer
val arrayBufferSerializer = serializerFor[ArrayBuffer[Int]](BoundReference(
0, ObjectType(classOf[ArrayBuffer[Int]]), nullable = false))
assert(arrayBufferSerializer.dataType.head.dataType ==
ArrayType(IntegerType, containsNull = false))
val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]]
assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]]))

// Check whether conversion is skipped when using WrappedArray[_] supertype
// (would otherwise needlessly add overhead)
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
val seqDeserializer = deserializerFor[Seq[Int]]
assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject ==
scala.collection.mutable.WrappedArray.getClass)
assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make")
}

private val dataTypeForComplexData = dataTypeFor[ComplexData]
private val typeOfComplexData = typeOf[ComplexData]

Expand Down
115 changes: 94 additions & 21 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
* @since 1.6.0
*/
@InterfaceStability.Evolving
abstract class SQLImplicits {
abstract class SQLImplicits extends LowPrioritySQLImplicits {

protected def _sqlContext: SQLContext

Expand All @@ -45,9 +45,6 @@ abstract class SQLImplicits {
}
}

/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

// Primitives

/** @since 1.6.0 */
Expand Down Expand Up @@ -99,33 +96,96 @@ abstract class SQLImplicits {

// Seqs

/** @since 1.6.1 */
implicit def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newIntSequenceEncoder]]
*/
def newIntSeqEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you doesn't need to do this, just update the project/MimaExcludes.scala to add something like ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.newDoubleSeqEncoder"). Please see the history of the MimaExcludes.scala to see how others update this file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, I'm not sure I agree... Do we want to break binary compatibility for libraries that might be using this function? That could have even been resolved implicitly, so it would be confusing when it breaks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah i see, makes sense


/** @since 1.6.1 */
implicit def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newLongSequenceEncoder]]
*/
def newLongSeqEncoder: Encoder[Seq[Long]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newDoubleSequenceEncoder]]
*/
def newDoubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newFloatSequenceEncoder]]
*/
def newFloatSeqEncoder: Encoder[Seq[Float]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newByteSequenceEncoder]]
*/
def newByteSeqEncoder: Encoder[Seq[Byte]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newShortSequenceEncoder]]
*/
def newShortSeqEncoder: Encoder[Seq[Short]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newBooleanSequenceEncoder]]
*/
def newBooleanSeqEncoder: Encoder[Seq[Boolean]] = ExpressionEncoder()

/** @since 1.6.1 */
implicit def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()
/**
* @since 1.6.1
* @deprecated use [[newStringSequenceEncoder]]
*/
def newStringSeqEncoder: Encoder[Seq[String]] = ExpressionEncoder()

/** @since 1.6.1 */
/**
* @since 1.6.1
* @deprecated use [[newProductSequenceEncoder]]
*/
implicit def newProductSeqEncoder[A <: Product : TypeTag]: Encoder[Seq[A]] = ExpressionEncoder()

/** @since 2.2.0 */
implicit def newIntSequenceEncoder[T <: Seq[Int] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newLongSequenceEncoder[T <: Seq[Long] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newDoubleSequenceEncoder[T <: Seq[Double] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newFloatSequenceEncoder[T <: Seq[Float] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newByteSequenceEncoder[T <: Seq[Byte] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newShortSequenceEncoder[T <: Seq[Short] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newBooleanSequenceEncoder[T <: Seq[Boolean] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newStringSequenceEncoder[T <: Seq[String] : TypeTag]: Encoder[T] =
ExpressionEncoder()

/** @since 2.2.0 */
implicit def newProductSequenceEncoder[T <: Seq[Product] : TypeTag]: Encoder[T] =
ExpressionEncoder()

// Arrays

/** @since 1.6.1 */
Expand Down Expand Up @@ -180,3 +240,16 @@ abstract class SQLImplicits {
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)

}

/**
* Lower priority implicit methods for converting Scala objects into [[Dataset]]s.
* Conflicting implicits are placed here to disambiguate resolution.
*
* Reasons for including specific implicits:
* newProductEncoder - to disambiguate for [[List]]s which are both [[Seq]] and [[Product]]
*/
trait LowPrioritySQLImplicits {
/** @since 1.6.0 */
implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T]

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,21 @@

package org.apache.spark.sql

import scala.collection.immutable.Queue
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.test.SharedSQLContext

case class IntClass(value: Int)

case class SeqClass(s: Seq[Int])

case class ListClass(l: List[Int])

case class QueueClass(q: Queue[Int])

case class ComplexClass(seq: SeqClass, list: ListClass, queue: QueueClass)

package object packageobject {
case class PackageClass(value: Int)
}
Expand Down Expand Up @@ -130,6 +141,62 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext {
checkDataset(Seq(Array(Tuple1(1))).toDS(), Array(Tuple1(1)))
}

test("arbitrary sequences") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also test nested sequences, e.g. List(Queue(1)), and sequences inside product, e.g. List(1) -> Queue(1)

Copy link
Contributor Author

@michalsenkyr michalsenkyr Jan 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some sequence-product combination tests.
Nested sequences were never supported (tried on master and 2.0.2). That would probably be worthy of another ticket.

checkDataset(Seq(Queue(1)).toDS(), Queue(1))
checkDataset(Seq(Queue(1.toLong)).toDS(), Queue(1.toLong))
checkDataset(Seq(Queue(1.toDouble)).toDS(), Queue(1.toDouble))
checkDataset(Seq(Queue(1.toFloat)).toDS(), Queue(1.toFloat))
checkDataset(Seq(Queue(1.toByte)).toDS(), Queue(1.toByte))
checkDataset(Seq(Queue(1.toShort)).toDS(), Queue(1.toShort))
checkDataset(Seq(Queue(true)).toDS(), Queue(true))
checkDataset(Seq(Queue("test")).toDS(), Queue("test"))
checkDataset(Seq(Queue(Tuple1(1))).toDS(), Queue(Tuple1(1)))

checkDataset(Seq(ArrayBuffer(1)).toDS(), ArrayBuffer(1))
checkDataset(Seq(ArrayBuffer(1.toLong)).toDS(), ArrayBuffer(1.toLong))
checkDataset(Seq(ArrayBuffer(1.toDouble)).toDS(), ArrayBuffer(1.toDouble))
checkDataset(Seq(ArrayBuffer(1.toFloat)).toDS(), ArrayBuffer(1.toFloat))
checkDataset(Seq(ArrayBuffer(1.toByte)).toDS(), ArrayBuffer(1.toByte))
checkDataset(Seq(ArrayBuffer(1.toShort)).toDS(), ArrayBuffer(1.toShort))
checkDataset(Seq(ArrayBuffer(true)).toDS(), ArrayBuffer(true))
checkDataset(Seq(ArrayBuffer("test")).toDS(), ArrayBuffer("test"))
checkDataset(Seq(ArrayBuffer(Tuple1(1))).toDS(), ArrayBuffer(Tuple1(1)))
}

test("sequence and product combinations") {
// Case classes
checkDataset(Seq(SeqClass(Seq(1))).toDS(), SeqClass(Seq(1)))
checkDataset(Seq(Seq(SeqClass(Seq(1)))).toDS(), Seq(SeqClass(Seq(1))))
checkDataset(Seq(List(SeqClass(Seq(1)))).toDS(), List(SeqClass(Seq(1))))
checkDataset(Seq(Queue(SeqClass(Seq(1)))).toDS(), Queue(SeqClass(Seq(1))))

checkDataset(Seq(ListClass(List(1))).toDS(), ListClass(List(1)))
checkDataset(Seq(Seq(ListClass(List(1)))).toDS(), Seq(ListClass(List(1))))
checkDataset(Seq(List(ListClass(List(1)))).toDS(), List(ListClass(List(1))))
checkDataset(Seq(Queue(ListClass(List(1)))).toDS(), Queue(ListClass(List(1))))

checkDataset(Seq(QueueClass(Queue(1))).toDS(), QueueClass(Queue(1)))
checkDataset(Seq(Seq(QueueClass(Queue(1)))).toDS(), Seq(QueueClass(Queue(1))))
checkDataset(Seq(List(QueueClass(Queue(1)))).toDS(), List(QueueClass(Queue(1))))
checkDataset(Seq(Queue(QueueClass(Queue(1)))).toDS(), Queue(QueueClass(Queue(1))))

val complex = ComplexClass(SeqClass(Seq(1)), ListClass(List(2)), QueueClass(Queue(3)))
checkDataset(Seq(complex).toDS(), complex)
checkDataset(Seq(Seq(complex)).toDS(), Seq(complex))
checkDataset(Seq(List(complex)).toDS(), List(complex))
checkDataset(Seq(Queue(complex)).toDS(), Queue(complex))

// Tuples
checkDataset(Seq(Seq(1) -> Seq(2)).toDS(), Seq(1) -> Seq(2))
checkDataset(Seq(List(1) -> Queue(2)).toDS(), List(1) -> Queue(2))
checkDataset(Seq(List(Seq("test1") -> List(Queue("test2")))).toDS(),
List(Seq("test1") -> List(Queue("test2"))))

// Complex
checkDataset(Seq(ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2)))).toDS(),
ListClass(List(1)) -> Queue("test" -> SeqClass(Seq(2))))
}

test("package objects") {
import packageobject._
checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1))
Expand Down