Skip to content

Commit

Permalink
Add random data generator test utilities to Spark SQL.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Jul 2, 2015
1 parent 3a342de commit d2b4a4a
Show file tree
Hide file tree
Showing 3 changed files with 287 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.test

import org.apache.spark.sql.types._

/**
* Utility functions for working with DataTypes in tests.
*/
object DataTypeTestUtils {

/**
* Instances of all [[IntegralType]]s.
*/
val integralType: Set[IntegralType] = Set(
ByteType, ShortType, IntegerType, LongType
)

/**
* Instances of all [[FractionalType]]s, including both fixed- and unlimited-precision
* decimal types.
*/
val fractionalTypes: Set[FractionalType] = Set(
DecimalType(precisionInfo = None),
DecimalType(2, 1),
DoubleType,
FloatType
)

/**
* Instances of all [[NumericType]]s.
*/
val numericTypes: Set[NumericType] = integralType ++ fractionalTypes

/**
* Instances of all [[AtomicType]]s.
*/
val atomicTypes: Set[DataType] = Set(BinaryType, StringType, TimestampType) ++ numericTypes

/**
* Instances of [[ArrayType]] for all [[AtomicType]]s. Arrays of these types may contain null.
*/
val atomicArrayTypes: Set[ArrayType] = atomicTypes.map(ArrayType(_, containsNull = true))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.test

import org.apache.spark.sql.Row

import scala.util.Random

import org.apache.spark.sql.types._

/**
* Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random
* values; instead, they're biased to return "interesting" values (such as maximum / minimum values)
* with higher probability.
*/
object RandomDataGenerator {

/**
* The conditional probability of a non-null value being drawn from a set of "interesting" values
* instead of being chosen uniformly at random.
*/
private val PROBABILITY_OF_INTERESTING_VALUE: Float = 0.25f

/**
* The probability of the generated value being null
*/
private val PROBABILITY_OF_NULL: Float = 0.1f

private val MAX_STR_LEN: Int = 1024
private val MAX_ARR_SIZE: Int = 128
private val MAX_MAP_SIZE: Int = 128

/**
* Helper function for constructing a biased random number generator which returns "interesting"
* values with a higher probability.
*/
private def randomNumeric[T](
rand: Random,
uniformRand: Random => T,
interestingValues: Seq[T]): Some[() => T] = {
val f = () => {
if (rand.nextFloat() <= PROBABILITY_OF_INTERESTING_VALUE) {
interestingValues(rand.nextInt(interestingValues.length))
} else {
uniformRand(rand)
}
}
Some(f)
}

/**
* Returns a function which generates random values for the given [[DataType]], or `None` if no
* random data generator is defined for that data type. The generated values will use an external
* representation of the data type; for example, the random generator for [[DateType]] will return
* instances of [[java.sql.Date]] and the generator for [[StructType]] will return a
* [[org.apache.spark.Row]].
*
* @param dataType the type to generate values for
* @param nullable whether null values should be generated
* @param seed an optional seed for the random number generator
* @return a function which can be called to generate random values.
*/
def forType(
dataType: DataType,
nullable: Boolean = true,
seed: Option[Long] = None): Option[() => Any] = {
val rand = new Random()
seed.foreach(rand.setSeed)

val valueGenerator: Option[() => Any] = dataType match {
case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN)))
case BinaryType => Some(() => {
val arr = new Array[Byte](rand.nextInt(MAX_STR_LEN))
rand.nextBytes(arr)
arr
})
case BooleanType => Some(() => rand.nextBoolean())
case DateType => Some(() => new java.sql.Date(rand.nextInt(Int.MaxValue)))
case DoubleType => randomNumeric[Double](
rand, _.nextDouble(), Seq(Double.MinValue, Double.MinPositiveValue, Double.MaxValue, 0.0))
case FloatType => randomNumeric[Float](
rand, _.nextFloat(), Seq(Float.MinValue, Float.MinPositiveValue, Float.MaxValue, 0.0f))
case ByteType => randomNumeric[Byte](
rand, _.nextInt().toByte, Seq(Byte.MinValue, Byte.MaxValue, 0.toByte))
case IntegerType => randomNumeric[Int](
rand, _.nextInt(), Seq(Int.MinValue, Int.MaxValue, 0))
case LongType => randomNumeric[Long](
rand, _.nextLong(), Seq(Long.MinValue, Long.MaxValue, 0L))
case ShortType => randomNumeric[Short](
rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort))
case NullType => Some(() => null)
case ArrayType(elementType, containsNull) => {
forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map {
elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator())
}
}
case MapType(keyType, valueType, valueContainsNull) => {
for (
keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong()));
valueGenerator <-
forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong()))
) yield {
() => {
Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap
}
}
}
case StructType(fields) => {
val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field =>
forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong()))
}
if (maybeFieldGenerators.forall(_.isDefined)) {
val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get)
Some(() => Row.fromSeq(fieldGenerators.map(_.apply())))
} else {
None
}
}
case unsupportedType => None
}
// Handle nullability by wrapping the non-null value generator:
valueGenerator.map { valueGenerator =>
if (nullable) {
() => {
if (rand.nextFloat() <= PROBABILITY_OF_NULL) {
null
} else {
valueGenerator()
}
}
} else {
valueGenerator
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.test

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.types.{StructField, StructType, MapType, DataType}

/**
* Tests of [[RandomDataGenerator]].
*/
class RandomDataGeneratorSuite extends SparkFunSuite {

/**
* Tests random data generation for the given type by using it to generate random values then
* converting those values into their Catalyst equivalents using CatalystTypeConverters.
*/
def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = {
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType)
RandomDataGenerator.forType(dataType, nullable, Some(42L)).foreach { generator =>
for (_ <- 1 to 10) {
val generatedValue = generator()
val convertedValue = toCatalyst(generatedValue)
if (!nullable) {
assert(convertedValue !== null)
}
}
}

}

// Basic types:

(DataTypeTestUtils.atomicTypes ++ DataTypeTestUtils.atomicArrayTypes).foreach { dataType =>
test(s"$dataType") {
testRandomDataGeneration(dataType)
}
}

// Complex types:

for (
keyType <- DataTypeTestUtils.atomicTypes;
valueType <- DataTypeTestUtils.atomicTypes
) {
val mapType = MapType(keyType, valueType)
test(s"$mapType") {
testRandomDataGeneration(mapType)
}
}

for (
colOneType <- DataTypeTestUtils.atomicTypes;
colTwoType <- DataTypeTestUtils.atomicTypes
) {
val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil)
test(s"$structType") {
testRandomDataGeneration(structType)
}
}

}

0 comments on commit d2b4a4a

Please sign in to comment.