Skip to content

Commit

Permalink
Merge pull request apache#55 from marmbrus/mutableRows
Browse files Browse the repository at this point in the history
Add a framework for dealing with mutable rows.
  • Loading branch information
marmbrus committed Mar 6, 2014
2 parents c9f8fb3 + ba28849 commit c2a658d
Show file tree
Hide file tree
Showing 16 changed files with 218 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends Deser
kryo.readClassAndObject(input).asInstanceOf[T]
} catch {
// DeserializationStream uses the EOF exception to indicate stopping condition.
case _: KryoException => throw new EOFException
case e: KryoException if e.getMessage == "Buffer underflow." => throw new EOFException
}
}

Expand Down
12 changes: 11 additions & 1 deletion core/src/main/scala/org/apache/spark/util/MutablePair.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,20 @@ package org.apache.spark.util
* @param _2 Element 2 of this MutablePair
*/
case class MutablePair[@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T1,
@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2]
@specialized(Int, Long, Double, Char, Boolean/*, AnyRef*/) T2]
(var _1: T1, var _2: T2)
extends Product2[T1, T2]
{
/** No-arg constructor for serialization */
def this() = this(null.asInstanceOf[T1], null.asInstanceOf[T2])

/** Updates this pair with new values and returns itself */
def apply(n1: T1, n2: T2): MutablePair[T1, T2] = {
_1 = n1
_2 = n2
this
}

override def toString = "(" + _1 + "," + _2 + ")"

override def canEqual(that: Any): Boolean = that.isInstanceOf[MutablePair[_,_]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ abstract class Expression extends TreeNode[Expression] {
type EvaluatedType <: Any

def dataType: DataType

/**
* Returns true when an expression is a candidate for static evaluation before the query is
* executed.
Expand All @@ -53,14 +54,6 @@ abstract class Expression extends TreeNode[Expression] {
def apply(input: Row = null): EvaluatedType =
throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}")

// Primitive Accessor functions that avoid boxing for performance.
// Note this is an Unstable API as it doesn't correctly handle null values yet.

def applyBoolean(input: Row): Boolean = apply(input).asInstanceOf[Boolean]
def applyInt(input: Row): Int = apply(input).asInstanceOf[Int]
def applyDouble(input: Row): Double = apply(input).asInstanceOf[Double]
def applyString(input: Row): String = apply(input).asInstanceOf[String]

/**
* Returns `true` if this expression and all its children have been resolved to a specific schema
* and `false` if it is still contains any unresolved placeholders. Implementations of expressions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,9 @@ package org.apache.spark.sql.catalyst
package expressions

/**
* Converts a Row to another Row given a set of expressions.
*
* If the schema of the input row is specified, then the given expression will be bound to that
* schema.
* Converts a [[Row]] to another Row given a sequence of expression that define each column of the
* new row. If the schema of the input row is specified, then the given expression will be bound to
* that schema.
*/
class Projection(expressions: Seq[Expression]) extends (Row => Row) {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
Expand All @@ -23,6 +22,33 @@ class Projection(expressions: Seq[Expression]) extends (Row => Row) {
}
}

/**
* Converts a [[Row]] to another Row given a sequence of expression that define each column of th
* new row. If the schema of the input row is specified, then the given expression will be bound to
* that schema.
*
* In contrast to a normal projection, a MutableProjection reuses the same underlying row object
* each time an input row is added. This significatly reduces the cost of calcuating the
* projection, but means that it is not safe
*/
case class MutableProjection(expressions: Seq[Expression]) extends (Row => Row) {
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))

private[this] val exprArray = expressions.toArray
private[this] val mutableRow = new GenericMutableRow(exprArray.size)
def currentValue: Row = mutableRow

def apply(input: Row): Row = {
var i = 0
while (i < exprArray.size) {
mutableRow(i) = exprArray(i).apply(input)
i += 1
}
mutableRow
}
}

/**
* A mutable wrapper that makes two rows appear appear as a single concatenated row. Designed to
* be instantiated once per thread and reused.
Expand Down Expand Up @@ -68,4 +94,17 @@ class JoinedRow extends Row {
def getFloat(i: Int): Float =
if (i < row1.size) row1.getFloat(i) else row2.getFloat(i - row1.size)

def getString(i: Int): String =
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)

def copy() = {
val totalSize = row1.size + row2.size
val copiedValues = new Array[Any](totalSize)
var i = 0
while(i < totalSize) {
copiedValues(i) = apply(i)
i += 1
}
new GenericRow(copiedValues)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import types._
* It is invalid to use the native primitive interface to retrieve a value that is null, instead a
* user must check [[isNullAt]] before attempting to retrieve a value that might be null.
*/
abstract class Row extends Seq[Any] with Serializable {
trait Row extends Seq[Any] with Serializable {
def apply(i: Int): Any

def isNullAt(i: Int): Boolean
Expand All @@ -40,9 +40,39 @@ abstract class Row extends Seq[Any] with Serializable {
def getBoolean(i: Int): Boolean
def getShort(i: Int): Short
def getByte(i: Int): Byte
def getString(i: Int): String

override def toString() =
s"[${this.mkString(",")}]"

def copy(): Row
}

/**
* An extended interface to [[Row]] that allows the values for each column to be updated. Setting
* a value through a primitive function implicitly marks that column as not null.
*/
trait MutableRow extends Row {
def setNullAt(i: Int): Unit

def update(ordinal: Int, value: Any)

def setInt(ordinal: Int, value: Int)
def setLong(ordinal: Int, value: Long)
def setDouble(ordinal: Int, value: Double)
def setBoolean(ordinal: Int, value: Boolean)
def setShort(ordinal: Int, value: Short)
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Byte)

/**
* EXPERIMENTAL
*
* Returns a mutable string builder for the specified column. A given row should return the
* result of any mutations made to the returned buffer next time getString is called for the same
* column.
*/
def getStringBuilder(ordinal: Int): StringBuilder
}

/**
Expand All @@ -62,12 +92,22 @@ object EmptyRow extends Row {
def getBoolean(i: Int): Boolean = throw new UnsupportedOperationException
def getShort(i: Int): Short = throw new UnsupportedOperationException
def getByte(i: Int): Byte = throw new UnsupportedOperationException
def getString(i: Int): String = throw new UnsupportedOperationException

def copy() = this
}

/**
* A row implementation that uses an array of objects as the underlying storage.
* A row implementation that uses an array of objects as the underlying storage. Note that, while
* the array is not copied, and thus could technically be mutated after creation, this is not
* allowed.
*/
class GenericRow(val values: Array[Any]) extends Row {
class GenericRow(protected[catalyst] val values: Array[Any]) extends Row {
/** No-arg constructor for serialization. */
def this() = this(null)

def this(size: Int) = this(new Array[Any](size))

def iterator = values.iterator

def length = values.length
Expand All @@ -80,32 +120,68 @@ class GenericRow(val values: Array[Any]) extends Row {
if (values(i) == null) sys.error("Failed to check null bit for primitive int value.")
values(i).asInstanceOf[Int]
}

def getLong(i: Int): Long = {
if (values(i) == null) sys.error("Failed to check null bit for primitive long value.")
values(i).asInstanceOf[Long]
}

def getDouble(i: Int): Double = {
if (values(i) == null) sys.error("Failed to check null bit for primitive double value.")
values(i).asInstanceOf[Double]
}

def getFloat(i: Int): Float = {
if (values(i) == null) sys.error("Failed to check null bit for primitive float value.")
values(i).asInstanceOf[Float]
}

def getBoolean(i: Int): Boolean = {
if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.")
values(i).asInstanceOf[Boolean]
}

def getShort(i: Int): Short = {
if (values(i) == null) sys.error("Failed to check null bit for primitive short value.")
values(i).asInstanceOf[Short]
}

def getByte(i: Int): Byte = {
if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.")
values(i).asInstanceOf[Byte]
}

def getString(i: Int): String = {
if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.")
values(i).asInstanceOf[String]
}

def copy() = this
}

class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
/** No-arg constructor for serialization. */
def this() = this(0)

def getStringBuilder(ordinal: Int): StringBuilder = ???

override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value }
override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value }
override def setFloat(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value }

override def setNullAt(i: Int): Unit = { values(i) = null }

override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value }

