diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 26d5667351fd9..4e5cd44acc9f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -41,7 +41,8 @@ private[sql] trait ColumnAccessor { } private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( - buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) + protected val buffer: ByteBuffer, + protected val columnType: ColumnType[T, JvmType]) extends ColumnAccessor { protected def initialize() {} @@ -49,20 +50,20 @@ private[sql] abstract class BasicColumnAccessor[T <: DataType, JvmType]( def hasNext = buffer.hasRemaining def extractTo(row: MutableRow, ordinal: Int) { - columnType.setField(row, ordinal, columnType.extract(buffer)) + columnType.setField(row, ordinal, extractSingle(buffer)) } + def extractSingle(buffer: ByteBuffer) = columnType.extract(buffer) + protected def underlyingBuffer = buffer } private[sql] abstract class NativeColumnAccessor[T <: NativeType]( buffer: ByteBuffer, - val columnType: NativeColumnType[T]) - extends BasicColumnAccessor[T, T#JvmType](buffer, columnType) - with NullableColumnAccessor { - - type JvmType = T#JvmType -} + columnType: NativeColumnType[T]) + extends BasicColumnAccessor(buffer, columnType) + with NullableColumnAccessor + with CompressedColumnAccessor[T] private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 9394dfa39f2f9..0c10a2b87308f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -29,10 +29,19 @@ private[sql] trait ColumnBuilder { */ def initialize(initialSize: Int, columnName: String = "") + /** + * Gathers statistics information from `row(ordinal)`. + */ def gatherStats(row: Row, ordinal: Int) {} + /** + * Appends `row(ordinal)` to the column builder. + */ def appendFrom(row: Row, ordinal: Int) + /** + * Returns the final columnar byte buffer. + */ def build(): ByteBuffer } @@ -40,14 +49,16 @@ private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType]( val columnType: ColumnType[T, JvmType]) extends ColumnBuilder { - private var columnName: String = _ + protected var columnName: String = _ protected var buffer: ByteBuffer = _ override def initialize(initialSize: Int, columnName: String = "") = { val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize this.columnName = columnName - buffer = ByteBuffer.allocate(4 + 4 + size * columnType.defaultSize) + + // Reserves 4 bytes for column type ID + buffer = ByteBuffer.allocate(4 + size * columnType.defaultSize) buffer.order(ByteOrder.nativeOrder()).putInt(columnType.typeId) } @@ -66,8 +77,9 @@ private[sql] abstract class BasicColumnBuilder[T <: DataType, JvmType]( private[sql] abstract class NativeColumnBuilder[T <: NativeType]( protected val columnStats: ColumnStats[T], columnType: NativeColumnType[T]) - extends BasicColumnBuilder[T, T#JvmType](columnType) - with NullableColumnBuilder { + extends BasicColumnBuilder(columnType) + with NullableColumnBuilder + with CompressedColumnBuilder[T] { override def gatherStats(row: Row, ordinal: Int) { columnStats.gatherStats(row, ordinal) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala new file mode 100644 index 0000000000000..2fefd1a75fb72 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.types._ + +private[sql] sealed abstract class ColumnStats[T <: NativeType] extends Serializable{ + type JvmType = T#JvmType + + protected var (_lower, _upper) = initialBounds + + protected val ordering: Ordering[JvmType] + + protected def columnType: NativeColumnType[T] + + /** + * Closed lower bound of this column. + */ + def lowerBound = _lower + + /** + * Closed upper bound of this column. + */ + def upperBound = _upper + + /** + * Initial values for the closed lower/upper bounds, in the format of `(lower, upper)`. + */ + protected def initialBounds: (JvmType, JvmType) + + /** + * Gathers statistics information from `row(ordinal)`. + */ + @inline def gatherStats(row: Row, ordinal: Int) { + val field = columnType.getField(row, ordinal) + if (ordering.gt(field, upperBound)) _upper = field + if (ordering.lt(field, lowerBound)) _lower = field + } + + /** + * Returns `true` if `lower <= row(ordinal) <= upper`. + */ + @inline def contains(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + ordering.lteq(lowerBound, field) && ordering.lteq(field, upperBound) + } + + /** + * Returns `true` if `row(ordinal) < upper` holds. + */ + @inline def isAbove(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + ordering.lt(field, upperBound) + } + + /** + * Returns `true` if `lower < row(ordinal)` holds. + */ + @inline def isBelow(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + ordering.lt(lowerBound, field) + } + + /** + * Returns `true` if `row(ordinal) <= upper` holds. + */ + @inline def isAtOrAbove(row: Row, ordinal: Int) = { + contains(row, ordinal) || isAbove(row, ordinal) + } + + /** + * Returns `true` if `lower <= row(ordinal)` holds. + */ + @inline def isAtOrBelow(row: Row, ordinal: Int) = { + contains(row, ordinal) || isBelow(row, ordinal) + } +} + +private[sql] abstract class BasicColumnStats[T <: NativeType]( + protected val columnType: NativeColumnType[T]) + extends ColumnStats[T] + +private[sql] class BooleanColumnStats extends BasicColumnStats(BOOLEAN) { + override protected val ordering = implicitly[Ordering[JvmType]] + override protected def initialBounds = (true, false) +} + +private[sql] class ByteColumnStats extends BasicColumnStats(BYTE) { + override protected val ordering = implicitly[Ordering[JvmType]] + override protected def initialBounds = (Byte.MaxValue, Byte.MinValue) +} + +private[sql] class ShortColumnStats extends BasicColumnStats(SHORT) { + override protected val ordering = implicitly[Ordering[JvmType]] + override protected def initialBounds = (Short.MaxValue, Short.MinValue) +} + +private[sql] class LongColumnStats extends BasicColumnStats(LONG) { + override protected val ordering = implicitly[Ordering[JvmType]] + override protected def initialBounds = (Long.MaxValue, Long.MinValue) +} + +private[sql] class DoubleColumnStats extends BasicColumnStats(DOUBLE) { + override protected val ordering = implicitly[Ordering[JvmType]] + override protected def initialBounds = (Double.MaxValue, Double.MinValue) +} + +private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) { + override protected val ordering = implicitly[Ordering[JvmType]] + override protected def initialBounds = (Float.MaxValue, Float.MinValue) +} + +private[sql] class IntColumnStats extends BasicColumnStats(INT) { + private object OrderedState extends Enumeration { + val Uninitialized, Initialized, Ascending, Descending, Unordered = Value + } + + import OrderedState._ + + private var orderedState = Uninitialized + private var lastValue: Int = _ + private var _maxDelta: Int = _ + + def isAscending = orderedState != Descending && orderedState != Unordered + def isDescending = orderedState != Ascending && orderedState != Unordered + def isOrdered = isAscending || isDescending + def maxDelta = _maxDelta + + override protected val ordering = implicitly[Ordering[JvmType]] + override protected def initialBounds = (Int.MaxValue, Int.MinValue) + + override def gatherStats(row: Row, ordinal: Int) = { + val field = columnType.getField(row, ordinal) + + if (field > upperBound) _upper = field + if (field < lowerBound) _lower = field + + orderedState = orderedState match { + case Uninitialized => + lastValue = field + Initialized + + case Initialized => + // If all the integers in the column are the same, ordered state is set to Ascending. + // TODO (lian) Confirm whether this is the standard behaviour. + val nextState = if (field >= lastValue) Ascending else Descending + _maxDelta = math.abs(field - lastValue) + lastValue = field + nextState + + case Ascending if field < lastValue => + Unordered + + case Descending if field > lastValue => + Unordered + + case state @ (Ascending | Descending) => + _maxDelta = _maxDelta.max(field - lastValue) + lastValue = field + state + } + } +} + +private[sql] class StringColumnStates extends BasicColumnStats(STRING) { + override protected val ordering = implicitly[Ordering[JvmType]] + override protected def initialBounds = (null, null) + + override def contains(row: Row, ordinal: Int) = { + !(upperBound eq null) && super.contains(row, ordinal) + } + + override def isAbove(row: Row, ordinal: Int) = { + !(upperBound eq null) && super.isAbove(row, ordinal) + } + + override def isBelow(row: Row, ordinal: Int) = { + !(lowerBound eq null) && super.isBelow(row, ordinal) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/CompressedColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CompressedColumnAccessor.scala new file mode 100644 index 0000000000000..b7c52de8a60d1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CompressedColumnAccessor.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import java.nio.ByteBuffer + +import org.apache.spark.sql.catalyst.types.NativeType +import org.apache.spark.sql.columnar.CompressionAlgorithm.NoopDecoder +import org.apache.spark.sql.columnar.CompressionType._ + +private[sql] trait CompressedColumnAccessor[T <: NativeType] extends ColumnAccessor { + this: BasicColumnAccessor[T, T#JvmType] => + + private var decoder: Iterator[T#JvmType] = _ + + abstract override protected def initialize() = { + super.initialize() + + decoder = underlyingBuffer.getInt() match { + case id if id == Noop.id => new NoopDecoder[T](buffer, columnType) + case _ => throw new UnsupportedOperationException() + } + } + + abstract override def extractSingle(buffer: ByteBuffer) = decoder.next() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/CompressedColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CompressedColumnBuilder.scala new file mode 100644 index 0000000000000..fa31e48f35cd7 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CompressedColumnBuilder.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.apache.spark.sql.{Logging, Row} +import org.apache.spark.sql.catalyst.types.NativeType + +private[sql] trait CompressedColumnBuilder[T <: NativeType] extends ColumnBuilder with Logging { + this: BasicColumnBuilder[T, T#JvmType] => + + val compressionSchemes = Seq(new CompressionAlgorithm.Noop) + .filter(_.supports(columnType)) + + def isWorthCompressing(scheme: CompressionAlgorithm) = { + scheme.compressionRatio < 0.8 + } + + abstract override def gatherStats(row: Row, ordinal: Int) { + compressionSchemes.foreach { + val field = columnType.getField(row, ordinal) + _.gatherCompressibilityStats(field, columnType) + } + + super.gatherStats(row, ordinal) + } + + abstract override def build() = { + val rawBuffer = super.build() + + if (compressionSchemes.isEmpty) { + logger.info(s"Compression scheme chosen for [$columnName] is ${CompressionType.Noop}") + new CompressionAlgorithm.Noop().compress(rawBuffer, columnType) + } else { + val candidateScheme = compressionSchemes.minBy(_.compressionRatio) + + logger.info( + s"Compression scheme chosen for [$columnName] is ${candidateScheme.compressionType} " + + s"ration ${candidateScheme.compressionRatio}") + + if (isWorthCompressing(candidateScheme)) { + candidateScheme.compress(rawBuffer, columnType) + } else { + new CompressionAlgorithm.Noop().compress(rawBuffer, columnType) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/CompressionAlgorithm.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CompressionAlgorithm.scala new file mode 100644 index 0000000000000..855767f69b4e0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/CompressionAlgorithm.scala @@ -0,0 +1,86 @@ +package org.apache.spark.sql.columnar + +import java.nio.{ByteOrder, ByteBuffer} + +import org.apache.spark.sql.catalyst.types.NativeType + +private[sql] object CompressionType extends Enumeration { + type CompressionType = Value + + val Default, Noop, RLE, Dictionary, BooleanBitSet, IntDelta, LongDelta = Value +} + +private[sql] trait CompressionAlgorithm { + def compressionType: CompressionType.Value + + def supports(columnType: ColumnType[_, _]): Boolean + + def gatherCompressibilityStats[T <: NativeType]( + value: T#JvmType, + columnType: ColumnType[T, T#JvmType]) {} + + def compressedSize: Int + + def uncompressedSize: Int + + def compressionRatio: Double = compressedSize.toDouble / uncompressedSize + + def compress[T <: NativeType](from: ByteBuffer, columnType: ColumnType[T, T#JvmType]): ByteBuffer +} + +private[sql] object CompressionAlgorithm { + def apply(typeId: Int) = typeId match { + case CompressionType.Noop => new CompressionAlgorithm.Noop + case _ => throw new UnsupportedOperationException() + } + + class Noop extends CompressionAlgorithm { + override def uncompressedSize = 0 + override def compressedSize = 0 + override def compressionRatio = 1.0 + override def supports(columnType: ColumnType[_, _]) = true + override def compressionType = CompressionType.Noop + + override def compress[T <: NativeType]( + from: ByteBuffer, + columnType: ColumnType[T, T#JvmType]) = { + + // Reserves 4 bytes for compression type + val to = ByteBuffer.allocate(from.limit + 4).order(ByteOrder.nativeOrder) + copyHeader(from, to) + + // Writes compression type ID and copies raw contents + to.putInt(CompressionType.Noop.id).put(from).rewind() + to + } + } + + class NoopDecoder[T <: NativeType](buffer: ByteBuffer, columnType: ColumnType[T, T#JvmType]) + extends Iterator[T#JvmType] { + + override def next() = columnType.extract(buffer) + + override def hasNext = buffer.hasRemaining + } + + def copyNullInfo(from: ByteBuffer, to: ByteBuffer) { + // Writes null count + val nullCount = from.getInt() + to.putInt(nullCount) + + // Writes null positions + var i = 0 + while (i < nullCount) { + to.putInt(from.getInt()) + i += 1 + } + } + + def copyHeader(from: ByteBuffer, to: ByteBuffer) { + // Writes column type ID + to.putInt(from.getInt()) + + // Copies null count and null positions + copyNullInfo(from, to) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala index 2970c609b928d..7d49ab07f7a53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnAccessor.scala @@ -29,7 +29,7 @@ private[sql] trait NullableColumnAccessor extends ColumnAccessor { private var nextNullIndex: Int = _ private var pos: Int = 0 - abstract override def initialize() { + abstract override protected def initialize() { nullsBuffer = underlyingBuffer.duplicate().order(ByteOrder.nativeOrder()) nullCount = nullsBuffer.getInt() nextNullIndex = if (nullCount > 0) nullsBuffer.getInt() else -1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala index 048d1f05c7df2..8712fdb283659 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/NullableColumnBuilder.scala @@ -71,7 +71,7 @@ private[sql] trait NullableColumnBuilder extends ColumnBuilder { // | ... | Non-null part (without column type ID) // +---------+ val buffer = ByteBuffer - .allocate(4 + nullDataLen + nonNulls.limit) + .allocate(4 + 4 + nullDataLen + nonNulls.remaining()) .order(ByteOrder.nativeOrder()) .putInt(typeId) .putInt(nullCount) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala new file mode 100644 index 0000000000000..0fa5323ba2060 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.columnar + +import org.scalatest.FunSuite +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow + +class ColumnStatsSuite extends FunSuite { + test("Boolean") { + val stats = new BooleanColumnStats + val row = new GenericMutableRow(1) + + row(0) = false + stats.gatherStats(row, 0) + assert(stats.lowerBound === false) + assert(stats.upperBound === false) + + row(0) = true + stats.gatherStats(row, 0) + assert(stats.lowerBound === false) + assert(stats.upperBound === true) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 0b687e74ed660..61f7338913791 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -17,11 +17,27 @@ package org.apache.spark.sql.columnar +import java.nio.ByteBuffer + import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.DataType +class TestNullableColumnAccessor[T <: DataType, JvmType]( + buffer: ByteBuffer, + columnType: ColumnType[T, JvmType]) + extends BasicColumnAccessor(buffer, columnType) + with NullableColumnAccessor + +object TestNullableColumnAccessor { + def apply[T <: DataType, JvmType](buffer: ByteBuffer, columnType: ColumnType[T, JvmType]) = { + // Skips the column type ID + buffer.getInt() + new TestNullableColumnAccessor(buffer, columnType) + } +} + class NullableColumnAccessorSuite extends FunSuite { import ColumnarTestData._ @@ -33,20 +49,20 @@ class NullableColumnAccessorSuite extends FunSuite { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") test(s"$typeName accessor: empty column") { - val builder = ColumnBuilder(columnType.typeId, 4) - val accessor = ColumnAccessor(builder.build()) + val builder = TestNullableColumnBuilder(columnType) + val accessor = TestNullableColumnAccessor(builder.build(), columnType) assert(!accessor.hasNext) } test(s"$typeName accessor: access null values") { - val builder = ColumnBuilder(columnType.typeId, 4) + val builder = TestNullableColumnBuilder(columnType) (0 until 4).foreach { _ => builder.appendFrom(nonNullRandomRow, columnType.typeId) builder.appendFrom(nullRow, columnType.typeId) } - val accessor = ColumnAccessor(builder.build()) + val accessor = TestNullableColumnAccessor(builder.build(), columnType) val row = new GenericMutableRow(1) (0 until 4).foreach { _ => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index 5222a47e1ab87..3732aaaba6687 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -19,9 +19,21 @@ package org.apache.spark.sql.columnar import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.DataType +import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.SparkSqlSerializer +class TestNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) + extends BasicColumnBuilder(columnType) + with NullableColumnBuilder + +object TestNullableColumnBuilder { + def apply[T <: DataType, JvmType](columnType: ColumnType[T, JvmType], initialSize: Int = 0) = { + val builder = new TestNullableColumnBuilder(columnType) + builder.initialize(initialSize) + builder + } +} + class NullableColumnBuilderSuite extends FunSuite { import ColumnarTestData._ @@ -30,23 +42,21 @@ class NullableColumnBuilderSuite extends FunSuite { } def testNullableColumnBuilder[T <: DataType, JvmType](columnType: ColumnType[T, JvmType]) { - val columnBuilder = ColumnBuilder(columnType.typeId) val typeName = columnType.getClass.getSimpleName.stripSuffix("$") test(s"$typeName column builder: empty column") { - columnBuilder.initialize(4) - + val columnBuilder = TestNullableColumnBuilder(columnType) val buffer = columnBuilder.build() // For column type ID assert(buffer.getInt() === columnType.typeId) // For null count - assert(buffer.getInt === 0) + assert(buffer.getInt() === 0) assert(!buffer.hasRemaining) } test(s"$typeName column builder: buffer size auto growth") { - columnBuilder.initialize(4) + val columnBuilder = TestNullableColumnBuilder(columnType) (0 until 4) foreach { _ => columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId) @@ -61,7 +71,7 @@ class NullableColumnBuilderSuite extends FunSuite { } test(s"$typeName column builder: null values") { - columnBuilder.initialize(4) + val columnBuilder = TestNullableColumnBuilder(columnType) (0 until 4) foreach { _ => columnBuilder.appendFrom(nonNullRandomRow, columnType.typeId)