Skip to content

Commit

Permalink
[SPARK-8777] [SQL] Add random data generator test utilities to Spark SQL
Browse files Browse the repository at this point in the history
This commit adds a set of random data generation utilities to Spark SQL, for use in its own unit tests.

- `RandomDataGenerator.forType(DataType)` returns an `Option[() => Any]` that, if defined, contains a function for generating random values for the given DataType.  The random values use the external representations for the given DataType (for example, for DateType we return `java.sql.Date` instances instead of longs).
- `DateTypeTestUtilities` defines some convenience fields for looping over instances of data types.  For example, `numericTypes` holds `DataType` instances for all supported numeric types.  These constants will help us to raise the level of abstraction in our tests.  For example, it's now very easy to write a test which is parameterized by all common data types.

Author: Josh Rosen <joshrosen@databricks.com>

Closes apache#7176 from JoshRosen/sql-random-data-generators and squashes the following commits:

f71634d [Josh Rosen] Roll back ScalaCheck usage
e0d7d49 [Josh Rosen] Bump ScalaCheck version in LICENSE
89d86b1 [Josh Rosen] Bump ScalaCheck version.
0c20905 [Josh Rosen] Initial attempt at using ScalaCheck.
b55875a [Josh Rosen] Generate doubles and floats over entire possible range.
5acdd5c [Josh Rosen] Infinity and NaN are interesting.
ab76cbd [Josh Rosen] Move code to Catalyst package.
d2b4a4a [Josh Rosen] Add random data generator test utilities to Spark SQL.
  • Loading branch information
JoshRosen authored and rxin committed Jul 4, 2015
1 parent 9fb6b83 commit f32487b
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.lang.Double.longBitsToDouble
import java.lang.Float.intBitsToFloat
import java.math.MathContext

import scala.util.Random

import org.apache.spark.sql.types._