override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value }

override def copy() = new GenericRow(values.clone())
}


class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[Row] {
def compare(a: Row, b: Row): Int = {
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,15 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
def this() = this(null, null) // Required for serialization.

private var count: Long = _
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(null))
private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).apply(EmptyRow))
private val sumAsDouble = Cast(sum, DoubleType)



private val addFunction = Add(sum, expr)

override def apply(input: Row): Any = sumAsDouble.applyDouble(null) / count.toDouble
override def apply(input: Row): Any =
sumAsDouble.apply(EmptyRow).asInstanceOf[Double] / count.toDouble

def update(input: Row): Unit = {
count += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class SparkSqlContext(val sparkContext: SparkContext) extends Logging {
lazy val executedPlan: SparkPlan = PrepareForExecution(sparkPlan)

// TODO: We are loosing schema here.
lazy val toRdd: RDD[Row] = executedPlan.execute()
lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())

protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,39 +18,79 @@
package org.apache.spark.sql
package execution

import java.nio.ByteBuffer

import com.esotericsoftware.kryo.{Kryo, Serializer}
import com.esotericsoftware.kryo.io.{Output, Input}

import org.apache.spark.{SparkConf, RangePartitioner, HashPartitioner}
import org.apache.spark.rdd.ShuffledRDD
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.MutablePair

import catalyst.rules.Rule
import catalyst.errors._
import catalyst.expressions._
import catalyst.plans.physical._

import org.apache.spark.{RangePartitioner, HashPartitioner}
import org.apache.spark.rdd.ShuffledRDD
class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) {
override def newKryo(): Kryo = {
val kryo = new Kryo
kryo.setRegistrationRequired(true)
kryo.register(classOf[MutablePair[_,_]])
kryo.register(classOf[Array[Any]])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
kryo.setClassLoader(this.getClass.getClassLoader)
kryo
}
}

