forked from alteryx/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-8777] [SQL] Add random data generator test utilities to Spark SQL
This commit adds a set of random data generation utilities to Spark SQL, for use in its own unit tests. - `RandomDataGenerator.forType(DataType)` returns an `Option[() => Any]` that, if defined, contains a function for generating random values for the given DataType. The random values use the external representations for the given DataType (for example, for DateType we return `java.sql.Date` instances instead of longs). - `DateTypeTestUtilities` defines some convenience fields for looping over instances of data types. For example, `numericTypes` holds `DataType` instances for all supported numeric types. These constants will help us to raise the level of abstraction in our tests. For example, it's now very easy to write a test which is parameterized by all common data types. Author: Josh Rosen <joshrosen@databricks.com> Closes apache#7176 from JoshRosen/sql-random-data-generators and squashes the following commits: f71634d [Josh Rosen] Roll back ScalaCheck usage e0d7d49 [Josh Rosen] Bump ScalaCheck version in LICENSE 89d86b1 [Josh Rosen] Bump ScalaCheck version. 0c20905 [Josh Rosen] Initial attempt at using ScalaCheck. b55875a [Josh Rosen] Generate doubles and floats over entire possible range. 5acdd5c [Josh Rosen] Infinity and NaN are interesting. ab76cbd [Josh Rosen] Move code to Catalyst package. d2b4a4a [Josh Rosen] Add random data generator test utilities to Spark SQL.
- Loading branch information
Showing
3 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
158 changes: 158 additions & 0 deletions
158
sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql | ||
|
||
import java.lang.Double.longBitsToDouble | ||
import java.lang.Float.intBitsToFloat | ||
import java.math.MathContext | ||
|
||
import scala.util.Random | ||
|
||
import org.apache.spark.sql.types._ | ||
|
||
/** | ||
* Random data generators for Spark SQL DataTypes. These generators do not generate uniformly random | ||
* values; instead, they're biased to return "interesting" values (such as maximum / minimum values) | ||
* with higher probability. | ||
*/ | ||
object RandomDataGenerator { | ||
|
||
/** | ||
* The conditional probability of a non-null value being drawn from a set of "interesting" values | ||
* instead of being chosen uniformly at random. | ||
*/ | ||
private val PROBABILITY_OF_INTERESTING_VALUE: Float = 0.5f | ||
|
||
/** | ||
* The probability of the generated value being null | ||
*/ | ||
private val PROBABILITY_OF_NULL: Float = 0.1f | ||
|
||
private val MAX_STR_LEN: Int = 1024 | ||
private val MAX_ARR_SIZE: Int = 128 | ||
private val MAX_MAP_SIZE: Int = 128 | ||
|
||
/** | ||
* Helper function for constructing a biased random number generator which returns "interesting" | ||
* values with a higher probability. | ||
*/ | ||
private def randomNumeric[T]( | ||
rand: Random, | ||
uniformRand: Random => T, | ||
interestingValues: Seq[T]): Some[() => T] = { | ||
val f = () => { | ||
if (rand.nextFloat() <= PROBABILITY_OF_INTERESTING_VALUE) { | ||
interestingValues(rand.nextInt(interestingValues.length)) | ||
} else { | ||
uniformRand(rand) | ||
} | ||
} | ||
Some(f) | ||
} | ||
|
||
/** | ||
* Returns a function which generates random values for the given [[DataType]], or `None` if no | ||
* random data generator is defined for that data type. The generated values will use an external | ||
* representation of the data type; for example, the random generator for [[DateType]] will return | ||
* instances of [[java.sql.Date]] and the generator for [[StructType]] will return a | ||
* [[org.apache.spark.Row]]. | ||
* | ||
* @param dataType the type to generate values for | ||
* @param nullable whether null values should be generated | ||
* @param seed an optional seed for the random number generator | ||
* @return a function which can be called to generate random values. | ||
*/ | ||
def forType( | ||
dataType: DataType, | ||
nullable: Boolean = true, | ||
seed: Option[Long] = None): Option[() => Any] = { | ||
val rand = new Random() | ||
seed.foreach(rand.setSeed) | ||
|
||
val valueGenerator: Option[() => Any] = dataType match { | ||
case StringType => Some(() => rand.nextString(rand.nextInt(MAX_STR_LEN))) | ||
case BinaryType => Some(() => { | ||
val arr = new Array[Byte](rand.nextInt(MAX_STR_LEN)) | ||
rand.nextBytes(arr) | ||
arr | ||
}) | ||
case BooleanType => Some(() => rand.nextBoolean()) | ||
case DateType => Some(() => new java.sql.Date(rand.nextInt())) | ||
case TimestampType => Some(() => new java.sql.Timestamp(rand.nextLong())) | ||
case DecimalType.Unlimited => Some( | ||
() => BigDecimal.apply(rand.nextLong, rand.nextInt, MathContext.UNLIMITED)) | ||
case DoubleType => randomNumeric[Double]( | ||
rand, r => longBitsToDouble(r.nextLong()), Seq(Double.MinValue, Double.MinPositiveValue, | ||
Double.MaxValue, Double.PositiveInfinity, Double.NegativeInfinity, Double.NaN, 0.0)) | ||
case FloatType => randomNumeric[Float]( | ||
rand, r => intBitsToFloat(r.nextInt()), Seq(Float.MinValue, Float.MinPositiveValue, | ||
Float.MaxValue, Float.PositiveInfinity, Float.NegativeInfinity, Float.NaN, 0.0f)) | ||
case ByteType => randomNumeric[Byte]( | ||
rand, _.nextInt().toByte, Seq(Byte.MinValue, Byte.MaxValue, 0.toByte)) | ||
case IntegerType => randomNumeric[Int]( | ||
rand, _.nextInt(), Seq(Int.MinValue, Int.MaxValue, 0)) | ||
case LongType => randomNumeric[Long]( | ||
rand, _.nextLong(), Seq(Long.MinValue, Long.MaxValue, 0L)) | ||
case ShortType => randomNumeric[Short]( | ||
rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) | ||
case NullType => Some(() => null) | ||
case ArrayType(elementType, containsNull) => { | ||
forType(elementType, nullable = containsNull, seed = Some(rand.nextLong())).map { | ||
elementGenerator => () => Array.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) | ||
} | ||
} | ||
case MapType(keyType, valueType, valueContainsNull) => { | ||
for ( | ||
keyGenerator <- forType(keyType, nullable = false, seed = Some(rand.nextLong())); | ||
valueGenerator <- | ||
forType(valueType, nullable = valueContainsNull, seed = Some(rand.nextLong())) | ||
) yield { | ||
() => { | ||
Seq.fill(rand.nextInt(MAX_MAP_SIZE))((keyGenerator(), valueGenerator())).toMap | ||
} | ||
} | ||
} | ||
case StructType(fields) => { | ||
val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => | ||
forType(field.dataType, nullable = field.nullable, seed = Some(rand.nextLong())) | ||
} | ||
if (maybeFieldGenerators.forall(_.isDefined)) { | ||
val fieldGenerators: Seq[() => Any] = maybeFieldGenerators.map(_.get) | ||
Some(() => Row.fromSeq(fieldGenerators.map(_.apply()))) | ||
} else { | ||
None | ||
} | ||
} | ||
case unsupportedType => None | ||
} | ||
// Handle nullability by wrapping the non-null value generator: | ||
valueGenerator.map { valueGenerator => | ||
if (nullable) { | ||
() => { | ||
if (rand.nextFloat() <= PROBABILITY_OF_NULL) { | ||
null | ||
} else { | ||
valueGenerator() | ||
} | ||
} | ||
} else { | ||
valueGenerator | ||
} | ||
} | ||
} | ||
} |
98 changes: 98 additions & 0 deletions
98
sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGeneratorSuite.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql | ||
|
||
import org.apache.spark.SparkFunSuite | ||
import org.apache.spark.sql.catalyst.CatalystTypeConverters | ||
import org.apache.spark.sql.types._ | ||
|
||
/** | ||
* Tests of [[RandomDataGenerator]]. | ||
*/ | ||
class RandomDataGeneratorSuite extends SparkFunSuite { | ||
|
||
/** | ||
* Tests random data generation for the given type by using it to generate random values then | ||
* converting those values into their Catalyst equivalents using CatalystTypeConverters. | ||
*/ | ||
def testRandomDataGeneration(dataType: DataType, nullable: Boolean = true): Unit = { | ||
val toCatalyst = CatalystTypeConverters.createToCatalystConverter(dataType) | ||
val generator = RandomDataGenerator.forType(dataType, nullable).getOrElse { | ||
fail(s"Random data generator was not defined for $dataType") | ||
} | ||
if (nullable) { | ||
assert(Iterator.fill(100)(generator()).contains(null)) | ||
} else { | ||
assert(Iterator.fill(100)(generator()).forall(_ != null)) | ||
} | ||
for (_ <- 1 to 10) { | ||
val generatedValue = generator() | ||
toCatalyst(generatedValue) | ||
} | ||
} | ||
|
||
// Basic types: | ||
for ( | ||
dataType <- DataTypeTestUtils.atomicTypes; | ||
nullable <- Seq(true, false) | ||
if !dataType.isInstanceOf[DecimalType] || | ||
dataType.asInstanceOf[DecimalType].precisionInfo.isEmpty | ||
) { | ||
test(s"$dataType (nullable=$nullable)") { | ||
testRandomDataGeneration(dataType) | ||
} | ||
} | ||
|
||
for ( | ||
arrayType <- DataTypeTestUtils.atomicArrayTypes | ||
if RandomDataGenerator.forType(arrayType.elementType, arrayType.containsNull).isDefined | ||
) { | ||
test(s"$arrayType") { | ||
testRandomDataGeneration(arrayType) | ||
} | ||
} | ||
|
||
val atomicTypesWithDataGenerators = | ||
DataTypeTestUtils.atomicTypes.filter(RandomDataGenerator.forType(_).isDefined) | ||
|
||
// Complex types: | ||
for ( | ||
keyType <- atomicTypesWithDataGenerators; | ||
valueType <- atomicTypesWithDataGenerators | ||
// Scala's BigDecimal.hashCode can lead to OutOfMemoryError on Scala 2.10 (see SI-6173) and | ||
// Spark can hit NumberFormatException errors when converting certain BigDecimals (SPARK-8802). | ||
// For these reasons, we don't support generation of maps with decimal keys. | ||
if !keyType.isInstanceOf[DecimalType] | ||
) { | ||
val mapType = MapType(keyType, valueType) | ||
test(s"$mapType") { | ||
testRandomDataGeneration(mapType) | ||
} | ||
} | ||
|
||
for ( | ||
colOneType <- atomicTypesWithDataGenerators; | ||
colTwoType <- atomicTypesWithDataGenerators | ||
) { | ||
val structType = StructType(StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) | ||
test(s"$structType") { | ||
testRandomDataGeneration(structType) | ||
} | ||
} | ||
|
||
} |
63 changes: 63 additions & 0 deletions
63
sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeTestUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.types | ||
|
||
/** | ||
* Utility functions for working with DataTypes in tests. | ||
*/ | ||
object DataTypeTestUtils { | ||
|
||
/** | ||
* Instances of all [[IntegralType]]s. | ||
*/ | ||
val integralType: Set[IntegralType] = Set( | ||
ByteType, ShortType, IntegerType, LongType | ||
) | ||
|
||
/** | ||
* Instances of all [[FractionalType]]s, including both fixed- and unlimited-precision | ||
* decimal types. | ||
*/ | ||
val fractionalTypes: Set[FractionalType] = Set( | ||
DecimalType(precisionInfo = None), | ||
DecimalType(2, 1), | ||
DoubleType, | ||
FloatType | ||
) | ||
|
||
/** | ||
* Instances of all [[NumericType]]s. | ||
*/ | ||
val numericTypes: Set[NumericType] = integralType ++ fractionalTypes | ||
|
||
/** | ||
* Instances of all [[AtomicType]]s. | ||
*/ | ||
val atomicTypes: Set[DataType] = numericTypes ++ Set( | ||
BinaryType, | ||
BooleanType, | ||
DateType, | ||
StringType, | ||
TimestampType | ||
) | ||
|
||
/** | ||
* Instances of [[ArrayType]] for all [[AtomicType]]s. Arrays of these types may contain null. | ||
*/ | ||
val atomicArrayTypes: Set[ArrayType] = atomicTypes.map(ArrayType(_, containsNull = true)) | ||
} |