Skip to content

Commit

Permalink
Addressed all PR comments by @marmbrus
Browse files Browse the repository at this point in the history
  • Loading branch information
liancheng committed Apr 2, 2014
1 parent d3a4fa9 commit ed71bbd
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.types._

private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable{
private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable {
/**
* Closed lower bound of this column.
*/
Expand Down Expand Up @@ -246,14 +246,25 @@ private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) {
}
}

object IntColumnStats {
private[sql] object IntColumnStats {
val UNINITIALIZED = 0
val INITIALIZED = 1
val ASCENDING = 2
val DESCENDING = 3
val UNORDERED = 4
}

/**
* Statistical information for `Int` columns. More information is collected since `Int` is
* frequently used. Extra information include:
*
* - Ordering state (ascending/descending/unordered), may be used to decide whether binary search
* is applicable when searching elements.
* - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression
* scheme.
*
* (This two kinds of information are not used anywhere yet and might be removed later.)
*/
private[sql] class IntColumnStats extends BasicColumnStats(INT) {
import IntColumnStats._

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType](
* Creates a duplicated copy of the value.
*/
def clone(v: JvmType): JvmType = v

override def toString = getClass.getSimpleName.stripSuffix("$")
}

private[sql] abstract class NativeColumnType[T <: NativeType](
Expand Down Expand Up @@ -258,7 +260,7 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) {
}

private[sql] object ColumnType {
implicit def dataTypeToColumnType(dataType: DataType): ColumnType[_, _] = {
def apply(dataType: DataType): ColumnType[_, _] = {
dataType match {
case IntegerType => INT
case LongType => LONG
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute}
import org.apache.spark.sql.execution.{SparkPlan, LeafNode}
import org.apache.spark.sql.Row

/* Implicit conversions */
import org.apache.spark.sql.columnar.ColumnType._

private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], child: SparkPlan)
extends LeafNode {

Expand All @@ -33,7 +30,7 @@ private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], ch
val output = child.output
val cached = child.execute().mapPartitions { iterator =>
val columnBuilders = output.map { attribute =>
ColumnBuilder(attribute.dataType.typeId, 0, attribute.name)
ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name)
}.toArray

var row: Row = null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ private[sql] trait WithCompressionSchemes {
}

private[sql] trait AllCompressionSchemes extends WithCompressionSchemes {
override val schemes = Seq(PassThrough, RunLengthEncoding, DictionaryEncoding)
override val schemes: Seq[CompressionScheme] = {
Seq(PassThrough, RunLengthEncoding, DictionaryEncoding)
}
}

private[sql] object CompressionScheme {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.NativeType
import org.apache.spark.sql.columnar._

private[sql] object PassThrough extends CompressionScheme {
private[sql] case object PassThrough extends CompressionScheme {
override val typeId = 0

override def supports(columnType: ColumnType[_, _]) = true
Expand Down Expand Up @@ -63,7 +63,7 @@ private[sql] object PassThrough extends CompressionScheme {
}
}

private[sql] object RunLengthEncoding extends CompressionScheme {
private[sql] case object RunLengthEncoding extends CompressionScheme {
override def typeId = 1

override def encoder = new this.Encoder
Expand Down Expand Up @@ -171,7 +171,7 @@ private[sql] object RunLengthEncoding extends CompressionScheme {
}
}

private[sql] object DictionaryEncoding extends CompressionScheme {
private[sql] case object DictionaryEncoding extends CompressionScheme {
override def typeId: Int = 2

// 32K unique values allowed
Expand Down Expand Up @@ -270,6 +270,7 @@ private[sql] object DictionaryEncoding extends CompressionScheme {
extends compression.Decoder[T] {

private val dictionary = {
// TODO Can we clean up this mess? Maybe move this to `DataType`?
implicit val classTag = {
val mirror = runtimeMirror(getClass.getClassLoader)
ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ class ColumnStatsSuite extends FunSuite {

test(s"$columnStatsName: empty") {
val columnStats = columnStatsClass.newInstance()
assert((columnStats.lowerBound, columnStats.upperBound) === columnStats.initialBounds)
expectResult(columnStats.initialBounds, "Wrong initial bounds") {
(columnStats.lowerBound, columnStats.upperBound)
}
}

test(s"$columnStatsName: non-empty") {
Expand All @@ -52,8 +54,8 @@ class ColumnStatsSuite extends FunSuite {
val values = rows.map(_.head.asInstanceOf[T#JvmType])
val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]]

assert(columnStats.lowerBound === values.min(ordering))
assert(columnStats.upperBound === values.max(ordering))
expectResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound)
expectResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,46 @@ import org.apache.spark.sql.execution.SparkSqlSerializer
class ColumnTypeSuite extends FunSuite {
val DEFAULT_BUFFER_SIZE = 512

val columnTypes = Seq(INT, SHORT, LONG, BYTE, DOUBLE, FLOAT, STRING, BINARY, GENERIC)

test("defaultSize") {
val defaultSize = Seq(4, 2, 8, 1, 8, 4, 8, 16, 16)
val checks = Map(
INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4,
BOOLEAN -> 1, STRING -> 8, BINARY -> 16, GENERIC -> 16)

columnTypes.zip(defaultSize).foreach { case (columnType, size) =>
assert(columnType.defaultSize === size)
checks.foreach { case (columnType, expectedSize) =>
expectResult(expectedSize, s"Wrong defaultSize for $columnType") {
columnType.defaultSize
}
}
}

test("actualSize") {
val expectedSizes = Seq(4, 2, 8, 1, 8, 4, 4 + 5, 4 + 4, 4 + 11)
val actualSizes = Seq(
INT.actualSize(Int.MaxValue),
SHORT.actualSize(Short.MaxValue),
LONG.actualSize(Long.MaxValue),
BYTE.actualSize(Byte.MaxValue),
DOUBLE.actualSize(Double.MaxValue),
FLOAT.actualSize(Float.MaxValue),
STRING.actualSize("hello"),
BINARY.actualSize(new Array[Byte](4)),
GENERIC.actualSize(SparkSqlSerializer.serialize(Map(1 -> "a"))))

expectedSizes.zip(actualSizes).foreach { case (expected, actual) =>
assert(expected === actual)
def checkActualSize[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
value: JvmType,
expected: Int) {

expectResult(expected, s"Wrong actualSize for $columnType") {
columnType.actualSize(value)
}
}

checkActualSize(INT, Int.MaxValue, 4)
checkActualSize(SHORT, Short.MaxValue, 2)
checkActualSize(LONG, Long.MaxValue, 8)
checkActualSize(BYTE, Byte.MaxValue, 1)
checkActualSize(DOUBLE, Double.MaxValue, 8)
checkActualSize(FLOAT, Float.MaxValue, 4)
checkActualSize(BOOLEAN, true, 1)
checkActualSize(STRING, "hello", 4 + 5)

val binary = Array.fill[Byte](4)(0: Byte)
checkActualSize(BINARY, binary, 4 + 4)

val generic = Map(1 -> "a")
checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11)
}

testNativeColumnStats[BooleanType.type](
testNativeColumnType[BooleanType.type](
BOOLEAN,
(buffer: ByteBuffer, v: Boolean) => {
buffer.put((if (v) 1 else 0).toByte)
Expand All @@ -65,37 +76,19 @@ class ColumnTypeSuite extends FunSuite {
buffer.get() == 1
})

testNativeColumnStats[IntegerType.type](
INT,
(_: ByteBuffer).putInt(_),
(_: ByteBuffer).getInt)

testNativeColumnStats[ShortType.type](
SHORT,
(_: ByteBuffer).putShort(_),
(_: ByteBuffer).getShort)

testNativeColumnStats[LongType.type](
LONG,
(_: ByteBuffer).putLong(_),
(_: ByteBuffer).getLong)

testNativeColumnStats[ByteType.type](
BYTE,
(_: ByteBuffer).put(_),
(_: ByteBuffer).get)

testNativeColumnStats[DoubleType.type](
DOUBLE,
(_: ByteBuffer).putDouble(_),
(_: ByteBuffer).getDouble)

testNativeColumnStats[FloatType.type](
FLOAT,
(_: ByteBuffer).putFloat(_),
(_: ByteBuffer).getFloat)

testNativeColumnStats[StringType.type](
testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt)

testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort)

testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong)

testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get)

testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble)

testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat)

testNativeColumnType[StringType.type](
STRING,
(buffer: ByteBuffer, string: String) => {
val bytes = string.getBytes()
Expand All @@ -108,7 +101,7 @@ class ColumnTypeSuite extends FunSuite {
new String(bytes)
})

testColumnStats[BinaryType.type, Array[Byte]](
testColumnType[BinaryType.type, Array[Byte]](
BINARY,
(buffer: ByteBuffer, bytes: Array[Byte]) => {
buffer.putInt(bytes.length).put(bytes)
Expand All @@ -131,51 +124,58 @@ class ColumnTypeSuite extends FunSuite {
val length = buffer.getInt()
assert(length === serializedObj.length)

val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
assert(obj === SparkSqlSerializer.deserialize(bytes))
expectResult(obj, "Deserialized object didn't equal to the original object") {
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
SparkSqlSerializer.deserialize(bytes)
}

buffer.rewind()
buffer.putInt(serializedObj.length).put(serializedObj)

buffer.rewind()
assert(obj === SparkSqlSerializer.deserialize(GENERIC.extract(buffer)))
expectResult(obj, "Deserialized object didn't equal to the original object") {
buffer.rewind()
SparkSqlSerializer.deserialize(GENERIC.extract(buffer))
}
}

def testNativeColumnStats[T <: NativeType](
def testNativeColumnType[T <: NativeType](
columnType: NativeColumnType[T],
putter: (ByteBuffer, T#JvmType) => Unit,
getter: (ByteBuffer) => T#JvmType) {

testColumnStats[T, T#JvmType](columnType, putter, getter)
testColumnType[T, T#JvmType](columnType, putter, getter)
}

def testColumnStats[T <: DataType, JvmType](
def testColumnType[T <: DataType, JvmType](
columnType: ColumnType[T, JvmType],
putter: (ByteBuffer, JvmType) => Unit,
getter: (ByteBuffer) => JvmType) {

val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE)
val columnTypeName = columnType.getClass.getSimpleName.stripSuffix("$")
val seq = (0 until 4).map(_ => makeRandomValue(columnType))

test(s"$columnTypeName.extract") {
test(s"$columnType.extract") {
buffer.rewind()
seq.foreach(putter(buffer, _))

buffer.rewind()
seq.foreach { i =>
assert(i === columnType.extract(buffer))
seq.foreach { expected =>
assert(
expected === columnType.extract(buffer),
"Extracted value didn't equal to the original one")
}
}

test(s"$columnTypeName.append") {
test(s"$columnType.append") {
buffer.rewind()
seq.foreach(columnType.append(_, buffer))

buffer.rewind()
seq.foreach { i =>
assert(i === getter(buffer))
seq.foreach { expected =>
assert(
expected === getter(buffer),
"Extracted value didn't equal to the original one")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,13 @@ class NullableColumnAccessorSuite extends FunSuite {
val typeName = columnType.getClass.getSimpleName.stripSuffix("$")
val nullRow = makeNullRow(1)

test(s"Nullable $typeName accessor: empty column") {
test(s"Nullable $typeName column accessor: empty column") {
val builder = TestNullableColumnBuilder(columnType)
val accessor = TestNullableColumnAccessor(builder.build(), columnType)
assert(!accessor.hasNext)
}

test(s"Nullable $typeName accessor: access null values") {
test(s"Nullable $typeName column accessor: access null values") {
val builder = TestNullableColumnBuilder(columnType)
val randomRow = makeRandomRow(columnType)

Expand All @@ -72,7 +72,7 @@ class NullableColumnAccessorSuite extends FunSuite {
assert(row(0) === randomRow(0))

accessor.extractTo(row, 0)
assert(row(0) === null)
assert(row.isNullAt(0))
}
}
}
Expand Down
Loading

0 comments on commit ed71bbd

Please sign in to comment.