class BigDecimalSerializer extends Serializer[BigDecimal] {
def write(kryo: Kryo, output: Output, bd: math.BigDecimal) {
// TODO: There are probably more efficient representations than strings...
output.writeString(bd.toString)
}

def read(kryo: Kryo, input: Input, tpe: Class[BigDecimal]): BigDecimal = {
BigDecimal(input.readString())
}
}

case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends UnaryNode {

override def outputPartitioning = newPartitioning

def output = child.output

def execute() = attachTree(this , "execute") {
newPartitioning match {
case HashPartitioning(expressions, numPartitions) => {
// TODO: Eliminate redundant expressions in grouping key and value.
val rdd = child.execute().mapPartitions { iter =>
val hashExpressions = new Projection(expressions)
iter.map(r => (hashExpressions(r), r))
val hashExpressions = new MutableProjection(expressions)
val mutablePair = new MutablePair[Row, Row]()
iter.map(r => mutablePair(hashExpressions(r), r))
}
val part = new HashPartitioner(numPartitions)
val shuffled = new ShuffledRDD[Row, Row, (Row, Row)](rdd, part)

val shuffled = new ShuffledRDD[Row, Row, MutablePair[Row, Row]](rdd, part)
shuffled.setSerializer(classOf[SparkSqlSerializer].getName)
shuffled.map(_._2)
}
case RangePartitioning(sortingExpressions, numPartitions) => {
// TODO: ShuffledRDD should take an Ordering.
// TODO: RangePartitioner should take an Ordering.
implicit val ordering = new RowOrdering(sortingExpressions)

val rdd = child.execute().map(row => (row, null))
val rdd = child.execute().mapPartitions { iter =>
val mutablePair = new MutablePair[Row, Null](null, null)
iter.map(row => mutablePair(row, null))
}
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
val shuffled = new ShuffledRDD[Row, Null, (Row, Null)](rdd, part)
val shuffled = new ShuffledRDD[Row, Null, MutablePair[Row, Null]](rdd, part)
shuffled.setSerializer(classOf[SparkSqlSerializer].getName)

shuffled.map(_._1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}

protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array()): Row), 1)
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)

def convertToCatalyst(a: Any): Any = a match {
case s: Seq[Any] => s.map(convertToCatalyst)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ case class Aggregate(
// in the [[catalyst.execution.Exchange]].
val grouped = child.execute().mapPartitions { iter =>
val buildGrouping = new Projection(groupingExpressions)
iter.map(row => (buildGrouping(row), row))
iter.map(row => (buildGrouping(row), row.copy()))
}.groupByKeyLocally()

val result = grouped.map { case (group, rows) =>
Expand Down
Loading

0 comments on commit c2a658d

Please sign in to comment.