diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 920490f9d0d61..d2cf9efceef25 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -103,7 +103,7 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser kryo.readClassAndObject(input).asInstanceOf[T] } catch { // DeserializationStream uses the EOF exception to indicate stopping condition. - case _: KryoException => throw new EOFException + case e: KryoException if e.getMessage == "Buffer underflow." => throw new EOFException } } diff --git a/core/src/main/scala/org/apache/spark/util/MutablePair.scala b/core/src/main/scala/org/apache/spark/util/MutablePair.scala index b053266f12748..e317882f6e856 100644 --- a/core/src/main/scala/org/apache/spark/util/MutablePair.scala +++ b/core/src/main/scala/org/apache/spark/util/MutablePair.scala @@ -25,10 +25,20 @@ package org.apache.spark.util * @param _2 Element 2 of this MutablePair */ case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1, - @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] + @specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2] (var _1: T1, var _2: T2) extends Product2[T1, T2] { + /** No-arg constructor for serialization */ + def this() = this(null.asInstanceOf[T1], null.asInstanceOf[T2]) + + /** Updates this pair with new values and returns itself */ + def apply(n1: T1, n2: T2): MutablePair[T1, T2] = { + _1 = n1 + _2 = n2 + this + } + override def toString = "(" + _1 + "," + _2 + ")" override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a07045fed79f9..78aaaeebbd631 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -30,6 +30,7 @@ abstract class Expression extends TreeNode[Expression] { type EvaluatedType <: Any def dataType: DataType + /** * Returns true when an expression is a candidate for static evaluation before the query is * executed. @@ -53,14 +54,6 @@ abstract class Expression extends TreeNode[Expression] { def apply(input: Row = null): EvaluatedType = throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - // Primitive Accessor functions that avoid boxing for performance. - // Note this is an Unstable API as it doesn't correctly handle null values yet. - - def applyBoolean(input: Row): Boolean = apply(input).asInstanceOf[Boolean] - def applyInt(input: Row): Int = apply(input).asInstanceOf[Int] - def applyDouble(input: Row): Double = apply(input).asInstanceOf[Double] - def applyString(input: Row): String = apply(input).asInstanceOf[String] - /** * Returns `true` if this expression and all its children have been resolved to a specific schema * and `false` if it is still contains any unresolved placeholders. Implementations of expressions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 4eea200c23a1a..ae04be2fb0d1c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -2,10 +2,9 @@ package org.apache.spark.sql.catalyst package expressions /** - * Converts a Row to another Row given a set of expressions. - * - * If the schema of the input row is specified, then the given expression will be bound to that - * schema. + * Converts a [[Row]] to another Row given a sequence of expression that define each column of the + * new row. If the schema of the input row is specified, then the given expression will be bound to + * that schema. */ class Projection(expressions: Seq[Expression]) extends (Row => Row) { def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = @@ -23,6 +22,33 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) { } } +/** + * Converts a [[Row]] to another Row given a sequence of expression that define each column of th + * new row. If the schema of the input row is specified, then the given expression will be bound to + * that schema. + * + * In contrast to a normal projection, a MutableProjection reuses the same underlying row object + * each time an input row is added. This significatly reduces the cost of calcuating the + * projection, but means that it is not safe + */ +case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) { + def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = + this(expressions.map(BindReferences.bindReference(_, inputSchema))) + + private[this] val exprArray = expressions.toArray + private[this] val mutableRow = new GenericMutableRow(exprArray.size) + def currentValue: Row = mutableRow + + def apply(input: Row): Row = { + var i = 0 + while (i < exprArray.size) { + mutableRow(i) = exprArray(i).apply(input) + i += 1 + } + mutableRow + } +} + /** * A mutable wrapper that makes two rows appear appear as a single concatenated row. Designed to * be instantiated once per thread and reused. @@ -68,4 +94,17 @@ class JoinedRow extends Row { def getFloat(i: Int): Float = if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size) + def getString(i: Int): String = + if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + + def copy() = { + val totalSize = row1.size + row2.size + val copiedValues = new Array[Any](totalSize) + var i = 0 + while(i < totalSize) { + copiedValues(i) = apply(i) + i += 1 + } + new GenericRow(copiedValues) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 6bb0decc0e1cf..6d12e6c606373 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -28,7 +28,7 @@ import types._ * It is invalid to use the native primitive interface to retrieve a value that is null, instead a * user must check [[isNullAt]] before attempting to retrieve a value that might be null. */ -abstract class Row extends Seq[Any] with Serializable { +trait Row extends Seq[Any] with Serializable { def apply(i: Int): Any def isNullAt(i: Int): Boolean @@ -40,9 +40,39 @@ abstract class Row extends Seq[Any] with Serializable { def getBoolean(i: Int): Boolean def getShort(i: Int): Short def getByte(i: Int): Byte + def getString(i: Int): String override def toString() = s"[${this.mkString(",")}]" + + def copy(): Row +} + +/** + * An extended interface to [[Row]] that allows the values for each column to be updated. Setting + * a value through a primitive function implicitly marks that column as not null. + */ +trait MutableRow extends Row { + def setNullAt(i: Int): Unit + + def update(ordinal: Int, value: Any) + + def setInt(ordinal: Int, value: Int) + def setLong(ordinal: Int, value: Long) + def setDouble(ordinal: Int, value: Double) + def setBoolean(ordinal: Int, value: Boolean) + def setShort(ordinal: Int, value: Short) + def setByte(ordinal: Int, value: Byte) + def setFloat(ordinal: Int, value: Byte) + + /** + * EXPERIMENTAL + * + * Returns a mutable string builder for the specified column. A given row should return the + * result of any mutations made to the returned buffer next time getString is called for the same + * column. + */ + def getStringBuilder(ordinal: Int): StringBuilder } /** @@ -62,12 +92,22 @@ object EmptyRow extends Row { def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException + def getString(i: Int): String = throw new UnsupportedOperationException + + def copy() = this } /** - * A row implementation that uses an array of objects as the underlying storage. + * A row implementation that uses an array of objects as the underlying storage. Note that, while + * the array is not copied, and thus could technically be mutated after creation, this is not + * allowed. */ -class GenericRow(val values: Array[Any]) extends Row { +class GenericRow(protected[catalyst] val values: Array[Any]) extends Row { + /** No-arg constructor for serialization. */ + def this() = this(null) + + def this(size: Int) = this(new Array[Any](size)) + def iterator = values.iterator def length = values.length @@ -80,32 +120,68 @@ class GenericRow(val values: Array[Any]) extends Row { if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") values(i).asInstanceOf[Int] } + def getLong(i: Int): Long = { if (values(i) == null) sys.error("Failed to check null bit for primitive long value.") values(i).asInstanceOf[Long] } + def getDouble(i: Int): Double = { if (values(i) == null) sys.error("Failed to check null bit for primitive double value.") values(i).asInstanceOf[Double] } + def getFloat(i: Int): Float = { if (values(i) == null) sys.error("Failed to check null bit for primitive float value.") values(i).asInstanceOf[Float] } + def getBoolean(i: Int): Boolean = { if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.") values(i).asInstanceOf[Boolean] } + def getShort(i: Int): Short = { if (values(i) == null) sys.error("Failed to check null bit for primitive short value.") values(i).asInstanceOf[Short] } + def getByte(i: Int): Byte = { if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") values(i).asInstanceOf[Byte] } + + def getString(i: Int): String = { + if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") + values(i).asInstanceOf[String] + } + + def copy() = this } +class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { + /** No-arg constructor for serialization. */ + def this() = this(0) + + def getStringBuilder(ordinal: Int): StringBuilder = ??? + + override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value } + override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } + override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value } + override def setFloat(ordinal: Int,value: Byte): Unit = { values(ordinal) = value } + override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value } + override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value } + + override def setNullAt(i: Int): Unit = { values(i) = null } + + override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value } + + override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value } + + override def copy() = new GenericRow(values.clone()) +} + + class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] { def compare(a: Row, b: Row): Int = { var i = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 010064489c421..2287a849e6831 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -170,12 +170,15 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) def this() = this(null, null) // Required for serialization. private var count: Long = _ - private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(null)) + private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(EmptyRow)) private val sumAsDouble = Cast(sum, DoubleType) + + private val addFunction = Add(sum, expr) - override def apply(input: Row): Any = sumAsDouble.applyDouble(null) / count.toDouble + override def apply(input: Row): Any = + sumAsDouble.apply(EmptyRow).asInstanceOf[Double] / count.toDouble def update(input: Row): Unit = { count += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSqlContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSqlContext.scala index f25de3827855c..edf912bd8fa92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSqlContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSqlContext.scala @@ -112,7 +112,7 @@ class SparkSqlContext(val sparkContext: SparkContext) extends Logging { lazy val executedPlan: SparkPlan = PrepareForExecution(sparkPlan) // TODO: We are loosing schema here. - lazy val toRdd: RDD[Row] = executedPlan.execute() + lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy()) protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 4226a1a85509c..e7ed1a5e1658d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -18,17 +18,52 @@ package org.apache.spark.sql package execution +import java.nio.ByteBuffer + +import com.esotericsoftware.kryo.{Kryo, Serializer} +import com.esotericsoftware.kryo.io.{Output, Input} + +import org.apache.spark.{SparkConf, RangePartitioner, HashPartitioner} +import org.apache.spark.rdd.ShuffledRDD +import org.apache.spark.serializer.KryoSerializer +import org.apache.spark.util.MutablePair + import catalyst.rules.Rule import catalyst.errors._ import catalyst.expressions._ import catalyst.plans.physical._ -import org.apache.spark.{RangePartitioner, HashPartitioner} -import org.apache.spark.rdd.ShuffledRDD +class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { + override def newKryo(): Kryo = { + val kryo = new Kryo + kryo.setRegistrationRequired(true) + kryo.register(classOf[MutablePair[_,_]]) + kryo.register(classOf[Array[Any]]) + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) + kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]]) + kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer) + kryo.setReferences(false) + kryo.setClassLoader(this.getClass.getClassLoader) + kryo + } +} + +class BigDecimalSerializer extends Serializer[BigDecimal] { + def write(kryo: Kryo, output: Output, bd: math.BigDecimal) { + // TODO: There are probably more efficient representations than strings... + output.writeString(bd.toString) + } + + def read(kryo: Kryo, input: Input, tpe: Class[BigDecimal]): BigDecimal = { + BigDecimal(input.readString()) + } +} case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode { override def outputPartitioning = newPartitioning + def output = child.output def execute() = attachTree(this , "execute") { @@ -36,21 +71,26 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una case HashPartitioning(expressions, numPartitions) => { // TODO: Eliminate redundant expressions in grouping key and value. val rdd = child.execute().mapPartitions { iter => - val hashExpressions = new Projection(expressions) - iter.map(r => (hashExpressions(r), r)) + val hashExpressions = new MutableProjection(expressions) + val mutablePair = new MutablePair[Row, Row]() + iter.map(r => mutablePair(hashExpressions(r), r)) } val part = new HashPartitioner(numPartitions) - val shuffled = new ShuffledRDD[Row, Row, (Row, Row)](rdd, part) - + val shuffled = new ShuffledRDD[Row, Row, MutablePair[Row, Row]](rdd, part) + shuffled.setSerializer(classOf[SparkSqlSerializer].getName) shuffled.map(_._2) } case RangePartitioning(sortingExpressions, numPartitions) => { - // TODO: ShuffledRDD should take an Ordering. + // TODO: RangePartitioner should take an Ordering. implicit val ordering = new RowOrdering(sortingExpressions) - val rdd = child.execute().map(row => (row, null)) + val rdd = child.execute().mapPartitions { iter => + val mutablePair = new MutablePair[Row, Null](null, null) + iter.map(row => mutablePair(row, null)) + } val part = new RangePartitioner(numPartitions, rdd, ascending = true) - val shuffled = new ShuffledRDD[Row, Null, (Row, Null)](rdd, part) + val shuffled = new ShuffledRDD[Row, Null, MutablePair[Row, Null]](rdd, part) + shuffled.setSerializer(classOf[SparkSqlSerializer].getName) shuffled.map(_._1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6c10a537f225e..828618ebd4d4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -154,7 +154,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } protected lazy val singleRowRdd = - sparkContext.parallelize(Seq(new GenericRow(Array()): Row), 1) + sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) def convertToCatalyst(a: Any): Any = a match { case s: Seq[Any] => s.map(convertToCatalyst) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala index 132576cbc4fae..51889c1988680 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregates.scala @@ -90,7 +90,7 @@ case class Aggregate( // in the [[catalyst.execution.Exchange]]. val grouped = child.execute().mapPartitions { iter => val buildGrouping = new Projection(groupingExpressions) - iter.map(row => (buildGrouping(row), row)) + iter.map(row => (buildGrouping(row), row.copy())) }.groupByKeyLocally() val result = grouped.map { case (group, rows) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 9a61c021fceea..40d3ef8128756 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -30,15 +30,16 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends def output = projectList.map(_.toAttribute) def execute() = child.execute().mapPartitions { iter => - val buildProjection = new Projection(projectList) - iter.map(buildProjection) + @transient val resuableProjection = new MutableProjection(projectList) + iter.map(resuableProjection) } } case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { def output = child.output + def execute() = child.execute().mapPartitions { iter => - iter.filter(condition.applyBoolean) + iter.filter(condition.apply(_).asInstanceOf[Boolean]) } } @@ -64,7 +65,7 @@ case class StopAfter(limit: Int, child: SparkPlan)(@transient sc: SparkContext) def output = child.output - override def executeCollect() = child.execute().take(limit) + override def executeCollect() = child.execute().map(_.copy()).take(limit) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. @@ -80,7 +81,7 @@ case class TopK(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) @transient lazy val ordering = new RowOrdering(sortOrder) - override def executeCollect() = child.execute().takeOrdered(limit)(ordering) + override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ordering) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. @@ -103,7 +104,7 @@ case class Sort( // TODO: Optimize sorting operation? child.execute() .mapPartitions( - iterator => iterator.toArray.sorted(ordering).iterator, + iterator => iterator.map(_.copy()).toArray.sorted(ordering).iterator, preservesPartitioning = true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index d44905dc84260..5934fd1b03bfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -46,12 +46,12 @@ case class SparkEquiInnerJoin( def execute() = attachTree(this, "execute") { val leftWithKeys = left.execute().mapPartitions { iter => val generateLeftKeys = new Projection(leftKeys, left.output) - iter.map(row => (generateLeftKeys(row), row)) + iter.map(row => (generateLeftKeys(row), row.copy())) } val rightWithKeys = right.execute().mapPartitions { iter => val generateRightKeys = new Projection(rightKeys, right.output) - iter.map(row => (generateRightKeys(row), row)) + iter.map(row => (generateRightKeys(row), row.copy())) } // Do the join. @@ -73,7 +73,7 @@ case class SparkEquiInnerJoin( case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { def output = left.output ++ right.output - def execute() = left.execute().cartesian(right.execute()).map { + def execute() = left.execute().map(_.copy()).cartesian(right.execute().map(_.copy())).map { case (l: Row, r: Row) => buildRow(l ++ r) } } @@ -95,17 +95,19 @@ case class BroadcastNestedLoopJoin( /** The Broadcast relation */ def right = broadcast + @transient lazy val boundCondition = + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true)) + + def execute() = { - val broadcastedRelation = sc.broadcast(broadcast.execute().collect().toIndexedSeq) + val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => val matchedRows = new mutable.ArrayBuffer[Row] val includedBroadcastTuples = new mutable.BitSet(broadcastedRelation.value.size) val joinedRow = new JoinedRow - val boundCondition = - condition - .map(c => BindReferences.bindReference(c, left.output ++ right.output)) - .getOrElse(Literal(true)) streamedIter.foreach { streamedRow => var i = 0 @@ -114,7 +116,7 @@ case class BroadcastNestedLoopJoin( while (i < broadcastedRelation.value.size) { // TODO: One bitset per partition instead of per row. val broadcastedRow = broadcastedRelation.value(i) - if (boundCondition.applyBoolean(joinedRow(streamedRow, broadcastedRow))) { + if (boundCondition(joinedRow(streamedRow, broadcastedRow)).asInstanceOf[Boolean]) { matchedRows += buildRow(streamedRow ++ broadcastedRow) matched = true includedBroadcastTuples += i diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala index df1a95c69c5bd..c790243082825 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TgfSuite.scala @@ -45,11 +45,11 @@ case class ExampleTGF(input: Seq[Attribute] = Seq('name, 'age)) extends Generato override def apply(input: Row): TraversableOnce[Row] = { val name = nameAttr.apply(input) - val age = ageAttr.applyInt(input) + val age = ageAttr.apply(input).asInstanceOf[Int] Iterator( - new GenericRow(Array(s"$name is $age years old")), - new GenericRow(Array(s"Next year, $name will be ${age + 1} years old"))) + new GenericRow(Array[Any](s"$name is $age years old")), + new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old"))) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index e269c8c846700..3e8e491e8758a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -205,7 +205,7 @@ abstract class HiveContext(sc: SparkContext) extends SparkSqlContext(sc) { override val planner = HivePlanner - protected lazy val emptyResult = sparkContext.parallelize(Seq(new GenericRow(Array()): Row), 1) + protected lazy val emptyResult = sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) /** Extends QueryExecution with hive specific features. */ abstract class QueryExecution extends super.QueryExecution { @@ -226,7 +226,7 @@ abstract class HiveContext(sc: SparkContext) extends SparkSqlContext(sc) { sparkContext.parallelize(asRows, 1) } case _ => - executedPlan.execute() + executedPlan.execute.map(_.copy()) } protected val primitiveTypes = diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala index d1ae4c4111159..d20fd87f34f48 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala @@ -141,7 +141,7 @@ case class HiveTableScan( // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. val row = new GenericRow(castedValues.toArray) - shouldKeep.applyBoolean(row) + shouldKeep.apply(row).asInstanceOf[Boolean] } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index c04bbe8d41785..5e775d6a048de 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -80,14 +80,14 @@ object HiveFunctionRegistry case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType case c: Class[_] if c == classOf[java.lang.String] => StringType case c: Class[_] if c == java.lang.Short.TYPE => ShortType - case c: Class[_] if c == java.lang.Integer.TYPE => ShortType + case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType case c: Class[_] if c == java.lang.Long.TYPE => LongType case c: Class[_] if c == java.lang.Double.TYPE => DoubleType case c: Class[_] if c == java.lang.Byte.TYPE => ByteType case c: Class[_] if c == java.lang.Float.TYPE => FloatType case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType case c: Class[_] if c == classOf[java.lang.Short] => ShortType - case c: Class[_] if c == classOf[java.lang.Integer] => ShortType + case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType case c: Class[_] if c == classOf[java.lang.Long] => LongType case c: Class[_] if c == classOf[java.lang.Double] => DoubleType case c: Class[_] if c == classOf[java.lang.Byte] => ByteType