/**
* Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random
* values; instead, they're biased to return "interesting" values (such as maximum / minimum values)
* with higher probability.
*/
object RandomDataGenerator {

/**
* The conditional probability of a non-null value being drawn from a set of "interesting" values
* instead of being chosen uniformly at random.
*/
private val PROBABILITY_OF_INTERESTING_VALUE: Float = 0.5f

/**
* The probability of the generated value being null
*/
private val PROBABILITY_OF_NULL: Float = 0.1f

private val MAX_STR_LEN: Int = 1024
private val MAX_ARR_SIZE: Int = 128
private val MAX_MAP_SIZE: Int = 128

/**
* Helper function for constructing a biased random number generator which returns "interesting"
* values with a higher probability.
*/
private def randomNumeric[T](
rand: Random,
uniformRand: Random => T,
interestingValues: Seq[T]): Some[() => T] = {
val f = () => {
if (rand.nextFloat() <= PROBABILITY_OF_INTERESTING_VALUE) {
interestingValues(rand.nextInt(interestingValues.length))
} else {
uniformRand(rand)
}
}
Some(f)
}

/**
* Returns a function which generates random values for the given [[DataType]], or `None` if no
* random data generator is defined for that data type. The generated values will use an external
* representation of the data type; for example, the random generator for [[DateType]] will return
* instances of [[java.sql.Date]] and the generator for [[StructType]] will return a
* [[org.apache.spark.Row]].
*
* @param dataType the type to generate values for
* @param nullable whether null values should be generated
* @param seed an optional seed for the random number generator
* @return a function which can be called to generate random values.
*/
def forType(
dataType: DataType,
nullable: Boolean = true,
seed: Option[Long] = None): Option[() => Any] = {
val rand = new Random()
seed.foreach(rand.setSeed)

val valueGenerator: Option[() => Any] = dataType match {
case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN)))
case BinaryType => Some(() => {
val arr = new Array[Byte](rand.nextInt(MAX_STR_LEN))
rand.nextBytes(arr)
arr
})
case BooleanType => Some(() => rand.nextBoolean())
case DateType => Some(() => new java.sql.Date(rand.nextInt()))
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong()))
case DecimalType.Unlimited => Some(
() => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED))
case DoubleType => randomNumeric[Double](
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue,
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0))
case FloatType => randomNumeric[Float](
rand, r => intBitsToFloat(r.nextInt()), Seq(Float.MinValue, Float.MinPositiveValue,
Float.MaxValue, Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN, 0.0f))
case ByteType => randomNumeric[Byte](
rand, _.nextInt().toByte, Seq(Byte.MinValue, Byte.MaxValue, 0.toByte))
case IntegerType => randomNumeric[Int](
rand, _.nextInt(), Seq(Int.MinValue, Int.MaxValue, 0))
case LongType => randomNumeric[Long](
rand, _.nextLong(), Seq(Long.MinValue, Long.MaxValue, 0L))
case ShortType => randomNumeric[Short](
rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort))
case NullType => Some(() => null)
case ArrayType(elementType, containsNull) => {
forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map {
elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
}
}
case MapType(keyType, valueType, valueContainsNull) => {
for (
keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong()));
valueGenerator <-
forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong()))
) yield {
() => {
Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap
}
}
}
case StructType(fields) => {
val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field =>
forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong()))
}
if (maybeFieldGenerators.forall(_.isDefined)) {
val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get)
Some(() => Row.fromSeq(fieldGenerators.map(_.apply())))
} else {
None
}
}
case unsupportedType => None
}
// Handle nullability by wrapping the non-null value generator:
valueGenerator.map { valueGenerator =>
if (nullable) {
() => {
if (rand.nextFloat() <= PROBABILITY_OF_NULL) {
null
} else {
valueGenerator()
}
}
} else {
valueGenerator
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types._

/**
* Tests of [[RandomDataGenerator]].
*/
class RandomDataGeneratorSuite extends SparkFunSuite {

/**
* Tests random data generation for the given type by using it to generate random values then
* converting those values into their Catalyst equivalents using CatalystTypeConverters.
*/
def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = {
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType)
val generator = RandomDataGenerator.forType(dataType, nullable).getOrElse {
fail(s"Random data generator was not defined for $dataType")
}
if (nullable) {
assert(Iterator.fill(100)(generator()).contains(null))
} else {
assert(Iterator.fill(100)(generator()).forall(_ != null))
}
for (_ <- 1 to 10) {
val generatedValue = generator()
toCatalyst(generatedValue)
}
}

// Basic types:
for (
dataType <- DataTypeTestUtils.atomicTypes;
nullable <- Seq(true, false)
if !dataType.isInstanceOf[DecimalType] ||
dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty
) {
test(s"$dataType (nullable=$nullable)") {
testRandomDataGeneration(dataType)
}
}

for (
arrayType <- DataTypeTestUtils.atomicArrayTypes
if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined
) {
test(s"$arrayType") {
testRandomDataGeneration(arrayType)
}
}

val atomicTypesWithDataGenerators =
DataTypeTestUtils.atomicTypes.filter(RandomDataGenerator.forType(_).isDefined)

// Complex types:
for (
keyType <- atomicTypesWithDataGenerators;
valueType <- atomicTypesWithDataGenerators
// Scala's BigDecimal.hashCode can lead to OutOfMemoryError on Scala 2.10 (see SI-6173) and
// Spark can hit NumberFormatException errors when converting certain BigDecimals (SPARK-8802).
// For these reasons, we don't support generation of maps with decimal keys.
if !keyType.isInstanceOf[DecimalType]
) {
val mapType = MapType(keyType, valueType)
test(s"$mapType") {
testRandomDataGeneration(mapType)
}
}

for (
colOneType <- atomicTypesWithDataGenerators;
colTwoType <- atomicTypesWithDataGenerators
) {
val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil)
test(s"$structType") {
testRandomDataGeneration(structType)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.types

/**
* Utility functions for working with DataTypes in tests.
*/
object DataTypeTestUtils {

/**
* Instances of all [[IntegralType]]s.
*/
val integralType: Set[IntegralType] = Set(
ByteType, ShortType, IntegerType, LongType
)

/**
* Instances of all [[FractionalType]]s, including both fixed- and unlimited-precision
* decimal types.
*/
val fractionalTypes: Set[FractionalType] = Set(
DecimalType(precisionInfo = None),
DecimalType(2, 1),
DoubleType,
FloatType
)

/**
* Instances of all [[NumericType]]s.
*/
val numericTypes: Set[NumericType] = integralType ++ fractionalTypes

/**
* Instances of all [[AtomicType]]s.
*/
val atomicTypes: Set[DataType] = numericTypes ++ Set(
BinaryType,
BooleanType,
DateType,
StringType,
TimestampType
)

/**
* Instances of [[ArrayType]] for all [[AtomicType]]s. Arrays of these types may contain null.
*/
val atomicArrayTypes: Set[ArrayType] = atomicTypes.map(ArrayType(_, containsNull = true))
}

0 comments on commit f32487b

Please sign in to comment.