From ed71bbd38f389dcdc11949513a02fe11ff7bdb6a Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 2 Apr 2014 20:54:24 +0800 Subject: [PATCH] Addressed all PR comments by @marmbrus https://github.com/apache/spark/pull/285 --- .../spark/sql/columnar/ColumnStats.scala | 15 +- .../spark/sql/columnar/ColumnType.scala | 4 +- ....scala => InMemoryColumnarTableScan.scala} | 5 +- .../compression/CompressionScheme.scala | 4 +- .../compression/compressionSchemes.scala | 7 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 8 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 134 +++++++++--------- .../NullableColumnAccessorSuite.scala | 6 +- .../columnar/NullableColumnBuilderSuite.scala | 28 ++-- .../compression/DictionaryEncodingSuite.scala | 42 +++--- .../compression/RunLengthEncodingSuite.scala | 49 ++++--- 11 files changed, 156 insertions(+), 146 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/columnar/{inMemoryColumnarOperators.scala => InMemoryColumnarTableScan.scala} (94%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index c7344dd9e1cd3..30c6bdc7912fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.columnar import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.types._ -private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable{ +private[sql] sealed abstract class ColumnStats[T <: DataType, JvmType] extends Serializable { /** * Closed lower bound of this column. */ @@ -246,7 +246,7 @@ private[sql] class FloatColumnStats extends BasicColumnStats(FLOAT) { } } -object IntColumnStats { +private[sql] object IntColumnStats { val UNINITIALIZED = 0 val INITIALIZED = 1 val ASCENDING = 2 @@ -254,6 +254,17 @@ object IntColumnStats { val UNORDERED = 4 } +/** + * Statistical information for `Int` columns. More information is collected since `Int` is + * frequently used. Extra information include: + * + * - Ordering state (ascending/descending/unordered), may be used to decide whether binary search + * is applicable when searching elements. + * - Maximum delta between adjacent elements, may be used to guide the `IntDelta` compression + * scheme. + * + * (This two kinds of information are not used anywhere yet and might be removed later.) + */ private[sql] class IntColumnStats extends BasicColumnStats(INT) { import IntColumnStats._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index a761c42b2aba2..5be76890afe31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -71,6 +71,8 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Creates a duplicated copy of the value. */ def clone(v: JvmType): JvmType = v + + override def toString = getClass.getSimpleName.stripSuffix("$") } private[sql] abstract class NativeColumnType[T <: NativeType]( @@ -258,7 +260,7 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](9, 16) { } private[sql] object ColumnType { - implicit def dataTypeToColumnType(dataType: DataType): ColumnType[_, _] = { + def apply(dataType: DataType): ColumnType[_, _] = { dataType match { case IntegerType => INT case LongType => LONG diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala rename to sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 07f8e59b061e2..8a24733047423 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/inMemoryColumnarOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -21,9 +21,6 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Attribute} import org.apache.spark.sql.execution.{SparkPlan, LeafNode} import org.apache.spark.sql.Row -/* Implicit conversions */ -import org.apache.spark.sql.columnar.ColumnType._ - private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], child: SparkPlan) extends LeafNode { @@ -33,7 +30,7 @@ private[sql] case class InMemoryColumnarTableScan(attributes: Seq[Attribute], ch val output = child.output val cached = child.execute().mapPartitions { iterator => val columnBuilders = output.map { attribute => - ColumnBuilder(attribute.dataType.typeId, 0, attribute.name) + ColumnBuilder(ColumnType(attribute.dataType).typeId, 0, attribute.name) }.toArray var row: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala index be2f359553020..d3a4ac8df926b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala @@ -58,7 +58,9 @@ private[sql] trait WithCompressionSchemes { } private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { - override val schemes = Seq(PassThrough, RunLengthEncoding, DictionaryEncoding) + override val schemes: Seq[CompressionScheme] = { + Seq(PassThrough, RunLengthEncoding, DictionaryEncoding) + } } private[sql] object CompressionScheme { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala index 50ebe6c907f32..dc2c153faf8ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.types.NativeType import org.apache.spark.sql.columnar._ -private[sql] object PassThrough extends CompressionScheme { +private[sql] case object PassThrough extends CompressionScheme { override val typeId = 0 override def supports(columnType: ColumnType[_, _]) = true @@ -63,7 +63,7 @@ private[sql] object PassThrough extends CompressionScheme { } } -private[sql] object RunLengthEncoding extends CompressionScheme { +private[sql] case object RunLengthEncoding extends CompressionScheme { override def typeId = 1 override def encoder = new this.Encoder @@ -171,7 +171,7 @@ private[sql] object RunLengthEncoding extends CompressionScheme { } } -private[sql] object DictionaryEncoding extends CompressionScheme { +private[sql] case object DictionaryEncoding extends CompressionScheme { override def typeId: Int = 2 // 32K unique values allowed @@ -270,6 +270,7 @@ private[sql] object DictionaryEncoding extends CompressionScheme { extends compression.Decoder[T] { private val dictionary = { + // TODO Can we clean up this mess? Maybe move this to `DataType`? implicit val classTag = { val mirror = runtimeMirror(getClass.getClassLoader) ClassTag[T#JvmType](mirror.runtimeClass(columnType.scalaTag.tpe)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index f830bb974627b..78640b876d4aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -39,7 +39,9 @@ class ColumnStatsSuite extends FunSuite { test(s"$columnStatsName: empty") { val columnStats = columnStatsClass.newInstance() - assert((columnStats.lowerBound, columnStats.upperBound) === columnStats.initialBounds) + expectResult(columnStats.initialBounds, "Wrong initial bounds") { + (columnStats.lowerBound, columnStats.upperBound) + } } test(s"$columnStatsName: non-empty") { @@ -52,8 +54,8 @@ class ColumnStatsSuite extends FunSuite { val values = rows.map(_.head.asInstanceOf[T#JvmType]) val ordering = columnType.dataType.ordering.asInstanceOf[Ordering[T#JvmType]] - assert(columnStats.lowerBound === values.min(ordering)) - assert(columnStats.upperBound === values.max(ordering)) + expectResult(values.min(ordering), "Wrong lower bound")(columnStats.lowerBound) + expectResult(values.max(ordering), "Wrong upper bound")(columnStats.upperBound) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 1a98ec270c03b..1d3608ed2d9ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -28,35 +28,46 @@ import org.apache.spark.sql.execution.SparkSqlSerializer class ColumnTypeSuite extends FunSuite { val DEFAULT_BUFFER_SIZE = 512 - val columnTypes = Seq(INT, SHORT, LONG, BYTE, DOUBLE, FLOAT, STRING, BINARY, GENERIC) - test("defaultSize") { - val defaultSize = Seq(4, 2, 8, 1, 8, 4, 8, 16, 16) + val checks = Map( + INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, + BOOLEAN -> 1, STRING -> 8, BINARY -> 16, GENERIC -> 16) - columnTypes.zip(defaultSize).foreach { case (columnType, size) => - assert(columnType.defaultSize === size) + checks.foreach { case (columnType, expectedSize) => + expectResult(expectedSize, s"Wrong defaultSize for $columnType") { + columnType.defaultSize + } } } test("actualSize") { - val expectedSizes = Seq(4, 2, 8, 1, 8, 4, 4 + 5, 4 + 4, 4 + 11) - val actualSizes = Seq( - INT.actualSize(Int.MaxValue), - SHORT.actualSize(Short.MaxValue), - LONG.actualSize(Long.MaxValue), - BYTE.actualSize(Byte.MaxValue), - DOUBLE.actualSize(Double.MaxValue), - FLOAT.actualSize(Float.MaxValue), - STRING.actualSize("hello"), - BINARY.actualSize(new Array[Byte](4)), - GENERIC.actualSize(SparkSqlSerializer.serialize(Map(1 -> "a")))) - - expectedSizes.zip(actualSizes).foreach { case (expected, actual) => - assert(expected === actual) + def checkActualSize[T <: DataType, JvmType]( + columnType: ColumnType[T, JvmType], + value: JvmType, + expected: Int) { + + expectResult(expected, s"Wrong actualSize for $columnType") { + columnType.actualSize(value) + } } + + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(LONG, Long.MaxValue, 8) + checkActualSize(BYTE, Byte.MaxValue, 1) + checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(FLOAT, Float.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(STRING, "hello", 4 + 5) + + val binary = Array.fill[Byte](4)(0: Byte) + checkActualSize(BINARY, binary, 4 + 4) + + val generic = Map(1 -> "a") + checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 11) } - testNativeColumnStats[BooleanType.type]( + testNativeColumnType[BooleanType.type]( BOOLEAN, (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) @@ -65,37 +76,19 @@ class ColumnTypeSuite extends FunSuite { buffer.get() == 1 }) - testNativeColumnStats[IntegerType.type]( - INT, - (_: ByteBuffer).putInt(_), - (_: ByteBuffer).getInt) - - testNativeColumnStats[ShortType.type]( - SHORT, - (_: ByteBuffer).putShort(_), - (_: ByteBuffer).getShort) - - testNativeColumnStats[LongType.type]( - LONG, - (_: ByteBuffer).putLong(_), - (_: ByteBuffer).getLong) - - testNativeColumnStats[ByteType.type]( - BYTE, - (_: ByteBuffer).put(_), - (_: ByteBuffer).get) - - testNativeColumnStats[DoubleType.type]( - DOUBLE, - (_: ByteBuffer).putDouble(_), - (_: ByteBuffer).getDouble) - - testNativeColumnStats[FloatType.type]( - FLOAT, - (_: ByteBuffer).putFloat(_), - (_: ByteBuffer).getFloat) - - testNativeColumnStats[StringType.type]( + testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) + + testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) + + testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) + + testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + + testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + + testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) + + testNativeColumnType[StringType.type]( STRING, (buffer: ByteBuffer, string: String) => { val bytes = string.getBytes() @@ -108,7 +101,7 @@ class ColumnTypeSuite extends FunSuite { new String(bytes) }) - testColumnStats[BinaryType.type, Array[Byte]]( + testColumnType[BinaryType.type, Array[Byte]]( BINARY, (buffer: ByteBuffer, bytes: Array[Byte]) => { buffer.putInt(bytes.length).put(bytes) @@ -131,51 +124,58 @@ class ColumnTypeSuite extends FunSuite { val length = buffer.getInt() assert(length === serializedObj.length) - val bytes = new Array[Byte](length) - buffer.get(bytes, 0, length) - assert(obj === SparkSqlSerializer.deserialize(bytes)) + expectResult(obj, "Deserialized object didn't equal to the original object") { + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + SparkSqlSerializer.deserialize(bytes) + } buffer.rewind() buffer.putInt(serializedObj.length).put(serializedObj) - buffer.rewind() - assert(obj === SparkSqlSerializer.deserialize(GENERIC.extract(buffer))) + expectResult(obj, "Deserialized object didn't equal to the original object") { + buffer.rewind() + SparkSqlSerializer.deserialize(GENERIC.extract(buffer)) + } } - def testNativeColumnStats[T <: NativeType]( + def testNativeColumnType[T <: NativeType]( columnType: NativeColumnType[T], putter: (ByteBuffer, T#JvmType) => Unit, getter: (ByteBuffer) => T#JvmType) { - testColumnStats[T, T#JvmType](columnType, putter, getter) + testColumnType[T, T#JvmType](columnType, putter, getter) } - def testColumnStats[T <: DataType, JvmType]( + def testColumnType[T <: DataType, JvmType]( columnType: ColumnType[T, JvmType], putter: (ByteBuffer, JvmType) => Unit, getter: (ByteBuffer) => JvmType) { val buffer = ByteBuffer.allocate(DEFAULT_BUFFER_SIZE) - val columnTypeName = columnType.getClass.getSimpleName.stripSuffix("$") val seq = (0 until 4).map(_ => makeRandomValue(columnType)) - test(s"$columnTypeName.extract") { + test(s"$columnType.extract") { buffer.rewind() seq.foreach(putter(buffer, _)) buffer.rewind() - seq.foreach { i => - assert(i === columnType.extract(buffer)) + seq.foreach { expected => + assert( + expected === columnType.extract(buffer), + "Extracted value didn't equal to the original one") } } - test(s"$columnTypeName.append") { + test(s"$columnType.append") { buffer.rewind() seq.foreach(columnType.append(_, buffer)) buffer.rewind() - seq.foreach { i => - assert(i === getter(buffer)) + seq.foreach { expected => + assert( + expected === getter(buffer), + "Extracted value didn't equal to the original one") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 9b12441b99566..4a21eb6201a69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -49,13 +49,13 @@ class NullableColumnAccessorSuite extends FunSuite { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") val nullRow = makeNullRow(1) - test(s"Nullable $typeName accessor: empty column") { + test(s"Nullable $typeName column accessor: empty column") { val builder = TestNullableColumnBuilder(columnType) val accessor = TestNullableColumnAccessor(builder.build(), columnType) assert(!accessor.hasNext) } - test(s"Nullable $typeName accessor: access null values") { + test(s"Nullable $typeName column accessor: access null values") { val builder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) @@ -72,7 +72,7 @@ class NullableColumnAccessorSuite extends FunSuite { assert(row(0) === randomRow(0)) accessor.extractTo(row, 0) - assert(row(0) === null) + assert(row.isNullAt(0)) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index aaeeb1272b03f..d9d1e1bfddb75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -48,10 +48,8 @@ class NullableColumnBuilderSuite extends FunSuite { val columnBuilder = TestNullableColumnBuilder(columnType) val buffer = columnBuilder.build() - // For column type ID - assert(buffer.getInt() === columnType.typeId) - // For null count - assert(buffer.getInt() === 0) + expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) + expectResult(0, "Wrong null count")(buffer.getInt()) assert(!buffer.hasRemaining) } @@ -59,16 +57,14 @@ class NullableColumnBuilderSuite extends FunSuite { val columnBuilder = TestNullableColumnBuilder(columnType) val randomRow = makeRandomRow(columnType) - (0 until 4) foreach { _ => + (0 until 4).foreach { _ => columnBuilder.appendFrom(randomRow, 0) } val buffer = columnBuilder.build() - // For column type ID - assert(buffer.getInt() === columnType.typeId) - // For null count - assert(buffer.getInt() === 0) + expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) + expectResult(0, "Wrong null count")(buffer.getInt()) } test(s"$typeName column builder: null values") { @@ -76,19 +72,18 @@ class NullableColumnBuilderSuite extends FunSuite { val randomRow = makeRandomRow(columnType) val nullRow = makeNullRow(1) - (0 until 4) foreach { _ => + (0 until 4).foreach { _ => columnBuilder.appendFrom(randomRow, 0) columnBuilder.appendFrom(nullRow, 0) } val buffer = columnBuilder.build() - // For column type ID - assert(buffer.getInt() === columnType.typeId) - // For null count - assert(buffer.getInt() === 4) + expectResult(columnType.typeId, "Wrong column type ID")(buffer.getInt()) + expectResult(4, "Wrong null count")(buffer.getInt()) + // For null positions - (1 to 7 by 2).foreach(i => assert(buffer.getInt() === i)) + (1 to 7 by 2).foreach(expectResult(_, "Wrong null position")(buffer.getInt())) // For non-null values (0 until 4).foreach { _ => @@ -97,7 +92,8 @@ class NullableColumnBuilderSuite extends FunSuite { } else { columnType.extract(buffer) } - assert(actual === randomRow.head) + + assert(actual === randomRow(0), "Extracted value didn't equal to the original one") } assert(!buffer.hasRemaining) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala index 3a6cc2f2ba56e..184691ab5b46a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/DictionaryEncodingSuite.scala @@ -31,8 +31,6 @@ class DictionaryEncodingSuite extends FunSuite { testDictionaryEncoding(new LongColumnStats, LONG) testDictionaryEncoding(new StringColumnStats, STRING) - val schemeName = DictionaryEncoding.getClass.getSimpleName.stripSuffix("$") - def testDictionaryEncoding[T <: NativeType]( columnStats: NativeColumnStats[T], columnType: NativeColumnType[T]) { @@ -43,7 +41,7 @@ class DictionaryEncodingSuite extends FunSuite { (0 until buffer.getInt()).map(columnType.extract(buffer) -> _.toShort).toMap } - test(s"$schemeName with $typeName: simple case") { + test(s"$DictionaryEncoding with $typeName: simple case") { // ------------- // Tests encoder // ------------- @@ -59,25 +57,25 @@ class DictionaryEncodingSuite extends FunSuite { val buffer = builder.build() val headerSize = CompressionScheme.columnHeaderSize(buffer) - // 4 bytes for dictionary size + // 4 extra bytes for dictionary size val dictionarySize = 4 + values.map(columnType.actualSize).sum + // 4 `Short`s, 2 bytes each val compressedSize = dictionarySize + 2 * 4 - // 4 bytes for compression scheme type ID - assert(buffer.capacity === headerSize + 4 + compressedSize) + // 4 extra bytes for compression scheme type ID + expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity) // Skips column header buffer.position(headerSize) - // Checks compression scheme ID - assert(buffer.getInt() === DictionaryEncoding.typeId) + expectResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) val dictionary = buildDictionary(buffer) - assert(dictionary(values(0)) === (0: Short)) - assert(dictionary(values(1)) === (1: Short)) + Array[Short](0, 1).foreach { i => + expectResult(i, "Wrong dictionary entry")(dictionary(values(i))) + } - assert(buffer.getShort() === (0: Short)) - assert(buffer.getShort() === (1: Short)) - assert(buffer.getShort() === (0: Short)) - assert(buffer.getShort() === (1: Short)) + Array[Short](0, 1, 0, 1).foreach { + expectResult(_, "Wrong column element value")(buffer.getShort()) + } // ------------- // Tests decoder @@ -88,15 +86,15 @@ class DictionaryEncodingSuite extends FunSuite { val decoder = new DictionaryEncoding.Decoder[T](buffer, columnType) - assert(decoder.next() === values(0)) - assert(decoder.next() === values(1)) - assert(decoder.next() === values(0)) - assert(decoder.next() === values(1)) + Array[Short](0, 1, 0, 1).foreach { i => + expectResult(values(i), "Wrong decoded value")(decoder.next()) + } + assert(!decoder.hasNext) } } - test(s"$schemeName: overflow") { + test(s"$DictionaryEncoding: overflow") { val builder = TestCompressibleColumnBuilder(new IntColumnStats, INT, DictionaryEncoding) builder.initialize(0) @@ -106,8 +104,10 @@ class DictionaryEncodingSuite extends FunSuite { builder.appendFrom(row, 0) } - intercept[Throwable] { - builder.build() + withClue("Dictionary overflowed, encoding should fail") { + intercept[Throwable] { + builder.build() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala index ac9bc222aad14..2089ad120d4f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/compression/RunLengthEncodingSuite.scala @@ -35,10 +35,9 @@ class RunLengthEncodingSuite extends FunSuite { columnStats: NativeColumnStats[T], columnType: NativeColumnType[T]) { - val schemeName = RunLengthEncoding.getClass.getSimpleName.stripSuffix("$") val typeName = columnType.getClass.getSimpleName.stripSuffix("$") - test(s"$schemeName with $typeName: simple case") { + test(s"$RunLengthEncoding with $typeName: simple case") { // ------------- // Tests encoder // ------------- @@ -54,20 +53,19 @@ class RunLengthEncodingSuite extends FunSuite { val buffer = builder.build() val headerSize = CompressionScheme.columnHeaderSize(buffer) - // 4 bytes each run for run length + // 4 extra bytes each run for run length val compressedSize = values.map(columnType.actualSize(_) + 4).sum - // 4 bytes for compression scheme type ID - assert(buffer.capacity === headerSize + 4 + compressedSize) + // 4 extra bytes for compression scheme type ID + expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity) // Skips column header buffer.position(headerSize) - // Checks compression scheme ID - assert(buffer.getInt() === RunLengthEncoding.typeId) + expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) - assert(columnType.extract(buffer) === values(0)) - assert(buffer.getInt() === 2) - assert(columnType.extract(buffer) === values(1)) - assert(buffer.getInt() === 2) + Array(0, 1).foreach { i => + expectResult(values(i), "Wrong column element value")(columnType.extract(buffer)) + expectResult(2, "Wrong run length")(buffer.getInt()) + } // ------------- // Tests decoder @@ -78,14 +76,14 @@ class RunLengthEncodingSuite extends FunSuite { val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType) - assert(decoder.next() === values(0)) - assert(decoder.next() === values(0)) - assert(decoder.next() === values(1)) - assert(decoder.next() === values(1)) + Array(0, 0, 1, 1).foreach { i => + expectResult(values(i), "Wrong decoded value")(decoder.next()) + } + assert(!decoder.hasNext) } - test(s"$schemeName with $typeName: run length == 1") { + test(s"$RunLengthEncoding with $typeName: run length == 1") { // ------------- // Tests encoder // ------------- @@ -102,17 +100,16 @@ class RunLengthEncodingSuite extends FunSuite { // 4 bytes each run for run length val compressedSize = values.map(columnType.actualSize(_) + 4).sum // 4 bytes for compression scheme type ID - assert(buffer.capacity === headerSize + 4 + compressedSize) + expectResult(headerSize + 4 + compressedSize, "Wrong buffer capacity")(buffer.capacity) // Skips column header buffer.position(headerSize) - // Checks compression scheme ID - assert(buffer.getInt() === RunLengthEncoding.typeId) + expectResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) - assert(columnType.extract(buffer) === values(0)) - assert(buffer.getInt() === 1) - assert(columnType.extract(buffer) === values(1)) - assert(buffer.getInt() === 1) + Array(0, 1).foreach { i => + expectResult(values(i), "Wrong column element value")(columnType.extract(buffer)) + expectResult(1, "Wrong run length")(buffer.getInt()) + } // ------------- // Tests decoder @@ -123,8 +120,10 @@ class RunLengthEncodingSuite extends FunSuite { val decoder = new RunLengthEncoding.Decoder[T](buffer, columnType) - assert(decoder.next() === values(0)) - assert(decoder.next() === values(1)) + Array(0, 1).foreach { i => + expectResult(values(i), "Wrong decoded value")(decoder.next()) + } + assert(!decoder.hasNext) } }