From 16be3e56e83a406af86b9e7f18059ec7a2595a9e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 9 Jul 2014 11:51:24 -0700 Subject: [PATCH 01/34] This commit contains three changes: * Expose `DataType`s in the sql package (internal details are private to sql). * Introduce `createSchemaRDD` to create a `SchemaRDD` from an `RDD` with a provided schema (represented by a `StructType`) and a provided function to construct `Row`, * Add a function `simpleString` to every `DataType`. Also, the schema represented by a `StructType` can be visualized by `printSchema`. --- .../spark/sql/catalyst/expressions/Row.scala | 5 + .../catalyst/expressions/WrapDynamic.scala | 4 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 47 +--- .../spark/sql/catalyst/types/dataTypes.scala | 206 +++++++++++++----- .../scala/org/apache/spark/sql/package.scala | 44 ++++ .../org/apache/spark/sql/SQLContext.scala | 12 +- .../org/apache/spark/sql/SchemaRDD.scala | 1 - .../org/apache/spark/sql/SchemaRDDLike.scala | 7 +- .../org/apache/spark/sql/json/JsonRDD.scala | 1 + .../apache/spark/sql/hive/HiveContext.scala | 1 - 10 files changed, 221 insertions(+), 107 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index 74ae723686cfe..ff10e198a3cee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -32,6 +32,11 @@ object Row { * }}} */ def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) + + /** + * Construct a [[Row]] with the given values. + */ + def apply(values: Any*): Row = new GenericRow(values.toArray) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index e787c59e75723..c7f8e383ec868 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -21,7 +21,9 @@ import scala.language.dynamics import org.apache.spark.sql.catalyst.types.DataType -case object DynamicType extends DataType +case object DynamicType extends DataType { + def simpleString: String = "dynamic" +} case class WrapDynamic(children: Seq[Attribute]) extends Expression { type EvaluatedType = DynamicRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7b82e19b2e714..c6589d68100ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -125,52 +125,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy }.toSeq } - protected def generateSchemaString(schema: Seq[Attribute]): String = { - val builder = new StringBuilder - builder.append("root\n") - val prefix = " |" - schema.foreach { attribute => - val name = attribute.name - val dataType = attribute.dataType - dataType match { - case fields: StructType => - builder.append(s"$prefix-- $name: $StructType\n") - generateSchemaString(fields, s"$prefix |", builder) - case ArrayType(fields: StructType) => - builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") - generateSchemaString(fields, s"$prefix |", builder) - case ArrayType(elementType: DataType) => - builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") - case _ => builder.append(s"$prefix-- $name: $dataType\n") - } - } - - builder.toString() - } - - protected def generateSchemaString( - schema: StructType, - prefix: String, - builder: StringBuilder): StringBuilder = { - schema.fields.foreach { - case StructField(name, fields: StructType, _) => - builder.append(s"$prefix-- $name: $StructType\n") - generateSchemaString(fields, s"$prefix |", builder) - case StructField(name, ArrayType(fields: StructType), _) => - builder.append(s"$prefix-- $name: $ArrayType[$StructType]\n") - generateSchemaString(fields, s"$prefix |", builder) - case StructField(name, ArrayType(elementType: DataType), _) => - builder.append(s"$prefix-- $name: $ArrayType[$elementType]\n") - case StructField(name, fieldType: DataType, _) => - builder.append(s"$prefix-- $name: $fieldType\n") - } - - builder - } + def schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ - def schemaString: String = generateSchemaString(output) + def formattedSchemaString: String = schema.formattedSchemaString /** Prints out the schema in the tree format */ - def printSchema(): Unit = println(schemaString) + def printSchema(): Unit = println(formattedSchemaString) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index bb77bccf86176..fd67e37b40694 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -62,7 +62,6 @@ object DataType extends RegexParsers { "true" ^^^ true | "false" ^^^ false - protected lazy val structType: Parser[DataType] = "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { case fields => new StructType(fields) @@ -93,47 +92,56 @@ abstract class DataType { } def isPrimitive: Boolean = false + + def simpleString: String } -case object NullType extends DataType +case object NullType extends DataType { + def simpleString: String = "null" +} trait PrimitiveType extends DataType { override def isPrimitive = true } abstract class NativeType extends DataType { - type JvmType - @transient val tag: TypeTag[JvmType] - val ordering: Ordering[JvmType] + private[sql] type JvmType + @transient private[sql] val tag: TypeTag[JvmType] + private[sql] val ordering: Ordering[JvmType] - @transient val classTag = { + @transient private[sql] val classTag = { val mirror = runtimeMirror(Utils.getSparkClassLoader) ClassTag[JvmType](mirror.runtimeClass(tag.tpe)) } } case object StringType extends NativeType with PrimitiveType { - type JvmType = String - @transient lazy val tag = typeTag[JvmType] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = String + @transient private[sql] lazy val tag = typeTag[JvmType] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "string" } case object BinaryType extends DataType with PrimitiveType { - type JvmType = Array[Byte] + private[sql] type JvmType = Array[Byte] + def simpleString: String = "binary" } case object BooleanType extends NativeType with PrimitiveType { - type JvmType = Boolean - @transient lazy val tag = typeTag[JvmType] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Boolean + @transient private[sql] lazy val tag = typeTag[JvmType] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "boolean" } case object TimestampType extends NativeType { - type JvmType = Timestamp + private[sql] type JvmType = Timestamp - @transient lazy val tag = typeTag[JvmType] + @transient private[sql] lazy val tag = typeTag[JvmType] - val ordering = new Ordering[JvmType] { + private[sql] val ordering = new Ordering[JvmType] { def compare(x: Timestamp, y: Timestamp) = x.compareTo(y) } + + def simpleString: String = "timestamp" } abstract class NumericType extends NativeType with PrimitiveType { @@ -142,7 +150,7 @@ abstract class NumericType extends NativeType with PrimitiveType { // type parameter and and add a numeric annotation (i.e., [JvmType : Numeric]). This gets // desugared by the compiler into an argument to the objects constructor. This means there is no // longer an no argument constructor and thus the JVM cannot serialize the object anymore. - val numeric: Numeric[JvmType] + private[sql] val numeric: Numeric[JvmType] } /** Matcher for any expressions that evaluate to [[IntegralType]]s */ @@ -154,39 +162,43 @@ object IntegralType { } abstract class IntegralType extends NumericType { - val integral: Integral[JvmType] + private[sql] val integral: Integral[JvmType] } case object LongType extends IntegralType { - type JvmType = Long - @transient lazy val tag = typeTag[JvmType] - val numeric = implicitly[Numeric[Long]] - val integral = implicitly[Integral[Long]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Long + @transient private[sql] lazy val tag = typeTag[JvmType] + private[sql] val numeric = implicitly[Numeric[Long]] + private[sql] val integral = implicitly[Integral[Long]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "long" } case object IntegerType extends IntegralType { - type JvmType = Int - @transient lazy val tag = typeTag[JvmType] - val numeric = implicitly[Numeric[Int]] - val integral = implicitly[Integral[Int]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Int + @transient private[sql] lazy val tag = typeTag[JvmType] + private[sql] val numeric = implicitly[Numeric[Int]] + private[sql] val integral = implicitly[Integral[Int]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "integer" } case object ShortType extends IntegralType { - type JvmType = Short - @transient lazy val tag = typeTag[JvmType] - val numeric = implicitly[Numeric[Short]] - val integral = implicitly[Integral[Short]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Short + @transient private[sql] lazy val tag = typeTag[JvmType] + private[sql] val numeric = implicitly[Numeric[Short]] + private[sql] val integral = implicitly[Integral[Short]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "short" } case object ByteType extends IntegralType { - type JvmType = Byte - @transient lazy val tag = typeTag[JvmType] - val numeric = implicitly[Numeric[Byte]] - val integral = implicitly[Integral[Byte]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Byte + @transient private[sql] lazy val tag = typeTag[JvmType] + private[sql] val numeric = implicitly[Numeric[Byte]] + private[sql] val integral = implicitly[Integral[Byte]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "byte" } /** Matcher for any expressions that evaluate to [[FractionalType]]s */ @@ -197,47 +209,127 @@ object FractionalType { } } abstract class FractionalType extends NumericType { - val fractional: Fractional[JvmType] + private[sql] val fractional: Fractional[JvmType] } case object DecimalType extends FractionalType { - type JvmType = BigDecimal - @transient lazy val tag = typeTag[JvmType] - val numeric = implicitly[Numeric[BigDecimal]] - val fractional = implicitly[Fractional[BigDecimal]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = BigDecimal + @transient private[sql] lazy val tag = typeTag[JvmType] + private[sql] val numeric = implicitly[Numeric[BigDecimal]] + private[sql] val fractional = implicitly[Fractional[BigDecimal]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "decimal" } case object DoubleType extends FractionalType { - type JvmType = Double - @transient lazy val tag = typeTag[JvmType] - val numeric = implicitly[Numeric[Double]] - val fractional = implicitly[Fractional[Double]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Double + @transient private[sql] lazy val tag = typeTag[JvmType] + private[sql] val numeric = implicitly[Numeric[Double]] + private[sql] val fractional = implicitly[Fractional[Double]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "double" } case object FloatType extends FractionalType { - type JvmType = Float - @transient lazy val tag = typeTag[JvmType] - val numeric = implicitly[Numeric[Float]] - val fractional = implicitly[Fractional[Float]] - val ordering = implicitly[Ordering[JvmType]] + private[sql] type JvmType = Float + @transient private[sql] lazy val tag = typeTag[JvmType] + private[sql] val numeric = implicitly[Numeric[Float]] + private[sql] val fractional = implicitly[Fractional[Float]] + private[sql] val ordering = implicitly[Ordering[JvmType]] + def simpleString: String = "float" } -case class ArrayType(elementType: DataType) extends DataType +case class ArrayType(elementType: DataType) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- element: ${elementType.simpleString}\n") + elementType match { + case array: ArrayType => + array.buildFormattedString(s"$prefix |", builder) + case struct: StructType => + struct.buildFormattedString(s"$prefix |", builder) + case map: MapType => + map.buildFormattedString(s"$prefix |", builder) + case _ => + } + } -case class StructField(name: String, dataType: DataType, nullable: Boolean) + def simpleString: String = "array" +} + +case class StructField(name: String, dataType: DataType, nullable: Boolean) { + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") + dataType match { + case array: ArrayType => + array.buildFormattedString(s"$prefix |", builder) + case struct: StructType => + struct.buildFormattedString(s"$prefix |", builder) + case map: MapType => + map.buildFormattedString(s"$prefix |", builder) + case _ => + } + } +} object StructType { def fromAttributes(attributes: Seq[Attribute]): StructType = { StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) } + private def validateFields(fields: Seq[StructField]): Boolean = + fields.map(field => field.name).distinct.size == fields.size + // def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq) } case class StructType(fields: Seq[StructField]) extends DataType { + require(StructType.validateFields(fields), "Found fields with the same name.") + def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + + def formattedSchemaString: String = { + val builder = new StringBuilder + builder.append("root\n") + val prefix = " |" + fields.foreach(field => field.buildFormattedString(prefix, builder)) + + builder.toString() + } + + def printSchema(): Unit = println(formattedSchemaString) + + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + fields.foreach(field => field.buildFormattedString(prefix, builder)) + } + + def simpleString: String = "struct" } -case class MapType(keyType: DataType, valueType: DataType) extends DataType +case class MapType(keyType: DataType, valueType: DataType) extends DataType { + private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { + builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") + keyType match { + case array: ArrayType => + array.buildFormattedString(s"$prefix |", builder) + case struct: StructType => + struct.buildFormattedString(s"$prefix |", builder) + case map: MapType => + map.buildFormattedString(s"$prefix |", builder) + case _ => + } + + builder.append(s"${prefix}-- value: ${valueType.simpleString}\n") + valueType match { + case array: ArrayType => + array.buildFormattedString(s"$prefix |", builder) + case struct: StructType => + struct.buildFormattedString(s"$prefix |", builder) + case map: MapType => + map.buildFormattedString(s"$prefix |", builder) + case _ => + } + } + + def simpleString: String = "map" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala index 4589129cd1c90..2099804073c08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala @@ -33,4 +33,48 @@ package object sql { type Row = catalyst.expressions.Row val Row = catalyst.expressions.Row + + type DataType = catalyst.types.DataType + + val DataType = catalyst.types.DataType + + val NullType = catalyst.types.NullType + + val StringType = catalyst.types.StringType + + val BinaryType = catalyst.types.BinaryType + + val BooleanType = catalyst.types.BooleanType + + val TimestampType = catalyst.types.TimestampType + + val DecimalType = catalyst.types.DecimalType + + val DoubleType = catalyst.types.DoubleType + + val FloatType = catalyst.types.FloatType + + val ByteType = catalyst.types.ByteType + + val IntegerType = catalyst.types.IntegerType + + val LongType = catalyst.types.LongType + + val ShortType = catalyst.types.ShortType + + type ArrayType = catalyst.types.ArrayType + + val ArrayType = catalyst.types.ArrayType + + type MapType = catalyst.types.MapType + + val MapType = catalyst.types.MapType + + type StructType = catalyst.types.StructType + + val StructType = catalyst.types.StructType + + type StructField = catalyst.types.StructField + + val StructField = catalyst.types.StructField } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 4abd89955bd27..14904d416eae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import scala.language.implicitConversions +import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration @@ -28,7 +29,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -88,6 +88,16 @@ class SQLContext(@transient val sparkContext: SparkContext) implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))) + /** + * Creates a SchemaRDD from an RDD by applying a schema and providing a function to construct + * a Row from a RDD record. + * + * @group userf + */ + def createSchemaRDD[A](rdd: RDD[A], schema: StructType, constructRow: A => Row) = { + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rdd.map(constructRow)))) + } + /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 8bcfc7c064c2f..304f17f4f5055 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.types.{ArrayType, BooleanType, StructType} import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.api.java.JavaRDD diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index fe81721943202..9737e62c0f839 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -123,9 +123,12 @@ private[sql] trait SchemaRDDLike { def saveAsTable(tableName: String): Unit = sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd + /** Returns the schema. */ + def schema: StructType = queryExecution.analyzed.schema + /** Returns the output schema in the tree format. */ - def schemaString: String = queryExecution.analyzed.schemaString + def formattedSchemaString: String = schema.formattedSchemaString /** Prints out the schema in the tree format. */ - def printSchema(): Unit = println(schemaString) + def printSchema(): Unit = println(formattedSchemaString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index f6cbca96483e2..6ffa514becd0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -344,6 +344,7 @@ private[sql] object JsonRDD extends Logging { } } + // TODO: Reuse the row instead of creating a new one for every record. private def asRow(json: Map[String,Any], schema: StructType): Row = { val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7aedfcd74189b..02a3dee67b464 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -35,7 +35,6 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.execution.{Command => PhysicalCommand} import org.apache.spark.sql.hive.execution.DescribeHiveTableCommand From 3fa0df5e888982c5a68240bd9ae139745a567a3a Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 9 Jul 2014 22:03:32 -0700 Subject: [PATCH 02/34] Provide easier ways to construct a StructType. --- .../apache/spark/sql/catalyst/types/dataTypes.scala | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index fd67e37b40694..5924eb3b063a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -273,14 +273,18 @@ case class StructField(name: String, dataType: DataType, nullable: Boolean) { } object StructType { - def fromAttributes(attributes: Seq[Attribute]): StructType = { + def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) - } private def validateFields(fields: Seq[StructField]): Boolean = fields.map(field => field.name).distinct.size == fields.size - // def apply(fields: Seq[StructField]) = new StructType(fields.toIndexedSeq) + def apply[A <: String: ClassTag, B <: DataType: ClassTag](fields: (A, B)*): StructType = + StructType(fields.map(field => StructField(field._1, field._2, true))) + + def apply[A <: String: ClassTag, B <: DataType: ClassTag, C <: Boolean: ClassTag]( + fields: (A, B, C)*): StructType = + StructType(fields.map(field => StructField(field._1, field._2, field._3))) } case class StructType(fields: Seq[StructField]) extends DataType { From 90460acf7aad1a5535e0c34f7f165f961c45b4d7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 10 Jul 2014 15:24:03 -0700 Subject: [PATCH 03/34] Infer the Catalyst data type from an object and cast a data value to the expected type. --- .../spark/sql/catalyst/ScalaReflection.scala | 20 ++ .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../spark/sql/catalyst/types/dataTypes.scala | 6 +- .../sql/catalyst/ScalaReflectionSuite.scala | 66 ++++++- .../org/apache/spark/sql/SQLContext.scala | 61 +++--- .../org/apache/spark/sql/SchemaRDDLike.scala | 4 +- .../spark/sql/api/java/JavaSQLContext.scala | 9 +- .../org/apache/spark/sql/json/JsonRDD.scala | 163 ++++++---------- .../org/apache/spark/sql/util/package.scala | 175 ++++++++++++++++++ .../spark/sql/api/java/JavaSQLSuite.scala | 2 +- .../org/apache/spark/sql/json/JsonSuite.scala | 20 +- 11 files changed, 386 insertions(+), 144 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/util/package.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 5a55be1e51558..8f7ecc8c46ff4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -85,6 +85,26 @@ object ScalaReflection { case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) } + def typeOfObject: PartialFunction[Any, DataType] = { + // The type of the can be determined without ambiguity. + case obj: BooleanType.JvmType => BooleanType + case obj: BinaryType.JvmType => BinaryType + case obj: StringType.JvmType => StringType + case obj: ByteType.JvmType => ByteType + case obj: ShortType.JvmType => ShortType + case obj: IntegerType.JvmType => IntegerType + case obj: LongType.JvmType => LongType + case obj: FloatType.JvmType => FloatType + case obj: DoubleType.JvmType => DoubleType + case obj: DecimalType.JvmType => DecimalType + case obj: TimestampType.JvmType => TimestampType + case null => NullType + // There is no obvious mapping from the type of the given object to a Catalyst data type. + // A user should provide his/her specific rules (in a user-defined PartialFunction) to infer + // the Catalyst data type for other types of objects and then compose the user-defined + // PartialFunction with this one. + } + implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c6589d68100ae..0320682d47ce5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -128,8 +128,8 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ - def formattedSchemaString: String = schema.formattedSchemaString + def schemaString: String = schema.schemaString /** Prints out the schema in the tree format */ - def printSchema(): Unit = println(formattedSchemaString) + def printSchema(): Unit = println(schemaString) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 5924eb3b063a2..3627577aac064 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -121,10 +121,12 @@ case object StringType extends NativeType with PrimitiveType { private[sql] val ordering = implicitly[Ordering[JvmType]] def simpleString: String = "string" } + case object BinaryType extends DataType with PrimitiveType { private[sql] type JvmType = Array[Byte] def simpleString: String = "binary" } + case object BooleanType extends NativeType with PrimitiveType { private[sql] type JvmType = Boolean @transient private[sql] lazy val tag = typeTag[JvmType] @@ -292,7 +294,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) - def formattedSchemaString: String = { + def schemaString: String = { val builder = new StringBuilder builder.append("root\n") val prefix = " |" @@ -301,7 +303,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { builder.toString() } - def printSchema(): Unit = println(formattedSchemaString) + def printSchema(): Unit = println(schemaString) private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { fields.foreach(field => field.buildFormattedString(prefix, builder)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index c0438dbe52a47..e030d6e13d472 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst +import java.math.BigInteger import java.sql.Timestamp import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ case class PrimitiveData( @@ -148,4 +148,68 @@ class ScalaReflectionSuite extends FunSuite { StructField("_2", StringType, nullable = true))), nullable = true)) } + + test("get data type of a value") { + // BooleanType + assert(BooleanType === typeOfObject(true)) + assert(BooleanType === typeOfObject(false)) + + // BinaryType + assert(BinaryType === typeOfObject("string".getBytes)) + + // StringType + assert(StringType === typeOfObject("string")) + + // ByteType + assert(ByteType === typeOfObject(127.toByte)) + + // ShortType + assert(ShortType === typeOfObject(32767.toShort)) + + // IntegerType + assert(IntegerType === typeOfObject(2147483647)) + + // LongType + assert(LongType === typeOfObject(9223372036854775807L)) + + // FloatType + assert(FloatType === typeOfObject(3.4028235E38.toFloat)) + + // DoubleType + assert(DoubleType === typeOfObject(1.7976931348623157E308)) + + // DecimalType + assert(DecimalType === typeOfObject(BigDecimal("1.7976931348623157E318"))) + + // TimestampType + assert(TimestampType === typeOfObject(java.sql.Timestamp.valueOf("2014-7-25 10:26:00"))) + + // NullType + assert(NullType === typeOfObject(null)) + + def typeOfObject1: PartialFunction[Any, DataType] = typeOfObject orElse { + case value: java.math.BigInteger => DecimalType + case value: java.math.BigDecimal => DecimalType + case _ => StringType + } + + assert(DecimalType === typeOfObject1( + new BigInteger("92233720368547758070"))) + assert(DecimalType === typeOfObject1( + new java.math.BigDecimal("1.7976931348623157E318"))) + assert(StringType === typeOfObject1(BigInt("92233720368547758070"))) + + def typeOfObject2: PartialFunction[Any, DataType] = typeOfObject orElse { + case value: java.math.BigInteger => DecimalType + } + + intercept[MatchError](typeOfObject2(BigInt("92233720368547758070"))) + + def typeOfObject3: PartialFunction[Any, DataType] = typeOfObject orElse { + case c: Seq[_] => ArrayType(typeOfObject3(c.head)) + } + + assert(ArrayType(IntegerType) === typeOfObject3(Seq(1, 2, 3))) + assert(ArrayType(ArrayType(IntegerType)) === typeOfObject3(Seq(Seq(1,2,3)))) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 14904d416eae3..99aaffe1f5ce4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration @@ -26,12 +25,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies @@ -89,14 +88,31 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))) /** - * Creates a SchemaRDD from an RDD by applying a schema and providing a function to construct - * a Row from a RDD record. + * Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function + * that will be applied to each partition of the RDD to convert RDD records to [[Row]]s. * * @group userf */ - def createSchemaRDD[A](rdd: RDD[A], schema: StructType, constructRow: A => Row) = { - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rdd.map(constructRow)))) - } + def applySchema[A](rdd: RDD[A],schema: StructType, f: A => Row): SchemaRDD = + applySchemaPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f)) + + /** + * Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function + * that will be applied to each partition of the RDD to convert RDD records to [[Row]]s. + * + * @group userf + */ + def applySchemaPartitions[A]( + rdd: RDD[A], + schema: StructType, + f: Iterator[A] => Iterator[Row]): SchemaRDD = + new SchemaRDD(this, makeCustomRDDScan(rdd, schema, f)) + + protected[sql] def makeCustomRDDScan[A]( + rdd: RDD[A], + schema: StructType, + f: Iterator[A] => Iterator[Row]): LogicalPlan = + SparkLogicalPlan(ExistingRdd(schema.toAttributes, rdd.mapPartitions(f))) /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. @@ -136,8 +152,10 @@ class SQLContext(@transient val sparkContext: SparkContext) * :: Experimental :: */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = - new SchemaRDD(this, JsonRDD.inferSchema(json, samplingRatio)) + def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) + applySchemaPartitions(json, schema, JsonRDD.jsonStringToRow(schema, _: Iterator[String])) + } /** * :: Experimental :: @@ -352,28 +370,29 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * Peek at the first row of the RDD and infer its schema. - * TODO: consolidate this with the type system developed in SPARK-2060. */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { import scala.collection.JavaConversions._ - def typeFor(obj: Any): DataType = obj match { - case c: java.lang.String => StringType - case c: java.lang.Integer => IntegerType - case c: java.lang.Long => LongType - case c: java.lang.Double => DoubleType - case c: java.lang.Boolean => BooleanType - case c: java.util.List[_] => ArrayType(typeFor(c.head)) - case c: java.util.Set[_] => ArrayType(typeFor(c.head)) + def typeOfComplexValue: PartialFunction[Any, DataType] = { + case c: java.util.List[_] => + ArrayType(ScalaReflection.typeOfObject(c.head)) + case c: java.util.Set[_] => + ArrayType(ScalaReflection.typeOfObject(c.head)) case c: java.util.Map[_, _] => val (key, value) = c.head - MapType(typeFor(key), typeFor(value)) + MapType( + ScalaReflection.typeOfObject(key), + ScalaReflection.typeOfObject(value)) case c if c.getClass.isArray => val elem = c.asInstanceOf[Array[_]].head - ArrayType(typeFor(elem)) + ArrayType(ScalaReflection.typeOfObject(elem)) case c => throw new Exception(s"Object of type $c cannot be used") } + + def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue + val schema = rdd.first().map { case (fieldName, obj) => - AttributeReference(fieldName, typeFor(obj), true)() + AttributeReference(fieldName, typeOfObject(obj), true)() }.toSeq val rowRdd = rdd.mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index 9737e62c0f839..d60b4eca52ff0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -127,8 +127,8 @@ private[sql] trait SchemaRDDLike { def schema: StructType = queryExecution.analyzed.schema /** Returns the output schema in the tree format. */ - def formattedSchemaString: String = schema.formattedSchemaString + def schemaString: String = schema.schemaString /** Prints out the schema in the tree format. */ - def printSchema(): Unit = println(formattedSchemaString) + def printSchema(): Unit = println(schemaString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 790d9ef22cf16..0f925cca07e25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -119,8 +119,13 @@ class JavaSQLContext(val sqlContext: SQLContext) { * * @group userf */ - def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = - new JavaSchemaRDD(sqlContext, JsonRDD.inferSchema(json, 1.0)) + def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { + val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0)) + val logicalPlan = + sqlContext.makeCustomRDDScan[String](json, schema, JsonRDD.jsonStringToRow(schema, _)) + + new JavaSchemaRDD(sqlContext, logicalPlan) + } /** * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 6ffa514becd0c..f8aba3d543932 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -25,30 +25,25 @@ import com.fasterxml.jackson.databind.ObjectMapper import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.Logging +import org.apache.spark.sql.util private[sql] object JsonRDD extends Logging { + def jsonStringToRow(schema: StructType, jsonIter: Iterator[String]): Iterator[Row] = { + parseJson(jsonIter).map(parsed => asRow(parsed, schema)) + } + private[sql] def inferSchema( json: RDD[String], - samplingRatio: Double = 1.0): LogicalPlan = { + samplingRatio: Double = 1.0): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) - val baseSchema = createSchema(allKeys) - - createLogicalPlan(json, baseSchema) - } - - private def createLogicalPlan( - json: RDD[String], - baseSchema: StructType): LogicalPlan = { - val schema = nullTypeToStringType(baseSchema) - - SparkLogicalPlan(ExistingRdd(asAttributes(schema), parseJson(json).map(asRow(_, schema)))) + val allKeys = + schemaData.mapPartitions(iter => parseJson(iter)).map(allKeysWithValueTypes).reduce(_ ++ _) + createSchema(allKeys) } private def createSchema(allKeys: Set[(String, DataType)]): StructType = { @@ -106,6 +101,22 @@ private[sql] object JsonRDD extends Logging { makeStruct(resolved.keySet.toSeq, Nil) } + private[sql] def nullTypeToStringType(struct: StructType): StructType = { + val fields = struct.fields.map { + case StructField(fieldName, dataType, nullable) => { + val newType = dataType match { + case NullType => StringType + case ArrayType(NullType) => ArrayType(StringType) + case struct: StructType => nullTypeToStringType(struct) + case other: DataType => other + } + StructField(fieldName, newType, nullable) + } + } + + StructType(fields) + } + /** * Returns the most general data type for two given data types. */ @@ -145,18 +156,13 @@ private[sql] object JsonRDD extends Logging { } } - private def typeOfPrimitiveValue(value: Any): DataType = { - value match { - case value: java.lang.String => StringType - case value: java.lang.Integer => IntegerType - case value: java.lang.Long => LongType + private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = { + ScalaReflection.typeOfObject orElse { // Since we do not have a data type backed by BigInteger, // when we see a Java BigInteger, we use DecimalType. case value: java.math.BigInteger => DecimalType - case value: java.lang.Double => DoubleType + // DecimalType's JVMType is scala BigDecimal. case value: java.math.BigDecimal => DecimalType - case value: java.lang.Boolean => BooleanType - case null => NullType // Unexpected data type. case _ => StringType } @@ -245,7 +251,7 @@ private[sql] object JsonRDD extends Logging { case atom => atom } - private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = { + private def parseJson(jsonIter: Iterator[String]): Iterator[Map[String, Any]] = { // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], // ObjectMapper will not return BigDecimal when // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled @@ -254,38 +260,23 @@ private[sql] object JsonRDD extends Logging { // for every float number, which will be slow. // So, right now, we will have Infinity for those BigDecimal number. // TODO: Support BigDecimal. - json.mapPartitions(iter => { - // When there is a key appearing multiple times (a duplicate key), - // the ObjectMapper will take the last value associated with this duplicate key. - // For example: for {"key": 1, "key":2}, we will get "key"->2. - val mapper = new ObjectMapper() - iter.map(record => mapper.readValue(record, classOf[java.util.Map[String, Any]])) - }).map(scalafy).map(_.asInstanceOf[Map[String, Any]]) - } - - private def toLong(value: Any): Long = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toLong - case value: java.lang.Long => value.asInstanceOf[Long] + // Also, when there is a key appearing multiple times (a duplicate key), + // the ObjectMapper will take the last value associated with this duplicate key. + // For example: for {"key": 1, "key":2}, we will get "key"->2. + val mapper = new ObjectMapper() + jsonIter.map { + record => + val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) + parsed.asInstanceOf[Map[String, Any]] } } - private def toDouble(value: Any): Double = { - value match { - case value: java.lang.Integer => value.asInstanceOf[Int].toDouble - case value: java.lang.Long => value.asInstanceOf[Long].toDouble - case value: java.lang.Double => value.asInstanceOf[Double] + private def toDecimalValue: PartialFunction[Any, BigDecimal] = { + def bigIntegerToDecimalValue: PartialFunction[Any, BigDecimal] = { + case v: java.math.BigInteger => BigDecimal(v) } - } - private def toDecimal(value: Any): BigDecimal = { - value match { - case value: java.lang.Integer => BigDecimal(value) - case value: java.lang.Long => BigDecimal(value) - case value: java.math.BigInteger => BigDecimal(value) - case value: java.lang.Double => BigDecimal(value) - case value: java.math.BigDecimal => BigDecimal(value) - } + bigIntegerToDecimalValue orElse util.toDecimalValue } private def toJsonArrayString(seq: Seq[Any]): String = { @@ -296,7 +287,7 @@ private[sql] object JsonRDD extends Logging { element => if (count > 0) builder.append(",") count += 1 - builder.append(toString(element)) + builder.append(toStringValue(element)) } builder.append("]") @@ -311,41 +302,35 @@ private[sql] object JsonRDD extends Logging { case (key, value) => if (count > 0) builder.append(",") count += 1 - builder.append(s"""\"${key}\":${toString(value)}""") + builder.append(s"""\"${key}\":${toStringValue(value)}""") } builder.append("}") builder.toString() } - private def toString(value: Any): String = { - value match { - case value: Map[String, Any] => toJsonObjectString(value) - case value: Seq[Any] => toJsonArrayString(value) - case value => Option(value).map(_.toString).orNull + private def toStringValue: PartialFunction[Any, String] = { + def complexValueToStringValue: PartialFunction[Any, String] = { + case v: Map[String, Any] => toJsonObjectString(v) + case v: Seq[Any] => toJsonArrayString(v) } + + complexValueToStringValue orElse util.toStringValue } - private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ - if (value == null) { - null - } else { - desiredType match { - case ArrayType(elementType) => - value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) - case StringType => toString(value) - case IntegerType => value.asInstanceOf[IntegerType.JvmType] - case LongType => toLong(value) - case DoubleType => toDouble(value) - case DecimalType => toDecimal(value) - case BooleanType => value.asInstanceOf[BooleanType.JvmType] - case NullType => null - } + private[json] def castToType: PartialFunction[(Any, DataType), Any] = { + def jsonSpecificCast: PartialFunction[(Any, DataType), Any] = { + case (v, StringType) => toStringValue(v) + case (v, DecimalType) => toDecimalValue(v) + case (v, ArrayType(elementType)) => + v.asInstanceOf[Seq[Any]].map(castToType(_, elementType)) } + + jsonSpecificCast orElse util.castToType } - // TODO: Reuse the row instead of creating a new one for every record. private def asRow(json: Map[String,Any], schema: StructType): Row = { + // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { // StructType @@ -363,37 +348,9 @@ private[sql] object JsonRDD extends Logging { // Other cases case (StructField(name, dataType, _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( - enforceCorrectType(_, dataType)).getOrElse(null)) + castToType(_, dataType)).getOrElse(null)) } row } - - private def nullTypeToStringType(struct: StructType): StructType = { - val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable) => { - val newType = dataType match { - case NullType => StringType - case ArrayType(NullType) => ArrayType(StringType) - case struct: StructType => nullTypeToStringType(struct) - case other: DataType => other - } - StructField(fieldName, newType, nullable) - } - } - - StructType(fields) - } - - private def asAttributes(struct: StructType): Seq[AttributeReference] = { - struct.fields.map(f => AttributeReference(f.name, f.dataType, nullable = true)()) - } - - private def asStruct(attributes: Seq[AttributeReference]): StructType = { - val fields = attributes.map { - case AttributeReference(name, dataType, nullable) => StructField(name, dataType, nullable) - } - - StructType(fields) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/package.scala new file mode 100644 index 0000000000000..2be4d5cf53af2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/package.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.math.BigDecimal + +import org.apache.spark.annotation.DeveloperApi + +package object util { + + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toBooleanValue: PartialFunction[Any, BooleanType.JvmType] = { + case v: BooleanType.JvmType => v + case v: ByteType.JvmType if v == 1 => true + case v: ByteType.JvmType if v == 0 => false + case v: ShortType.JvmType if v == 1 => true + case v: ShortType.JvmType if v == 0 => false + case v: IntegerType.JvmType if v == 1 => true + case v: IntegerType.JvmType if v == 0 => false + case v: LongType.JvmType if v == 1 => true + case v: LongType.JvmType if v == 0 => false + case v: StringType.JvmType if v.toLowerCase == "true" => true + case v: StringType.JvmType if v.toLowerCase == "false" => false + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toStringValue: PartialFunction[Any, StringType.JvmType] = { + case v => Option(v).map(_.toString).orNull + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toByteValue: PartialFunction[Any, ByteType.JvmType] = { + case v: BooleanType.JvmType => if (v) 1.toByte else 0.toByte + case v: ByteType.JvmType => v + case v: StringType.JvmType => v.toByte + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toShortValue: PartialFunction[Any, ShortType.JvmType] = { + case v: BooleanType.JvmType => if (v) 1.toShort else 0.toShort + case v: ByteType.JvmType => v.toShort + case v: ShortType.JvmType => v + case v: StringType.JvmType => v.toShort + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toIntegerValue: PartialFunction[Any, IntegerType.JvmType] = { + case v: BooleanType.JvmType => if (v) 1 else 0 + case v: ByteType.JvmType => v.toInt + case v: ShortType.JvmType => v.toInt + case v: IntegerType.JvmType => v + case v: StringType.JvmType => v.toInt + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toLongValue: PartialFunction[Any, LongType.JvmType] = { + case v: BooleanType.JvmType => if (v) 1.toLong else 0.toLong + case v: ByteType.JvmType => v.toLong + case v: ShortType.JvmType => v.toLong + case v: IntegerType.JvmType => v.toLong + case v: LongType.JvmType => v + // We can convert a Timestamp object to a Long because a Long representation of + // a Timestamp object has a clear meaning + // (milliseconds since January 1, 1970, 00:00:00 GMT). + case v: TimestampType.JvmType => v.getTime + case v: StringType.JvmType => v.toLong + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toFloatValue: PartialFunction[Any, FloatType.JvmType] = { + case v: BooleanType.JvmType => if (v) 1.toFloat else 0.toFloat + case v: ByteType.JvmType => v.toFloat + case v: ShortType.JvmType => v.toFloat + case v: IntegerType.JvmType => v.toFloat + case v: FloatType.JvmType => v + case v: StringType.JvmType => v.toFloat + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toDoubleValue: PartialFunction[Any, DoubleType.JvmType] = { + case v: BooleanType.JvmType => if (v) 1.toDouble else 0.toDouble + case v: ByteType.JvmType => v.toDouble + case v: ShortType.JvmType => v.toDouble + case v: IntegerType.JvmType => v.toDouble + case v: LongType.JvmType => v.toDouble + case v: FloatType.JvmType => v.toDouble + case v: DoubleType.JvmType => v + case v: StringType.JvmType => v.toDouble + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toDecimalValue: PartialFunction[Any, DecimalType.JvmType] = { + case v: BooleanType.JvmType => if (v) BigDecimal(1) else BigDecimal(0) + case v: ByteType.JvmType => BigDecimal(v) + case v: ShortType.JvmType => BigDecimal(v) + case v: IntegerType.JvmType => BigDecimal(v) + case v: LongType.JvmType => BigDecimal(v) + case v: FloatType.JvmType => BigDecimal(v) + case v: DoubleType.JvmType => BigDecimal(v) + case v: TimestampType.JvmType => BigDecimal(v.getTime) + case v: StringType.JvmType => BigDecimal(v) + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def toTimestampValue: PartialFunction[Any, TimestampType.JvmType] = { + case v: LongType.JvmType => new java.sql.Timestamp(v) + case v: TimestampType.JvmType => v + case v: StringType.JvmType => java.sql.Timestamp.valueOf(v) + } + + /** + * :: DeveloperApi :: + */ + @DeveloperApi + def castToType: PartialFunction[(Any, DataType), Any] = { + case (null, _) => null + case (_, NullType) => null + case (v, BooleanType) => toBooleanValue(v) + case (v, StringType) => toStringValue(v) + case (v, ByteType) => toByteValue(v) + case (v, ShortType) => toShortValue(v) + case (v, IntegerType) => toIntegerValue(v) + case (v, LongType) => toLongValue(v) + case (v, FloatType) => toFloatValue(v) + case (v, DoubleType) => toDoubleValue(v) + case (v, DecimalType) => toDecimalValue(v) + case (v, TimestampType) => toTimestampValue(v) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index 020baf0c7ec6f..e05c684d97eee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -47,7 +47,7 @@ class AllTypesBean extends Serializable { @BeanProperty var booleanField: java.lang.Boolean = _ } -class JavaSQLSuite extends FunSuite { + class JavaSQLSuite extends FunSuite { val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) val javaSqlCtx = new JavaSQLContext(javaCtx) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index e765cfc83a397..27391e6708076 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} +import org.apache.spark.sql.json.JsonRDD.{castToType, compatibleType} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.TestSQLContext._ @@ -41,19 +41,19 @@ class JsonSuite extends QueryTest { } val intNumber: Int = 2147483647 - checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) - checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) - checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) - checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType)) + checkTypePromotion(intNumber, castToType(intNumber, IntegerType)) + checkTypePromotion(intNumber.toLong, castToType(intNumber, LongType)) + checkTypePromotion(intNumber.toDouble, castToType(intNumber, DoubleType)) + checkTypePromotion(BigDecimal(intNumber), castToType(intNumber, DecimalType)) val longNumber: Long = 9223372036854775807L - checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) - checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) - checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType)) + checkTypePromotion(longNumber, castToType(longNumber, LongType)) + checkTypePromotion(longNumber.toDouble, castToType(longNumber, DoubleType)) + checkTypePromotion(BigDecimal(longNumber), castToType(longNumber, DecimalType)) val doubleNumber: Double = 1.7976931348623157E308d - checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) - checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) + checkTypePromotion(doubleNumber.toDouble, castToType(doubleNumber, DoubleType)) + checkTypePromotion(BigDecimal(doubleNumber), castToType(doubleNumber, DecimalType)) } test("Get compatible type") { From 0266761373a2295c6dddc16e72fe97dcab8fa09e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 10 Jul 2014 15:26:27 -0700 Subject: [PATCH 04/34] Format --- .../test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala index e05c684d97eee..020baf0c7ec6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/JavaSQLSuite.scala @@ -47,7 +47,7 @@ class AllTypesBean extends Serializable { @BeanProperty var booleanField: java.lang.Boolean = _ } - class JavaSQLSuite extends FunSuite { +class JavaSQLSuite extends FunSuite { val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) val javaSqlCtx = new JavaSQLContext(javaCtx) From 43a45e170577198b2c424e45f7c90dfa928031a7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 10 Jul 2014 20:10:58 -0700 Subject: [PATCH 05/34] Remove sql.util.package introduced in a previous commit. --- .../org/apache/spark/sql/SQLContext.scala | 6 +- .../org/apache/spark/sql/json/JsonRDD.scala | 66 ++++--- .../org/apache/spark/sql/util/package.scala | 175 ------------------ .../org/apache/spark/sql/json/JsonSuite.scala | 20 +- 4 files changed, 57 insertions(+), 210 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/util/package.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 99aaffe1f5ce4..024dc337cd047 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -94,7 +94,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def applySchema[A](rdd: RDD[A],schema: StructType, f: A => Row): SchemaRDD = - applySchemaPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f)) + applySchemaToPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f)) /** * Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function @@ -102,7 +102,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def applySchemaPartitions[A]( + def applySchemaToPartitions[A]( rdd: RDD[A], schema: StructType, f: Iterator[A] => Iterator[Row]): SchemaRDD = @@ -154,7 +154,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) - applySchemaPartitions(json, schema, JsonRDD.jsonStringToRow(schema, _: Iterator[String])) + applySchemaToPartitions(json, schema, JsonRDD.jsonStringToRow(schema, _: Iterator[String])) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index f8aba3d543932..bec741c96b678 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.Logging -import org.apache.spark.sql.util private[sql] object JsonRDD extends Logging { @@ -271,12 +270,29 @@ private[sql] object JsonRDD extends Logging { } } - private def toDecimalValue: PartialFunction[Any, BigDecimal] = { - def bigIntegerToDecimalValue: PartialFunction[Any, BigDecimal] = { - case v: java.math.BigInteger => BigDecimal(v) + private def toLong(value: Any): Long = { + value match { + case value: java.lang.Integer => value.asInstanceOf[Int].toLong + case value: java.lang.Long => value.asInstanceOf[Long] } + } - bigIntegerToDecimalValue orElse util.toDecimalValue + private def toDouble(value: Any): Double = { + value match { + case value: java.lang.Integer => value.asInstanceOf[Int].toDouble + case value: java.lang.Long => value.asInstanceOf[Long].toDouble + case value: java.lang.Double => value.asInstanceOf[Double] + } + } + + private def toDecimal(value: Any): BigDecimal = { + value match { + case value: java.lang.Integer => BigDecimal(value) + case value: java.lang.Long => BigDecimal(value) + case value: java.math.BigInteger => BigDecimal(value) + case value: java.lang.Double => BigDecimal(value) + case value: java.math.BigDecimal => BigDecimal(value) + } } private def toJsonArrayString(seq: Seq[Any]): String = { @@ -287,7 +303,7 @@ private[sql] object JsonRDD extends Logging { element => if (count > 0) builder.append(",") count += 1 - builder.append(toStringValue(element)) + builder.append(toString(element)) } builder.append("]") @@ -302,31 +318,37 @@ private[sql] object JsonRDD extends Logging { case (key, value) => if (count > 0) builder.append(",") count += 1 - builder.append(s"""\"${key}\":${toStringValue(value)}""") + builder.append(s"""\"${key}\":${toString(value)}""") } builder.append("}") builder.toString() } - private def toStringValue: PartialFunction[Any, String] = { - def complexValueToStringValue: PartialFunction[Any, String] = { - case v: Map[String, Any] => toJsonObjectString(v) - case v: Seq[Any] => toJsonArrayString(v) + private def toString(value: Any): String = { + value match { + case value: Map[String, Any] => toJsonObjectString(value) + case value: Seq[Any] => toJsonArrayString(value) + case value => Option(value).map(_.toString).orNull } - - complexValueToStringValue orElse util.toStringValue } - private[json] def castToType: PartialFunction[(Any, DataType), Any] = { - def jsonSpecificCast: PartialFunction[(Any, DataType), Any] = { - case (v, StringType) => toStringValue(v) - case (v, DecimalType) => toDecimalValue(v) - case (v, ArrayType(elementType)) => - v.asInstanceOf[Seq[Any]].map(castToType(_, elementType)) + private[json] def enforceCorrectType(value: Any, desiredType: DataType): Any ={ + if (value == null) { + null + } else { + desiredType match { + case ArrayType(elementType) => + value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) + case StringType => toString(value) + case IntegerType => value.asInstanceOf[IntegerType.JvmType] + case LongType => toLong(value) + case DoubleType => toDouble(value) + case DecimalType => toDecimal(value) + case BooleanType => value.asInstanceOf[BooleanType.JvmType] + case NullType => null + } } - - jsonSpecificCast orElse util.castToType } private def asRow(json: Map[String,Any], schema: StructType): Row = { @@ -348,7 +370,7 @@ private[sql] object JsonRDD extends Logging { // Other cases case (StructField(name, dataType, _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( - castToType(_, dataType)).getOrElse(null)) + enforceCorrectType(_, dataType)).getOrElse(null)) } row diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/package.scala deleted file mode 100644 index 2be4d5cf53af2..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/package.scala +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import scala.math.BigDecimal - -import org.apache.spark.annotation.DeveloperApi - -package object util { - - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toBooleanValue: PartialFunction[Any, BooleanType.JvmType] = { - case v: BooleanType.JvmType => v - case v: ByteType.JvmType if v == 1 => true - case v: ByteType.JvmType if v == 0 => false - case v: ShortType.JvmType if v == 1 => true - case v: ShortType.JvmType if v == 0 => false - case v: IntegerType.JvmType if v == 1 => true - case v: IntegerType.JvmType if v == 0 => false - case v: LongType.JvmType if v == 1 => true - case v: LongType.JvmType if v == 0 => false - case v: StringType.JvmType if v.toLowerCase == "true" => true - case v: StringType.JvmType if v.toLowerCase == "false" => false - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toStringValue: PartialFunction[Any, StringType.JvmType] = { - case v => Option(v).map(_.toString).orNull - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toByteValue: PartialFunction[Any, ByteType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toByte else 0.toByte - case v: ByteType.JvmType => v - case v: StringType.JvmType => v.toByte - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toShortValue: PartialFunction[Any, ShortType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toShort else 0.toShort - case v: ByteType.JvmType => v.toShort - case v: ShortType.JvmType => v - case v: StringType.JvmType => v.toShort - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toIntegerValue: PartialFunction[Any, IntegerType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1 else 0 - case v: ByteType.JvmType => v.toInt - case v: ShortType.JvmType => v.toInt - case v: IntegerType.JvmType => v - case v: StringType.JvmType => v.toInt - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toLongValue: PartialFunction[Any, LongType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toLong else 0.toLong - case v: ByteType.JvmType => v.toLong - case v: ShortType.JvmType => v.toLong - case v: IntegerType.JvmType => v.toLong - case v: LongType.JvmType => v - // We can convert a Timestamp object to a Long because a Long representation of - // a Timestamp object has a clear meaning - // (milliseconds since January 1, 1970, 00:00:00 GMT). - case v: TimestampType.JvmType => v.getTime - case v: StringType.JvmType => v.toLong - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toFloatValue: PartialFunction[Any, FloatType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toFloat else 0.toFloat - case v: ByteType.JvmType => v.toFloat - case v: ShortType.JvmType => v.toFloat - case v: IntegerType.JvmType => v.toFloat - case v: FloatType.JvmType => v - case v: StringType.JvmType => v.toFloat - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toDoubleValue: PartialFunction[Any, DoubleType.JvmType] = { - case v: BooleanType.JvmType => if (v) 1.toDouble else 0.toDouble - case v: ByteType.JvmType => v.toDouble - case v: ShortType.JvmType => v.toDouble - case v: IntegerType.JvmType => v.toDouble - case v: LongType.JvmType => v.toDouble - case v: FloatType.JvmType => v.toDouble - case v: DoubleType.JvmType => v - case v: StringType.JvmType => v.toDouble - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toDecimalValue: PartialFunction[Any, DecimalType.JvmType] = { - case v: BooleanType.JvmType => if (v) BigDecimal(1) else BigDecimal(0) - case v: ByteType.JvmType => BigDecimal(v) - case v: ShortType.JvmType => BigDecimal(v) - case v: IntegerType.JvmType => BigDecimal(v) - case v: LongType.JvmType => BigDecimal(v) - case v: FloatType.JvmType => BigDecimal(v) - case v: DoubleType.JvmType => BigDecimal(v) - case v: TimestampType.JvmType => BigDecimal(v.getTime) - case v: StringType.JvmType => BigDecimal(v) - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def toTimestampValue: PartialFunction[Any, TimestampType.JvmType] = { - case v: LongType.JvmType => new java.sql.Timestamp(v) - case v: TimestampType.JvmType => v - case v: StringType.JvmType => java.sql.Timestamp.valueOf(v) - } - - /** - * :: DeveloperApi :: - */ - @DeveloperApi - def castToType: PartialFunction[(Any, DataType), Any] = { - case (null, _) => null - case (_, NullType) => null - case (v, BooleanType) => toBooleanValue(v) - case (v, StringType) => toStringValue(v) - case (v, ByteType) => toByteValue(v) - case (v, ShortType) => toShortValue(v) - case (v, IntegerType) => toIntegerValue(v) - case (v, LongType) => toLongValue(v) - case (v, FloatType) => toFloatValue(v) - case (v, DoubleType) => toDoubleValue(v) - case (v, DecimalType) => toDecimalValue(v) - case (v, TimestampType) => toTimestampValue(v) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 27391e6708076..e765cfc83a397 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.json.JsonRDD.{castToType, compatibleType} +import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.TestSQLContext._ @@ -41,19 +41,19 @@ class JsonSuite extends QueryTest { } val intNumber: Int = 2147483647 - checkTypePromotion(intNumber, castToType(intNumber, IntegerType)) - checkTypePromotion(intNumber.toLong, castToType(intNumber, LongType)) - checkTypePromotion(intNumber.toDouble, castToType(intNumber, DoubleType)) - checkTypePromotion(BigDecimal(intNumber), castToType(intNumber, DecimalType)) + checkTypePromotion(intNumber, enforceCorrectType(intNumber, IntegerType)) + checkTypePromotion(intNumber.toLong, enforceCorrectType(intNumber, LongType)) + checkTypePromotion(intNumber.toDouble, enforceCorrectType(intNumber, DoubleType)) + checkTypePromotion(BigDecimal(intNumber), enforceCorrectType(intNumber, DecimalType)) val longNumber: Long = 9223372036854775807L - checkTypePromotion(longNumber, castToType(longNumber, LongType)) - checkTypePromotion(longNumber.toDouble, castToType(longNumber, DoubleType)) - checkTypePromotion(BigDecimal(longNumber), castToType(longNumber, DecimalType)) + checkTypePromotion(longNumber, enforceCorrectType(longNumber, LongType)) + checkTypePromotion(longNumber.toDouble, enforceCorrectType(longNumber, DoubleType)) + checkTypePromotion(BigDecimal(longNumber), enforceCorrectType(longNumber, DecimalType)) val doubleNumber: Double = 1.7976931348623157E308d - checkTypePromotion(doubleNumber.toDouble, castToType(doubleNumber, DoubleType)) - checkTypePromotion(BigDecimal(doubleNumber), castToType(doubleNumber, DecimalType)) + checkTypePromotion(doubleNumber.toDouble, enforceCorrectType(doubleNumber, DoubleType)) + checkTypePromotion(BigDecimal(doubleNumber), enforceCorrectType(doubleNumber, DecimalType)) } test("Get compatible type") { From 7a6a7e5088322e70d06c1e991a6c5915b32279a0 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 10 Jul 2014 22:09:01 -0700 Subject: [PATCH 06/34] Fix bug introduced by the change made on SQLContext.inferSchema. --- .../main/scala/org/apache/spark/sql/SQLContext.scala | 10 ++++------ .../main/scala/org/apache/spark/sql/json/JsonRDD.scala | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 024dc337cd047..21ada3b859980 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -375,17 +375,15 @@ class SQLContext(@transient val sparkContext: SparkContext) import scala.collection.JavaConversions._ def typeOfComplexValue: PartialFunction[Any, DataType] = { case c: java.util.List[_] => - ArrayType(ScalaReflection.typeOfObject(c.head)) + ArrayType(typeOfObject(c.head)) case c: java.util.Set[_] => - ArrayType(ScalaReflection.typeOfObject(c.head)) + ArrayType(typeOfObject(c.head)) case c: java.util.Map[_, _] => val (key, value) = c.head - MapType( - ScalaReflection.typeOfObject(key), - ScalaReflection.typeOfObject(value)) + MapType(typeOfObject(key), typeOfObject(value)) case c if c.getClass.isArray => val elem = c.asInstanceOf[Array[_]].head - ArrayType(ScalaReflection.typeOfObject(elem)) + ArrayType(typeOfObject(elem)) case c => throw new Exception(s"Object of type $c cannot be used") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index bec741c96b678..da384ab6c673a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.Logging private[sql] object JsonRDD extends Logging { - def jsonStringToRow(schema: StructType, jsonIter: Iterator[String]): Iterator[Row] = { + private[sql] def jsonStringToRow(schema: StructType, jsonIter: Iterator[String]): Iterator[Row] = { parseJson(jsonIter).map(parsed => asRow(parsed, schema)) } From 949d6bbae66ecbb9c11a0e9d0d14f09f51f5c46a Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 11 Jul 2014 10:22:29 -0700 Subject: [PATCH 07/34] When creating a SchemaRDD for a JSON dataset, users can apply an existing schema. --- .../org/apache/spark/sql/SQLContext.scala | 35 +++++++++++++++---- .../org/apache/spark/sql/json/JsonRDD.scala | 4 ++- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 21ada3b859980..628f7cb84c61f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -128,15 +128,23 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0, None) + + /** + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a [[SchemaRDD]]. + * + * @group userf + */ + def jsonFile(path: String, schema: StructType): SchemaRDD = jsonFile(path, 1.0, Option(schema)) /** * :: Experimental :: */ @Experimental - def jsonFile(path: String, samplingRatio: Double): SchemaRDD = { + def jsonFile(path: String, samplingRatio: Double, schema: Option[StructType]): SchemaRDD = { val json = sparkContext.textFile(path) - jsonRDD(json, samplingRatio) + jsonRDD(json, samplingRatio, schema) } /** @@ -146,15 +154,28 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0, None) + + /** + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a [[SchemaRDD]]. + * + * @group userf + */ + def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = jsonRDD(json, 1.0, Option(schema)) /** * :: Experimental :: */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { - val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) - applySchemaToPartitions(json, schema, JsonRDD.jsonStringToRow(schema, _: Iterator[String])) + def jsonRDD(json: RDD[String], samplingRatio: Double, schema: Option[StructType]): SchemaRDD = { + val appliedSchema = + schema.getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio))) + + applySchemaToPartitions( + json, + appliedSchema, + JsonRDD.jsonStringToRow(appliedSchema, _: Iterator[String])) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index da384ab6c673a..4fd55ff13dcce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -31,7 +31,9 @@ import org.apache.spark.sql.Logging private[sql] object JsonRDD extends Logging { - private[sql] def jsonStringToRow(schema: StructType, jsonIter: Iterator[String]): Iterator[Row] = { + private[sql] def jsonStringToRow( + schema: StructType, + jsonIter: Iterator[String]): Iterator[Row] = { parseJson(jsonIter).map(parsed => asRow(parsed, schema)) } From eca7d04db5e3240b046bfaab09d0b1e9b9546f51 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 11 Jul 2014 12:13:57 -0700 Subject: [PATCH 08/34] Add two apply methods which will be used to extract StructField(s) from a StructType. --- .../org/apache/spark/sql/catalyst/types/dataTypes.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 3627577aac064..8e73e5afd0ed1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -292,6 +292,15 @@ object StructType { case class StructType(fields: Seq[StructField]) extends DataType { require(StructType.validateFields(fields), "Found fields with the same name.") + def apply(name: String): StructField = { + fields.find(f => f.name == name).orNull + } + + def apply(names: String*): StructType = { + val nameSet = names.toSet + StructType(fields.filter(f => nameSet.contains(f.name))) + } + def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) def schemaString: String = { From 9168b833b542b1f03f352862a6dc4f1f5586a327 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 11 Jul 2014 12:23:11 -0700 Subject: [PATCH 09/34] Update comments. --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8f7ecc8c46ff4..0d26b52a84695 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -86,7 +86,7 @@ object ScalaReflection { } def typeOfObject: PartialFunction[Any, DataType] = { - // The type of the can be determined without ambiguity. + // The data type can be determined without ambiguity. case obj: BooleanType.JvmType => BooleanType case obj: BinaryType.JvmType => BinaryType case obj: StringType.JvmType => StringType @@ -99,10 +99,10 @@ object ScalaReflection { case obj: DecimalType.JvmType => DecimalType case obj: TimestampType.JvmType => TimestampType case null => NullType - // There is no obvious mapping from the type of the given object to a Catalyst data type. - // A user should provide his/her specific rules (in a user-defined PartialFunction) to infer - // the Catalyst data type for other types of objects and then compose the user-defined - // PartialFunction with this one. + // For other cases, there is no obvious mapping from the type of the given object to a + // Catalyst data type. A user should provide his/her specific rules + // (in a user-defined PartialFunction) to infer the Catalyst data type for other types of + // objects and then compose the user-defined PartialFunction with this one. } implicit class CaseClassRelation[A <: Product : TypeTag](data: Seq[A]) { From dcaf22fe0e96fc8946686563a1409d535b572653 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Fri, 11 Jul 2014 18:20:37 -0700 Subject: [PATCH 10/34] Add a field containsNull to ArrayType to indicate if an array can contain null values or not. If an ArrayType is constructed by "ArrayType(elementType)" (the existing constructor), the value of containsNull is false. --- .../catalyst/expressions/complexTypes.scala | 2 +- .../sql/catalyst/expressions/generators.scala | 4 +-- .../plans/logical/basicOperators.scala | 2 +- .../spark/sql/catalyst/types/dataTypes.scala | 13 +++++++-- .../org/apache/spark/sql/SQLContext.scala | 2 +- .../org/apache/spark/sql/SchemaRDD.scala | 2 +- .../org/apache/spark/sql/json/JsonRDD.scala | 29 ++++++++++--------- .../spark/sql/parquet/ParquetConverter.scala | 4 +-- .../sql/parquet/ParquetTableSupport.scala | 2 +- .../spark/sql/parquet/ParquetTypes.scala | 14 ++++----- .../org/apache/spark/sql/json/JsonSuite.scala | 16 ++++++++-- .../apache/spark/sql/hive/HiveContext.scala | 4 +-- .../spark/sql/hive/HiveMetastoreCatalog.scala | 6 ++-- .../org/apache/spark/sql/hive/hiveUdfs.scala | 3 +- 14 files changed, 64 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index 5d3bb25ad568c..f13a6d5f98382 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -31,7 +31,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { override def foldable = child.foldable && ordinal.foldable override def references = children.flatMap(_.references).toSet def dataType = child.dataType match { - case ArrayType(dt) => dt + case ArrayType(dt, _) => dt case MapType(_, vt) => vt } override lazy val resolved = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index dd78614754e12..0a8d4dd718329 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -84,7 +84,7 @@ case class Explode(attributeNames: Seq[String], child: Expression) (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) private lazy val elementTypes = child.dataType match { - case ArrayType(et) => et :: Nil + case ArrayType(et, _) => et :: Nil case MapType(kt,vt) => kt :: vt :: Nil } @@ -102,7 +102,7 @@ case class Explode(attributeNames: Seq[String], child: Expression) override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { - case ArrayType(_) => + case ArrayType(_, _) => val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) case MapType(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 1537de259c5b4..3cb407217c4c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -177,7 +177,7 @@ case class LowerCaseSchema(child: LogicalPlan) extends UnaryNode { case StructType(fields) => StructType(fields.map(f => StructField(f.name.toLowerCase(), lowerCaseSchema(f.dataType), f.nullable))) - case ArrayType(elemType) => ArrayType(lowerCaseSchema(elemType)) + case ArrayType(elemType, containsNull) => ArrayType(lowerCaseSchema(elemType), containsNull) case otherType => otherType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 8e73e5afd0ed1..4f7bc23a7412e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -45,7 +45,9 @@ object DataType extends RegexParsers { "TimestampType" ^^^ TimestampType protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType <~ ")" ^^ ArrayType + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } protected lazy val mapType: Parser[DataType] = "MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ { @@ -241,9 +243,14 @@ case object FloatType extends FractionalType { def simpleString: String = "float" } -case class ArrayType(elementType: DataType) extends DataType { +object ArrayType { + def apply(elementType: DataType): ArrayType = ArrayType(elementType, false) +} + +case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { - builder.append(s"${prefix}-- element: ${elementType.simpleString}\n") + builder.append( + s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") elementType match { case array: ArrayType => array.buildFormattedString(s"$prefix |", builder) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 628f7cb84c61f..355d545cad89e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -93,7 +93,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def applySchema[A](rdd: RDD[A],schema: StructType, f: A => Row): SchemaRDD = + def applySchema[A](rdd: RDD[A], schema: StructType, f: A => Row): SchemaRDD = applySchemaToPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f)) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 4e8faddedfe87..723c78c596646 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -382,7 +382,7 @@ class SchemaRDD( case (obj, (name, dataType)) => dataType match { case struct: StructType => map.put(name, rowToMap(obj.asInstanceOf[Row], struct)) - case array @ ArrayType(struct: StructType) => + case array @ ArrayType(struct: StructType, _) => val arrayValues = obj match { case seq: Seq[Any] => seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 4fd55ff13dcce..913d2368b82f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -68,8 +68,8 @@ private[sql] object JsonRDD extends Logging { val (topLevel, structLike) = values.partition(_.size == 1) val topLevelFields = topLevel.filter { name => resolved.get(prefix ++ name).get match { - case ArrayType(StructType(Nil)) => false - case ArrayType(_) => true + case ArrayType(StructType(Nil), _) => false + case ArrayType(_, _) => true case struct: StructType => false case _ => true } @@ -83,7 +83,8 @@ private[sql] object JsonRDD extends Logging { val structType = makeStruct(nestedFields, prefix :+ name) val dataType = resolved.get(prefix :+ name).get dataType match { - case array: ArrayType => Some(StructField(name, ArrayType(structType), nullable = true)) + case array: ArrayType => + Some(StructField(name, ArrayType(structType, array.containsNull), nullable = true)) case struct: StructType => Some(StructField(name, structType, nullable = true)) // dataType is StringType means that we have resolved type conflicts involving // primitive types and complex types. So, the type of name has been relaxed to @@ -107,7 +108,7 @@ private[sql] object JsonRDD extends Logging { case StructField(fieldName, dataType, nullable) => { val newType = dataType match { case NullType => StringType - case ArrayType(NullType) => ArrayType(StringType) + case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) case struct: StructType => nullTypeToStringType(struct) case other: DataType => other } @@ -148,8 +149,8 @@ private[sql] object JsonRDD extends Logging { case StructField(name, _, _) => name }) } - case (ArrayType(elementType1), ArrayType(elementType2)) => - ArrayType(compatibleType(elementType1, elementType2)) + case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) // TODO: We should use JsonObjectStringType to mark that values of field will be // strings and every string is a Json object. case (_, _) => StringType @@ -176,12 +177,13 @@ private[sql] object JsonRDD extends Logging { * treat the element as String. */ private def typeOfArray(l: Seq[Any]): ArrayType = { + val containsNull = l.exists(v => v == null) val elements = l.flatMap(v => Option(v)) if (elements.isEmpty) { // If this JSON array is empty, we use NullType as a placeholder. // If this array is not empty in other JSON objects, we can resolve // the type after we have passed through all JSON objects. - ArrayType(NullType) + ArrayType(NullType, containsNull) } else { val elementType = elements.map { e => e match { @@ -193,7 +195,7 @@ private[sql] object JsonRDD extends Logging { } }.reduce((type1: DataType, type2: DataType) => compatibleType(type1, type2)) - ArrayType(elementType) + ArrayType(elementType, containsNull) } } @@ -220,15 +222,16 @@ private[sql] object JsonRDD extends Logging { case (key: String, array: List[Any]) => { // The value associted with the key is an array. typeOfArray(array) match { - case ArrayType(StructType(Nil)) => { + case ArrayType(StructType(Nil), containsNull) => { // The elements of this arrays are structs. array.asInstanceOf[List[Map[String, Any]]].flatMap { element => allKeysWithValueTypes(element) }.map { case (k, dataType) => (s"$key.$k", dataType) - } :+ (key, ArrayType(StructType(Nil))) + } :+ (key, ArrayType(StructType(Nil), containsNull)) } - case ArrayType(elementType) => (key, ArrayType(elementType)) :: Nil + case ArrayType(elementType, containsNull) => + (key, ArrayType(elementType, containsNull)) :: Nil } } case (key: String, value) => (key, typeOfPrimitiveValue(value)) :: Nil @@ -340,7 +343,7 @@ private[sql] object JsonRDD extends Logging { null } else { desiredType match { - case ArrayType(elementType) => + case ArrayType(elementType, _) => value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType)) case StringType => toString(value) case IntegerType => value.asInstanceOf[IntegerType.JvmType] @@ -363,7 +366,7 @@ private[sql] object JsonRDD extends Logging { v => asRow(v.asInstanceOf[Map[String, Any]], fields)).orNull) // ArrayType(StructType) - case (StructField(name, ArrayType(structType: StructType), _), i) => + case (StructField(name, ArrayType(structType: StructType, _), _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( v => v.asInstanceOf[Seq[Any]].map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 75748b2b54400..4f1874a251f67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -75,11 +75,11 @@ private[sql] object CatalystConverter { val fieldType: DataType = field.dataType fieldType match { // For native JVM types we use a converter with native arrays - case ArrayType(elementType: NativeType) => { + case ArrayType(elementType: NativeType, false) => { new CatalystNativeArrayConverter(elementType, fieldIndex, parent) } // This is for other types of arrays, including those with nested fields - case ArrayType(elementType: DataType) => { + case ArrayType(elementType: DataType, false) => { new CatalystArrayConverter(elementType, fieldIndex, parent) } case StructType(fields: Seq[StructField]) => { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 108f8b6815423..f5e72801963f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -169,7 +169,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { private[parquet] def writeValue(schema: DataType, value: Any): Unit = { if (value != null) { schema match { - case t @ ArrayType(_) => writeArray( + case t @ ArrayType(_, false) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) case t @ MapType(_, _) => writeMap( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index f9046368e7ced..4fc0dfb8554a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -113,7 +113,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - new ArrayType(toDataType(field)) + ArrayType(toDataType(field), false) } case ParquetOriginalType.MAP => { assert( @@ -127,7 +127,7 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - new MapType(keyType, valueType) + MapType(keyType, valueType) } case _ => { // Note: the order of these checks is important! @@ -137,10 +137,10 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) - new MapType(keyType, valueType) + MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType val elementType = toDataType(groupType.getFields.apply(0)) - new ArrayType(elementType) + ArrayType(elementType, false) } else { // everything else: StructType val fields = groupType .getFields @@ -148,7 +148,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ptype.getName, toDataType(ptype), ptype.getRepetition != Repetition.REQUIRED)) - new StructType(fields) + StructType(fields) } } } @@ -168,7 +168,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case StringType => Some(ParquetPrimitiveTypeName.BINARY) case BooleanType => Some(ParquetPrimitiveTypeName.BOOLEAN) case DoubleType => Some(ParquetPrimitiveTypeName.DOUBLE) - case ArrayType(ByteType) => + case ArrayType(ByteType, false) => Some(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY) case FloatType => Some(ParquetPrimitiveTypeName.FLOAT) case IntegerType => Some(ParquetPrimitiveTypeName.INT32) @@ -231,7 +231,7 @@ private[parquet] object ParquetTypesConverter extends Logging { new ParquetPrimitiveType(repetition, primitiveType.get, name) } else { ctype match { - case ArrayType(elementType) => { + case ArrayType(elementType, false) => { val parquetElementType = fromDataType( elementType, CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index e765cfc83a397..6a780c5dfdf5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -127,6 +127,18 @@ class JsonSuite extends QueryTest { checkDataType(ArrayType(IntegerType), ArrayType(LongType), ArrayType(LongType)) checkDataType(ArrayType(IntegerType), ArrayType(StringType), ArrayType(StringType)) checkDataType(ArrayType(IntegerType), StructType(Nil), StringType) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, false), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, true), ArrayType(IntegerType, true), ArrayType(IntegerType, true)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType), ArrayType(IntegerType, false)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType, false)) + checkDataType( + ArrayType(IntegerType, false), ArrayType(IntegerType, false), ArrayType(IntegerType)) // StructType checkDataType(StructType(Nil), StructType(Nil), StructType(Nil)) @@ -200,7 +212,7 @@ class JsonSuite extends QueryTest { AttributeReference("arrayOfDouble", ArrayType(DoubleType), true)() :: AttributeReference("arrayOfInteger", ArrayType(IntegerType), true)() :: AttributeReference("arrayOfLong", ArrayType(LongType), true)() :: - AttributeReference("arrayOfNull", ArrayType(StringType), true)() :: + AttributeReference("arrayOfNull", ArrayType(StringType, true), true)() :: AttributeReference("arrayOfString", ArrayType(StringType), true)() :: AttributeReference("arrayOfStruct", ArrayType( StructType(StructField("field1", BooleanType, true) :: @@ -451,7 +463,7 @@ class JsonSuite extends QueryTest { val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) val expectedSchema = - AttributeReference("array1", ArrayType(StringType), true)() :: + AttributeReference("array1", ArrayType(StringType, true), true)() :: AttributeReference("array2", ArrayType(StructType( StructField("field", LongType, true) :: Nil)), true)() :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 02a3dee67b464..93cb0360fa9fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -257,7 +257,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ)) => + case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_,_], MapType(kType, vType)) => map.map { @@ -274,7 +274,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { struct.zip(fields).map { case (v, t) => s""""${t.name}":${toHiveStructString(v, t.dataType)}""" }.mkString("{", ",", "}") - case (seq: Seq[_], ArrayType(typ)) => + case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") case (map: Map[_,_], MapType(kType, vType)) => map.map { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f83068860701f..26681c63b4c34 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -199,7 +199,9 @@ object HiveMetastoreTypes extends RegexParsers { "varchar\\((\\d+)\\)".r ^^^ StringType protected lazy val arrayType: Parser[DataType] = - "array" ~> "<" ~> dataType <~ ">" ^^ ArrayType + "array" ~> "<" ~> dataType <~ ">" ^^ { + case tpe => ArrayType(tpe) + } protected lazy val mapType: Parser[DataType] = "map" ~> "<" ~> dataType ~ "," ~ dataType <~ ">" ^^ { @@ -228,7 +230,7 @@ object HiveMetastoreTypes extends RegexParsers { } def toMetastoreType(dt: DataType): String = dt match { - case ArrayType(elementType) => s"array<${toMetastoreType(elementType)}>" + case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" case MapType(keyType, valueType) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 9b105308ab7cf..ae0e3cb0a2c89 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -319,7 +319,8 @@ private[hive] trait HiveInspectors { } def toInspector(dataType: DataType): ObjectInspector = dataType match { - case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) + case ArrayType(tpe, _) => + ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) case MapType(keyType, valueType) => ObjectInspectorFactory.getStandardMapObjectInspector( toInspector(keyType), toInspector(valueType)) From 32091087c479d5c71b093072a9a3a2d4c37aad78 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 15 Jul 2014 18:36:56 -0700 Subject: [PATCH 11/34] Add unit tests. --- .../org/apache/spark/sql/SchemaSuite.scala | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala new file mode 100644 index 0000000000000..c1e1b5333927d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala @@ -0,0 +1,49 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +class SchemaSuite extends FunSuite { + + test("constructing an ArrayType") { + val array = ArrayType(StringType) + + assert(ArrayType(StringType, false) === array) + } + + test("extracting fields from a StructType") { + val struct = StructType( + StructField("a", IntegerType, true) :: + StructField("b", LongType, false) :: + StructField("c", StringType, true) :: + StructField("d", FloatType, true) :: Nil) + + assert(StructField("b", LongType, false) === struct("b")) + + assert(struct("e") === null) + + val expectedStruct = StructType( + StructField("b", LongType, false) :: + StructField("d", FloatType, true) :: Nil) + + assert(expectedStruct === struct(Set("b", "d"))) + // struct does not have a field called e. So e is ignored. + assert(expectedStruct === struct(Set("b", "d", "e"))) + } +} From 68525a25687e0be091a709a0bb49b5a6812ba081 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 15 Jul 2014 18:37:16 -0700 Subject: [PATCH 12/34] Update JSON unit test. --- .../org/apache/spark/sql/json/JsonSuite.scala | 184 +++++++++++------- 1 file changed, 116 insertions(+), 68 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 6a780c5dfdf5a..10ad4e2e3dd7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -25,8 +25,6 @@ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.TestSQLContext._ -protected case class Schema(output: Seq[Attribute]) extends LeafNode - class JsonSuite extends QueryTest { import TestJsonData._ TestJsonData @@ -176,16 +174,16 @@ class JsonSuite extends QueryTest { test("Primitive field and type inferring") { val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) - val expectedSchema = - AttributeReference("bigInteger", DecimalType, true)() :: - AttributeReference("boolean", BooleanType, true)() :: - AttributeReference("double", DoubleType, true)() :: - AttributeReference("integer", IntegerType, true)() :: - AttributeReference("long", LongType, true)() :: - AttributeReference("null", StringType, true)() :: - AttributeReference("string", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -204,27 +202,28 @@ class JsonSuite extends QueryTest { test("Complex field and type inferring") { val jsonSchemaRDD = jsonRDD(complexFieldAndType) - val expectedSchema = - AttributeReference("arrayOfArray1", ArrayType(ArrayType(StringType)), true)() :: - AttributeReference("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true)() :: - AttributeReference("arrayOfBigInteger", ArrayType(DecimalType), true)() :: - AttributeReference("arrayOfBoolean", ArrayType(BooleanType), true)() :: - AttributeReference("arrayOfDouble", ArrayType(DoubleType), true)() :: - AttributeReference("arrayOfInteger", ArrayType(IntegerType), true)() :: - AttributeReference("arrayOfLong", ArrayType(LongType), true)() :: - AttributeReference("arrayOfNull", ArrayType(StringType, true), true)() :: - AttributeReference("arrayOfString", ArrayType(StringType), true)() :: - AttributeReference("arrayOfStruct", ArrayType( - StructType(StructField("field1", BooleanType, true) :: - StructField("field2", StringType, true) :: Nil)), true)() :: - AttributeReference("struct", StructType( - StructField("field1", BooleanType, true) :: - StructField("field2", DecimalType, true) :: Nil), true)() :: - AttributeReference("structWithArrayFields", StructType( + val expectedSchema = StructType( + StructField("arrayOfArray1", ArrayType(ArrayType(StringType)), true) :: + StructField("arrayOfArray2", ArrayType(ArrayType(DoubleType)), true) :: + StructField("arrayOfBigInteger", ArrayType(DecimalType), true) :: + StructField("arrayOfBoolean", ArrayType(BooleanType), true) :: + StructField("arrayOfDouble", ArrayType(DoubleType), true) :: + StructField("arrayOfInteger", ArrayType(IntegerType), true) :: + StructField("arrayOfLong", ArrayType(LongType), true) :: + StructField("arrayOfNull", ArrayType(StringType, true), true) :: + StructField("arrayOfString", ArrayType(StringType), true) :: + StructField("arrayOfStruct", ArrayType( + StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", StringType, true) :: Nil)), true) :: + StructField("struct", StructType( + StructField("field1", BooleanType, true) :: + StructField("field2", DecimalType, true) :: Nil), true) :: + StructField("structWithArrayFields", StructType( StructField("field1", ArrayType(IntegerType), true) :: - StructField("field2", ArrayType(StringType), true) :: Nil), true)() :: Nil + StructField("field2", ArrayType(StringType), true) :: Nil), true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -313,15 +312,15 @@ class JsonSuite extends QueryTest { test("Type conflict in primitive field values") { val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - val expectedSchema = - AttributeReference("num_bool", StringType, true)() :: - AttributeReference("num_num_1", LongType, true)() :: - AttributeReference("num_num_2", DecimalType, true)() :: - AttributeReference("num_num_3", DoubleType, true)() :: - AttributeReference("num_str", StringType, true)() :: - AttributeReference("str_bool", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("num_bool", StringType, true) :: + StructField("num_num_1", LongType, true) :: + StructField("num_num_2", DecimalType, true) :: + StructField("num_num_3", DoubleType, true) :: + StructField("num_str", StringType, true) :: + StructField("str_bool", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -438,15 +437,15 @@ class JsonSuite extends QueryTest { test("Type conflict in complex field values") { val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) - val expectedSchema = - AttributeReference("array", ArrayType(IntegerType), true)() :: - AttributeReference("num_struct", StringType, true)() :: - AttributeReference("str_array", StringType, true)() :: - AttributeReference("struct", StructType( - StructField("field", StringType, true) :: Nil), true)() :: - AttributeReference("struct_array", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("array", ArrayType(IntegerType), true) :: + StructField("num_struct", StringType, true) :: + StructField("str_array", StringType, true) :: + StructField("struct", StructType( + StructField("field", StringType, true) :: Nil), true) :: + StructField("struct_array", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -462,12 +461,12 @@ class JsonSuite extends QueryTest { test("Type conflict in array elements") { val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) - val expectedSchema = - AttributeReference("array1", ArrayType(StringType, true), true)() :: - AttributeReference("array2", ArrayType(StructType( - StructField("field", LongType, true) :: Nil)), true)() :: Nil + val expectedSchema = StructType( + StructField("array1", ArrayType(StringType, true), true) :: + StructField("array2", ArrayType(StructType( + StructField("field", LongType, true) :: Nil)), true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -487,15 +486,15 @@ class JsonSuite extends QueryTest { test("Handling missing fields") { val jsonSchemaRDD = jsonRDD(missingFields) - val expectedSchema = - AttributeReference("a", BooleanType, true)() :: - AttributeReference("b", LongType, true)() :: - AttributeReference("c", ArrayType(IntegerType), true)() :: - AttributeReference("d", StructType( - StructField("field", BooleanType, true) :: Nil), true)() :: - AttributeReference("e", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("a", BooleanType, true) :: + StructField("b", LongType, true) :: + StructField("c", ArrayType(IntegerType), true) :: + StructField("d", StructType( + StructField("field", BooleanType, true) :: Nil), true) :: + StructField("e", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") } @@ -506,16 +505,16 @@ class JsonSuite extends QueryTest { primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) val jsonSchemaRDD = jsonFile(path) - val expectedSchema = - AttributeReference("bigInteger", DecimalType, true)() :: - AttributeReference("boolean", BooleanType, true)() :: - AttributeReference("double", DoubleType, true)() :: - AttributeReference("integer", IntegerType, true)() :: - AttributeReference("long", LongType, true)() :: - AttributeReference("null", StringType, true)() :: - AttributeReference("string", StringType, true)() :: Nil + val expectedSchema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) - comparePlans(Schema(expectedSchema), Schema(jsonSchemaRDD.logicalPlan.output)) + assert(expectedSchema === jsonSchemaRDD.schema) jsonSchemaRDD.registerAsTable("jsonTable") @@ -530,4 +529,53 @@ class JsonSuite extends QueryTest { "this is a simple string.") :: Nil ) } + + test("Applying schemas") { + val file = getTempFilePath("json") + val path = file.toString + primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) + + val schema = StructType( + StructField("bigInteger", DecimalType, true) :: + StructField("boolean", BooleanType, true) :: + StructField("double", DoubleType, true) :: + StructField("integer", IntegerType, true) :: + StructField("long", LongType, true) :: + StructField("null", StringType, true) :: + StructField("string", StringType, true) :: Nil) + + val jsonSchemaRDD1 = jsonFile(path, schema) + + assert(schema === jsonSchemaRDD1.schema) + + jsonSchemaRDD1.registerAsTable("jsonTable1") + + checkAnswer( + sql("select * from jsonTable1"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + + val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) + + assert(schema === jsonSchemaRDD2.schema) + + jsonSchemaRDD2.registerAsTable("jsonTable2") + + checkAnswer( + sql("select * from jsonTable2"), + (BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.") :: Nil + ) + } } From b8b7db44c052aa959581be62aa4f2e44505e4289 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 15 Jul 2014 18:40:06 -0700 Subject: [PATCH 13/34] 1. Move sql package object and package-info to sql-core. 2. Minor updates on APIs. 3. Update scala doc. --- .../catalyst/expressions/BoundAttribute.scala | 3 +- .../spark/sql/catalyst/expressions/Row.scala | 2 +- .../sql/catalyst/planning/QueryPlanner.scala | 3 +- .../sql/catalyst/planning/patterns.scala | 2 +- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../spark/sql/catalyst/rules/Rule.scala | 3 +- .../sql/catalyst/rules/RuleExecutor.scala | 6 +- .../spark/sql/catalyst/trees/package.scala | 5 +- .../spark/sql/catalyst/types/dataTypes.scala | 89 ++--- .../scala/org/apache/spark/sql/package.scala | 80 ---- .../org/apache/spark/sql/SQLContext.scala | 37 +- .../org/apache/spark/sql/SchemaRDDLike.scala | 17 +- .../org/apache/spark/sql/package-info.java | 0 .../scala/org/apache/spark/sql/package.scala | 350 ++++++++++++++++++ 14 files changed, 438 insertions(+), 161 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala rename sql/{catalyst => core}/src/main/scala/org/apache/spark/sql/package-info.java (100%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/package.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 9ce1f01056462..cbc214d442064 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import com.typesafe.scalalogging.slf4j.Logging + import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.Logging /** * A bound reference points to a specific slot in the input tuple, allowing the actual value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index ff10e198a3cee..b8f810447862f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -34,7 +34,7 @@ object Row { def unapplySeq(row: Row): Some[Seq[Any]] = Some(row) /** - * Construct a [[Row]] with the given values. + * This method can be used to construct a [[Row]] with the given values. */ def apply(values: Any*): Row = new GenericRow(values.toArray) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 67833664b35ae..4ff5791635f4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.sql.Logging +import com.typesafe.scalalogging.slf4j.Logging + import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 026692abe067d..b8ae326be6fab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec -import org.apache.spark.sql.Logging +import com.typesafe.scalalogging.slf4j.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0320682d47ce5..1e8fe098f7c1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -128,7 +128,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ - def schemaString: String = schema.schemaString + def schemaString: String = schema.structString /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index 1076537bc7602..f39bff8c25164 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.sql.Logging +import com.typesafe.scalalogging.slf4j.Logging + import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e32adb76fe146..e70ce66cb745f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql -package catalyst -package rules +package org.apache.spark.sql.catalyst.rules + +import com.typesafe.scalalogging.slf4j.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala index d159ecdd5d781..9a28d035a10a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.Logger - /** * A library for easily manipulating trees of operators. Operators that extend TreeNode are * granted the following interface: @@ -35,5 +33,6 @@ import org.apache.spark.sql.Logger */ package object trees { // Since we want tree nodes to be lightweight, we create one logger for all treenode instances. - protected val logger = Logger("catalyst.trees") + protected val logger = + com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger("catalyst.trees")) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 4f7bc23a7412e..e07db00b749c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.Utils /** * */ -object DataType extends RegexParsers { +protected[sql] object DataType extends RegexParsers { protected lazy val primitiveType: Parser[DataType] = "StringType" ^^^ StringType | "FloatType" ^^^ FloatType | @@ -84,6 +84,21 @@ object DataType extends RegexParsers { case Success(result, _) => result case failure: NoSuccess => sys.error(s"Unsupported dataType: $asString, $failure") } + + protected[types] def buildFormattedString( + dataType: DataType, + prefix: String, + builder: StringBuilder): Unit = { + dataType match { + case array: ArrayType => + array.buildFormattedString(prefix, builder) + case struct: StructType => + struct.buildFormattedString(prefix, builder) + case map: MapType => + map.buildFormattedString(prefix, builder) + case _ => + } + } } abstract class DataType { @@ -244,6 +259,7 @@ case object FloatType extends FractionalType { } object ArrayType { + /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is false. */ def apply(elementType: DataType): ArrayType = ArrayType(elementType, false) } @@ -251,15 +267,7 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( s"${prefix}-- element: ${elementType.simpleString} (containsNull = ${containsNull})\n") - elementType match { - case array: ArrayType => - array.buildFormattedString(s"$prefix |", builder) - case struct: StructType => - struct.buildFormattedString(s"$prefix |", builder) - case map: MapType => - map.buildFormattedString(s"$prefix |", builder) - case _ => - } + DataType.buildFormattedString(elementType, s"$prefix |", builder) } def simpleString: String = "array" @@ -269,48 +277,41 @@ case class StructField(name: String, dataType: DataType, nullable: Boolean) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n") - dataType match { - case array: ArrayType => - array.buildFormattedString(s"$prefix |", builder) - case struct: StructType => - struct.buildFormattedString(s"$prefix |", builder) - case map: MapType => - map.buildFormattedString(s"$prefix |", builder) - case _ => - } + DataType.buildFormattedString(dataType, s"$prefix |", builder) } } object StructType { - def fromAttributes(attributes: Seq[Attribute]): StructType = + protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) private def validateFields(fields: Seq[StructField]): Boolean = fields.map(field => field.name).distinct.size == fields.size - - def apply[A <: String: ClassTag, B <: DataType: ClassTag](fields: (A, B)*): StructType = - StructType(fields.map(field => StructField(field._1, field._2, true))) - - def apply[A <: String: ClassTag, B <: DataType: ClassTag, C <: Boolean: ClassTag]( - fields: (A, B, C)*): StructType = - StructType(fields.map(field => StructField(field._1, field._2, field._3))) } case class StructType(fields: Seq[StructField]) extends DataType { require(StructType.validateFields(fields), "Found fields with the same name.") + /** + * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not + * have a name matching the given name, `null` will be returned. + */ def apply(name: String): StructField = { fields.find(f => f.name == name).orNull } - def apply(names: String*): StructType = { - val nameSet = names.toSet - StructType(fields.filter(f => nameSet.contains(f.name))) + /** + * Returns a [[StructType]] containing [[StructField]]s of the given names. + * Those names which do not have matching fields will be ignored. + */ + def apply(names: Set[String]): StructType = { + StructType(fields.filter(f => names.contains(f.name))) } - def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + protected[sql] def toAttributes = + fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) - def schemaString: String = { + def structString: String = { val builder = new StringBuilder builder.append("root\n") val prefix = " |" @@ -319,7 +320,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { builder.toString() } - def printSchema(): Unit = println(schemaString) + def printStruct(): Unit = println(structString) private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { fields.foreach(field => field.buildFormattedString(prefix, builder)) @@ -331,26 +332,8 @@ case class StructType(fields: Seq[StructField]) extends DataType { case class MapType(keyType: DataType, valueType: DataType) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") - keyType match { - case array: ArrayType => - array.buildFormattedString(s"$prefix |", builder) - case struct: StructType => - struct.buildFormattedString(s"$prefix |", builder) - case map: MapType => - map.buildFormattedString(s"$prefix |", builder) - case _ => - } - - builder.append(s"${prefix}-- value: ${valueType.simpleString}\n") - valueType match { - case array: ArrayType => - array.buildFormattedString(s"$prefix |", builder) - case struct: StructType => - struct.buildFormattedString(s"$prefix |", builder) - case map: MapType => - map.buildFormattedString(s"$prefix |", builder) - case _ => - } + DataType.buildFormattedString(keyType, s"$prefix |", builder) + DataType.buildFormattedString(valueType, s"$prefix |", builder) } def simpleString: String = "map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala deleted file mode 100644 index 2099804073c08..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/package.scala +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -/** - * Allows the execution of relational queries, including those expressed in SQL using Spark. - * - * Note that this package is located in catalyst instead of in core so that all subprojects can - * inherit the settings from this package object. - */ -package object sql { - - protected[sql] def Logger(name: String) = - com.typesafe.scalalogging.slf4j.Logger(org.slf4j.LoggerFactory.getLogger(name)) - - protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging - - type Row = catalyst.expressions.Row - - val Row = catalyst.expressions.Row - - type DataType = catalyst.types.DataType - - val DataType = catalyst.types.DataType - - val NullType = catalyst.types.NullType - - val StringType = catalyst.types.StringType - - val BinaryType = catalyst.types.BinaryType - - val BooleanType = catalyst.types.BooleanType - - val TimestampType = catalyst.types.TimestampType - - val DecimalType = catalyst.types.DecimalType - - val DoubleType = catalyst.types.DoubleType - - val FloatType = catalyst.types.FloatType - - val ByteType = catalyst.types.ByteType - - val IntegerType = catalyst.types.IntegerType - - val LongType = catalyst.types.LongType - - val ShortType = catalyst.types.ShortType - - type ArrayType = catalyst.types.ArrayType - - val ArrayType = catalyst.types.ArrayType - - type MapType = catalyst.types.MapType - - val MapType = catalyst.types.MapType - - type StructType = catalyst.types.StructType - - val StructType = catalyst.types.StructType - - type StructField = catalyst.types.StructField - - val StructField = catalyst.types.StructField -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 355d545cad89e..197942c7b0f66 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -99,7 +99,9 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function * that will be applied to each partition of the RDD to convert RDD records to [[Row]]s. - * + * Similar to `RDD.mapPartitions``, this function can be used to improve performance where there + * is other setup work that can be amortized and used repeatedly for all of the + * elements in a partition. * @group userf */ def applySchemaToPartitions[A]( @@ -128,7 +130,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0, None) + def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) /** * Loads a JSON file (one object per line) and applies the given schema, @@ -136,15 +138,18 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def jsonFile(path: String, schema: StructType): SchemaRDD = jsonFile(path, 1.0, Option(schema)) + def jsonFile(path: String, schema: StructType): SchemaRDD = { + val json = sparkContext.textFile(path) + jsonRDD(json, schema) + } /** * :: Experimental :: */ @Experimental - def jsonFile(path: String, samplingRatio: Double, schema: Option[StructType]): SchemaRDD = { + def jsonFile(path: String, samplingRatio: Double): SchemaRDD = { val json = sparkContext.textFile(path) - jsonRDD(json, samplingRatio, schema) + jsonRDD(json, samplingRatio) } /** @@ -154,7 +159,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0, None) + def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) /** * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, @@ -162,22 +167,30 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = jsonRDD(json, 1.0, Option(schema)) + def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + val appliedSchema = + Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))) + + applySchemaToPartitions( + json, + appliedSchema, + JsonRDD.jsonStringToRow(appliedSchema, _: Iterator[String])) + } /** * :: Experimental :: */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double, schema: Option[StructType]): SchemaRDD = { - val appliedSchema = - schema.getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio))) + def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) applySchemaToPartitions( json, - appliedSchema, - JsonRDD.jsonStringToRow(appliedSchema, _: Iterator[String])) + schema, + JsonRDD.jsonStringToRow(schema, _: Iterator[String])) } + /** * :: Experimental :: * Creates an empty parquet file with the schema of class `A`, which can be registered as a table. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index d60b4eca52ff0..bf1a2a866e58d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -123,12 +123,21 @@ private[sql] trait SchemaRDDLike { def saveAsTable(tableName: String): Unit = sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd - /** Returns the schema. */ + /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). + * + * @group schema + */ def schema: StructType = queryExecution.analyzed.schema - /** Returns the output schema in the tree format. */ - def schemaString: String = schema.schemaString + /** Returns the schema as a string in the tree format. + * + * @group schema + */ + def schemaString: String = schema.structString - /** Prints out the schema in the tree format. */ + /** Prints out the schema. + * + * @group schema + */ def printSchema(): Unit = println(schemaString) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java b/sql/core/src/main/scala/org/apache/spark/sql/package-info.java similarity index 100% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/package-info.java rename to sql/core/src/main/scala/org/apache/spark/sql/package-info.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala new file mode 100644 index 0000000000000..4d36f639f4a00 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -0,0 +1,350 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.annotation.DeveloperApi + +/** + * Allows the execution of relational queries, including those expressed in SQL using Spark. + * + * @groupname dataType Data types + * @groupdesc Spark SQL data types. + * @groupprio dataType -2 + * @groupname row Row + * @groupprio row -1 + */ +package object sql { + + protected[sql] type Logging = com.typesafe.scalalogging.slf4j.Logging + + /** + * :: DeveloperApi :: + * + * Represents one row of output from a relational operator. + * @group row + */ + @DeveloperApi + type Row = catalyst.expressions.Row + + /** + * :: DeveloperApi :: + * + * A [[Row]] object can be constructed by providing field values. Example: + * {{{ + * import org.apache.spark.sql._ + * + * Row(value1, value2, value3, ...) + * }}} + * + * Fields in a [[Row]] object can be extracted in a pattern match. Example: + * {{{ + * import org.apache.spark.sql._ + * + * val pairs = sql("SELECT key, value FROM src").rdd.map { + * case Row(key: Int, value: String) => + * key -> value + * } + * }}} + * @group row + */ + @DeveloperApi + val Row = catalyst.expressions.Row + + /** + * :: DeveloperApi :: + * + * The base type of all Spark SQL data types. + * + * @group dataType + */ + @DeveloperApi + type DataType = catalyst.types.DataType + + /** + * :: DeveloperApi :: + * + * The data type representing `String` values + * + * @group dataType + */ + @DeveloperApi + val StringType = catalyst.types.StringType + + /** + * :: DeveloperApi :: + * + * The data type representing `Array[Byte]` values. + * + * @group dataType + */ + @DeveloperApi + val BinaryType = catalyst.types.BinaryType + + /** + * :: DeveloperApi :: + * + * The data type representing `Boolean` values. + * + *@group dataType + */ + @DeveloperApi + val BooleanType = catalyst.types.BooleanType + + /** + * :: DeveloperApi :: + * + * The data type representing `java.sql.Timestamp` values + * + * @group dataType + */ + @DeveloperApi + val TimestampType = catalyst.types.TimestampType + + /** + * :: DeveloperApi :: + * + * The data type representing `scala.math.BigDecimal` values. + * + * @group dataType + */ + @DeveloperApi + val DecimalType = catalyst.types.DecimalType + + /** + * :: DeveloperApi :: + * + * The data type representing `Double` values. + * + * @group dataType + */ + @DeveloperApi + val DoubleType = catalyst.types.DoubleType + + /** + * :: DeveloperApi :: + * + * The data type representing `Float` values. + * + * @group dataType + */ + @DeveloperApi + val FloatType = catalyst.types.FloatType + + /** + * :: DeveloperApi :: + * + * The data type representing `Byte` values + * + * @group dataType + */ + @DeveloperApi + val ByteType = catalyst.types.ByteType + + /** + * :: DeveloperApi :: + * + * The data type representing `Int` values. + * + * @group dataType + */ + @DeveloperApi + val IntegerType = catalyst.types.IntegerType + + /** + * :: DeveloperApi :: + * + * The data type representing `Long` values. + * + * @group dataType + */ + @DeveloperApi + val LongType = catalyst.types.LongType + + /** + * :: DeveloperApi :: + * + * The data type representing `Short` values. + * + * @group dataType + */ + @DeveloperApi + val ShortType = catalyst.types.ShortType + + /** + * :: DeveloperApi :: + * + * The data type representing `Seq`s. + * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and + * `containsNull: Boolean`. The field of `elementType` is used to specify the type of + * array elements. The field of `containsNull` is used to specify if the array can have + * any `null` value. + * + * @group dataType + */ + @DeveloperApi + type ArrayType = catalyst.types.ArrayType + + /** + * :: DeveloperApi :: + * + * An [[ArrayType]] object can be constructed with two ways, + * {{{ + * ArrayType(elementType: DataType, containsNull: Boolean) + * }}} and + * {{{ + * ArrayType(elementType: DataType) + * }}} + * For `ArrayType(elementType)`, the field of `containsNull` is set to `false`. + * + * @group dataType + */ + @DeveloperApi + val ArrayType = catalyst.types.ArrayType + + /** + * :: DeveloperApi :: + * + * The data type representing `Map`s. A [[MapType]] object comprises two fields, + * `keyType: [[DataType]]` and `valueType: [[DataType]]`. + * The field of `keyType` is used to specify the type of keys in the map. + * The field of `valueType` is used to specify the type of values in the map. + * For a [[MapType]] column, keys and values should not contain any `null` value. + * + * @group dataType + */ + @DeveloperApi + type MapType = catalyst.types.MapType + + /** + * :: DeveloperApi :: + * + * A [[MapType]] can be constructed by + * {{{ + * MapType(keyType: DataType, valueType: DataType) + * }}} + * + * @group dataType + */ + @DeveloperApi + val MapType = catalyst.types.MapType + + /** + * :: DeveloperApi :: + * + * The data type representing [[Row]]s. + * A [[StructType]] object comprises a [[Seq]] of [[StructField]]s. + * + * @group dataType + */ + @DeveloperApi + type StructType = catalyst.types.StructType + + /** + * :: DeveloperApi :: + * + * A [[StructType]] object can be constructed by + * {{{ + * StructType(fields: Seq[StructField]) + * }}} + * For a [[StructType]] object, one or multiple [[StructField]]s can be extracted by names. + * If multiple [[StructField]]s are extracted, a [[StructType]] object will be returned. + * If a provided name does not have a matching field, it will be ignored. For the case + * of extracting a single StructField, a `null` will be returned. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val struct = + * StructType( + * StructField("a", IntegerType, true) :: + * StructField("b", LongType, false) :: + * StructField("c", BooleanType, false) :: Nil) + * + * // Extract a single StructField. + * val singleField = struct("b") + * // singleField: StructField = StructField(b,LongType,false) + * + * // This struct does not have a field called "d". null will be returned. + * val nonExisting = struct("d") + * // nonExisting: StructField = null + * + * // Extract multiple StructFields. Field names are provided in a set. + * // A StructType object will be returned. + * val twoFields = struct(Set("b", "c")) + * // twoFields: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * + * // Those names do not have matching fields will be ignored. + * // For the case shown below, "d" will be ignored and + * // it is treated as struct(Set("b", "c")). + * val ignoreNonExisting = struct(Set("b", "c", "d")) + * // ignoreNonExisting: StructType = + * // StructType(List(StructField(b,LongType,false), StructField(c,BooleanType,false))) + * }}} + * + * A [[Row]] object is used as a value of the StructType. + * Example: + * {{{ + * import org.apache.spark.sql._ + * + * val innerStruct = + * StructType( + * StructField("f1", IntegerType, true) :: + * StructField("f2", LongType, false) :: + * StructField("f3", BooleanType, false) :: Nil) + * + * val struct = StructType( + * StructField("a", innerStruct, true) :: Nil) + * + * // Create a Row with the schema defined by struct + * val row = Row(Row(1, 2, true)) + * // row: Row = [[1,2,true]] + * }}} + * + * @group dataType + */ + @DeveloperApi + val StructType = catalyst.types.StructType + + /** + * :: DeveloperApi :: + * + * A [[StructField]] object represents a field in a [[StructType]] object. + * A [[StructField]] object comprises three fields, `name: [[String]]`, `dataType: [[DataType]]`, + * and `nullable: Boolean`. The field of `name` is the name of a `StructField`. The field of + * `dataType` specifies the data type of a `StructField`. + * The field of `nullable` specifies if values of a `StructField` can contain `null` values. + * + * @group dataType + */ + @DeveloperApi + type StructField = catalyst.types.StructField + + /** + * :: DeveloperApi :: + * + * A [[StructField]] object can be constructed by + * {{{ + * StructField(name: String, dataType: DataType, nullable: Boolean) + * }}} + * + * @group dataType + */ + @DeveloperApi + val StructField = catalyst.types.StructField +} From e495e4e1bc3bd8993cdd94413aaa00e8c1f95423 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 16 Jul 2014 10:54:44 -0700 Subject: [PATCH 14/34] More comments. --- .../scala/org/apache/spark/sql/package.scala | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 4d36f639f4a00..355b7f85950d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -51,6 +51,52 @@ package object sql { * Row(value1, value2, value3, ...) * }}} * + * A value of a row can be accessed through both generic access by ordinal, + * which will incur boxing overhead for primitives, as well as native primitive access. + * An example of generic access by ordinal: + * {{{ + * import org.apache.spark.sql._ + * + * val row = Row(1, true, "a string", null) + * // row: Row = [1,true,a string,null] + * val firstValue = row(0) + * // firstValue: Any = 1 + * val fourthValue = row(3) + * // fourthValue: Any = null + * }}} + * + * For native primitive access, it is invalid to use the native primitive interface to retrieve + * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a value + * that might be null. + * An example of native primitive access: + * {{{ + * // using the row from the previous example. + * val firstValue = row.getInt(0) + * // firstValue: Int = 1 + * val isNull = row.isNullAt(3) + * // isNull: Boolean = true + * }}} + * + * Interfaces related to native primitive access are: + * + * `isNullAt(i: Int): Boolean` + * + * `getInt(i: Int): Int` + * + * `getLong(i: Int): Long` + * + * `getDouble(i: Int): Double` + * + * `getFloat(i: Int): Float` + * + * `getBoolean(i: Int): Boolean` + * + * `getShort(i: Int): Short` + * + * `getByte(i: Int): Byte` + * + * `getString(i: Int): String` + * * Fields in a [[Row]] object can be extracted in a pattern match. Example: * {{{ * import org.apache.spark.sql._ @@ -60,6 +106,7 @@ package object sql { * key -> value * } * }}} + * * @group row */ @DeveloperApi From 1d9c13a3f15a360f613a90ff4b460529d20e8fd4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 23 Jul 2014 12:13:30 -0700 Subject: [PATCH 15/34] Update applySchema API. --- .../org/apache/spark/sql/SQLContext.scala | 53 +++++++------------ .../org/apache/spark/sql/json/JsonRDD.scala | 31 +++++------ .../org/apache/spark/sql/json/JsonSuite.scala | 2 - 3 files changed, 34 insertions(+), 52 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 197942c7b0f66..fd0aa363acbab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -88,33 +88,18 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, SparkLogicalPlan(ExistingRdd.fromProductRdd(rdd))) /** - * Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function - * that will be applied to each partition of the RDD to convert RDD records to [[Row]]s. + * :: DeveloperApi :: + * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. * * @group userf */ - def applySchema[A](rdd: RDD[A], schema: StructType, f: A => Row): SchemaRDD = - applySchemaToPartitions(rdd, schema, (iter: Iterator[A]) => iter.map(f)) - - /** - * Creates a [[SchemaRDD]] from an [[RDD]] by applying a schema to this RDD and using a function - * that will be applied to each partition of the RDD to convert RDD records to [[Row]]s. - * Similar to `RDD.mapPartitions``, this function can be used to improve performance where there - * is other setup work that can be amortized and used repeatedly for all of the - * elements in a partition. - * @group userf - */ - def applySchemaToPartitions[A]( - rdd: RDD[A], - schema: StructType, - f: Iterator[A] => Iterator[Row]): SchemaRDD = - new SchemaRDD(this, makeCustomRDDScan(rdd, schema, f)) - - protected[sql] def makeCustomRDDScan[A]( - rdd: RDD[A], - schema: StructType, - f: Iterator[A] => Iterator[Row]): LogicalPlan = - SparkLogicalPlan(ExistingRdd(schema.toAttributes, rdd.mapPartitions(f))) + @DeveloperApi + def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { + val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD)) + new SchemaRDD(this, logicalPlan) + } /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. @@ -133,11 +118,13 @@ class SQLContext(@transient val sparkContext: SparkContext) def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) /** + * :: Experimental :: * Loads a JSON file (one object per line) and applies the given schema, * returning the result as a [[SchemaRDD]]. * * @group userf */ + @Experimental def jsonFile(path: String, schema: StructType): SchemaRDD = { val json = sparkContext.textFile(path) jsonRDD(json, schema) @@ -162,19 +149,18 @@ class SQLContext(@transient val sparkContext: SparkContext) def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) /** + * :: Experimental :: * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, * returning the result as a [[SchemaRDD]]. * * @group userf */ + @Experimental def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { val appliedSchema = Option(schema).getOrElse(JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0))) - - applySchemaToPartitions( - json, - appliedSchema, - JsonRDD.jsonStringToRow(appliedSchema, _: Iterator[String])) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + applySchema(rowRDD, appliedSchema) } /** @@ -182,12 +168,9 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @Experimental def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { - val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) - - applySchemaToPartitions( - json, - schema, - JsonRDD.jsonStringToRow(schema, _: Iterator[String])) + val appliedSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, samplingRatio)) + val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema) + applySchema(rowRDD, appliedSchema) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 8e22f880810ff..a3fac2a5adbb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.Logging private[sql] object JsonRDD extends Logging { private[sql] def jsonStringToRow( - schema: StructType, - jsonIter: Iterator[String]): Iterator[Row] = { - parseJson(jsonIter).map(parsed => asRow(parsed, schema)) + json: RDD[String], + schema: StructType): RDD[Row] = { + parseJson(json).map(parsed => asRow(parsed, schema)) } private[sql] def inferSchema( @@ -42,8 +42,7 @@ private[sql] object JsonRDD extends Logging { samplingRatio: Double = 1.0): StructType = { require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0") val schemaData = if (samplingRatio > 0.99) json else json.sample(false, samplingRatio, 1) - val allKeys = - schemaData.mapPartitions(iter => parseJson(iter)).map(allKeysWithValueTypes).reduce(_ ++ _) + val allKeys = parseJson(schemaData).map(allKeysWithValueTypes).reduce(_ ++ _) createSchema(allKeys) } @@ -255,7 +254,7 @@ private[sql] object JsonRDD extends Logging { case atom => atom } - private def parseJson(jsonIter: Iterator[String]): Iterator[Map[String, Any]] = { + private def parseJson(json: RDD[String]): RDD[Map[String, Any]] = { // According to [Jackson-72: https://jira.codehaus.org/browse/JACKSON-72], // ObjectMapper will not return BigDecimal when // "DeserializationFeature.USE_BIG_DECIMAL_FOR_FLOATS" is disabled @@ -264,15 +263,17 @@ private[sql] object JsonRDD extends Logging { // for every float number, which will be slow. // So, right now, we will have Infinity for those BigDecimal number. // TODO: Support BigDecimal. - // Also, when there is a key appearing multiple times (a duplicate key), - // the ObjectMapper will take the last value associated with this duplicate key. - // For example: for {"key": 1, "key":2}, we will get "key"->2. - val mapper = new ObjectMapper() - jsonIter.map { - record => - val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) - parsed.asInstanceOf[Map[String, Any]] - } + json.mapPartitions(iter => { + // Also, when there is a key appearing multiple times (a duplicate key), + // the ObjectMapper will take the last value associated with this duplicate key. + // For example: for {"key": 1, "key":2}, we will get "key"->2. + val mapper = new ObjectMapper() + iter.map { + record => + val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) + parsed.asInstanceOf[Map[String, Any]] + } + }) } private def toLong(value: Any): Long = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 10ad4e2e3dd7f..9d9cfdd7c92e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.json -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.json.JsonRDD.{enforceCorrectType, compatibleType} From 9c99bc0428c9b4eb011d51b8e964fec104e8d3f7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 23 Jul 2014 12:15:14 -0700 Subject: [PATCH 16/34] Several minor updates. --- .../org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../org/apache/spark/sql/catalyst/types/dataTypes.scala | 5 +++-- .../src/main/scala/org/apache/spark/sql/SchemaRDD.scala | 2 +- .../src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/package.scala | 6 ++++-- 5 files changed, 10 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 1e8fe098f7c1b..0988b0c6d990c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -128,7 +128,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def schema: StructType = StructType.fromAttributes(output) /** Returns the output schema in the tree format. */ - def schemaString: String = schema.structString + def schemaString: String = schema.treeString /** Prints out the schema in the tree format */ def printSchema(): Unit = println(schemaString) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 0946e4dd2cfdd..a14041d4ccb5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -314,7 +314,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { protected[sql] def toAttributes = fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) - def structString: String = { + def treeString: String = { val builder = new StringBuilder builder.append("root\n") val prefix = " |" @@ -323,7 +323,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { builder.toString() } - def printStruct(): Unit = println(structString) + def printTreeString(): Unit = println(treeString) private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { fields.foreach(field => field.buildFormattedString(prefix, builder)) @@ -335,6 +335,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { case class MapType(keyType: DataType, valueType: DataType) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") + builder.append(s"${prefix}-- value: ${valueType.simpleString}\n") DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 2fd211eba86b1..1c8c5276761d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -382,7 +382,7 @@ class SchemaRDD( case (obj, (attrName, dataType)) => dataType match { case struct: StructType => map.put(attrName, rowToMap(obj.asInstanceOf[Row], struct)) - case array @ ArrayType(struct: StructType) => + case array @ ArrayType(struct: StructType, _) => val arrayValues = obj match { case seq: Seq[Any] => seq.map(element => rowToMap(element.asInstanceOf[Row], struct)).asJava diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index bf1a2a866e58d..cf9baa34eb2bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -133,7 +133,7 @@ private[sql] trait SchemaRDDLike { * * @group schema */ - def schemaString: String = schema.structString + def schemaString: String = schema.treeString /** Prints out the schema. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 355b7f85950d0..c7a1af05093f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -24,7 +24,9 @@ import org.apache.spark.annotation.DeveloperApi * * @groupname dataType Data types * @groupdesc Spark SQL data types. - * @groupprio dataType -2 + * @groupprio dataType -3 + * @groupname field Field + * @groupprio field -2 * @groupname row Row * @groupprio row -1 */ @@ -377,7 +379,7 @@ package object sql { * `dataType` specifies the data type of a `StructField`. * The field of `nullable` specifies if values of a `StructField` can contain `null` values. * - * @group dataType + * @group field */ @DeveloperApi type StructField = catalyst.types.StructField From 8da1a17907adf48b2acf9339c957f74bbba9598f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 23 Jul 2014 15:35:05 -0700 Subject: [PATCH 17/34] Add Row.fromSeq. --- .../spark/sql/catalyst/expressions/Row.scala | 5 ++ .../scala/org/apache/spark/sql/package.scala | 4 +- .../scala/org/apache/spark/sql/RowSuite.scala | 46 +++++++++++++++++++ 3 files changed, 54 insertions(+), 1 deletion(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index b8f810447862f..22a058191d415 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -37,6 +37,11 @@ object Row { * This method can be used to construct a [[Row]] with the given values. */ def apply(values: Any*): Row = new GenericRow(values.toArray) + + /** + * This method can be used to construct a [[Row]] from a [[Seq]] of values. + */ + def fromSeq(values: Seq[Any]): Row = new GenericRow(values.toArray) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index c7a1af05093f6..25d370a539e03 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -50,7 +50,10 @@ package object sql { * {{{ * import org.apache.spark.sql._ * + * // Create a Row from values. * Row(value1, value2, value3, ...) + * // Create a Row from a Seq of values. + * Row.fromSeq(Seq(value1, value2, ...)) * }}} * * A value of a row can be accessed through both generic access by ordinal, @@ -272,7 +275,6 @@ package object sql { * `keyType: [[DataType]]` and `valueType: [[DataType]]`. * The field of `keyType` is used to specify the type of keys in the map. * The field of `valueType` is used to specify the type of values in the map. - * For a [[MapType]] column, keys and values should not contain any `null` value. * * @group dataType */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala new file mode 100644 index 0000000000000..651cb735ab7d9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -0,0 +1,46 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow + +class RowSuite extends FunSuite { + + test("create row") { + val expected = new GenericMutableRow(4) + expected.update(0, 2147483647) + expected.update(1, "this is a string") + expected.update(2, false) + expected.update(3, null) + val actual1 = Row(2147483647, "this is a string", false, null) + assert(expected.size === actual1.size) + assert(expected.getInt(0) === actual1.getInt(0)) + assert(expected.getString(1) === actual1.getString(1)) + assert(expected.getBoolean(2) === actual1.getBoolean(2)) + assert(expected(3) === actual1(3)) + + val actual2 = Row.fromSeq(Seq(2147483647, "this is a string", false, null)) + assert(expected.size === actual2.size) + assert(expected.getInt(0) === actual2.getInt(0)) + assert(expected.getString(1) === actual2.getString(1)) + assert(expected.getBoolean(2) === actual2.getBoolean(2)) + assert(expected(3) === actual2(3)) + } +} From aa92e844137831de4d66ccdd3a317eda7c800dcf Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 23 Jul 2014 15:38:47 -0700 Subject: [PATCH 18/34] Update data type tests. --- .../spark/sql/catalyst/types/dataTypes.scala | 14 +++++++++++++- .../{SchemaSuite.scala => DataTypeSuite.scala} | 15 +++++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/{SchemaSuite.scala => DataTypeSuite.scala} (83%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index a14041d4ccb5d..93ffee327cefe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -295,12 +295,19 @@ object StructType { case class StructType(fields: Seq[StructField]) extends DataType { require(StructType.validateFields(fields), "Found fields with the same name.") + /** + * Returns all field names in a [[Seq]]. + */ + lazy val fieldNames: Seq[String] = fields.map(_.name) + private lazy val fieldNamesSet: Set[String] = fieldNames.toSet + /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. */ def apply(name: String): StructField = { - fields.find(f => f.name == name).orNull + fields.find(f => f.name == name).getOrElse( + throw new IllegalArgumentException(s"Field ${name} does not exist.")) } /** @@ -308,6 +315,11 @@ case class StructType(fields: Seq[StructField]) extends DataType { * Those names which do not have matching fields will be ignored. */ def apply(names: Set[String]): StructType = { + val nonExistFields = names -- fieldNamesSet + if (!nonExistFields.isEmpty) { + throw new IllegalArgumentException( + s"Field ${nonExistFields.mkString(",")} does not exist.") + } StructType(fields.filter(f => names.contains(f.name))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala similarity index 83% rename from sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index c1e1b5333927d..c5bd7b391db41 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.sql import org.scalatest.FunSuite -class SchemaSuite extends FunSuite { +class DataTypeSuite extends FunSuite { - test("constructing an ArrayType") { + test("construct an ArrayType") { val array = ArrayType(StringType) assert(ArrayType(StringType, false) === array) } - test("extracting fields from a StructType") { + test("extract fields from a StructType") { val struct = StructType( StructField("a", IntegerType, true) :: StructField("b", LongType, false) :: @@ -36,14 +36,17 @@ class SchemaSuite extends FunSuite { assert(StructField("b", LongType, false) === struct("b")) - assert(struct("e") === null) + intercept[IllegalArgumentException] { + struct("e") + } val expectedStruct = StructType( StructField("b", LongType, false) :: StructField("d", FloatType, true) :: Nil) assert(expectedStruct === struct(Set("b", "d"))) - // struct does not have a field called e. So e is ignored. - assert(expectedStruct === struct(Set("b", "d", "e"))) + intercept[IllegalArgumentException] { + struct(Set("b", "d", "e", "f")) + } } } From 624765cb0122cc715cbbe5b4967f6b5e4e0ce0ec Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 23 Jul 2014 15:39:14 -0700 Subject: [PATCH 19/34] Tests for applySchema. --- .../org/apache/spark/sql/SQLContext.scala | 2 ++ .../org/apache/spark/sql/SQLQuerySuite.scala | 32 +++++++++++++++++++ .../scala/org/apache/spark/sql/TestData.scala | 7 ++++ 3 files changed, 41 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index fd0aa363acbab..8433b215cb3d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -97,6 +97,8 @@ class SQLContext(@transient val sparkContext: SparkContext) */ @DeveloperApi def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { + // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied + // schema differs from the existing schema on any field data type. val logicalPlan = SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRDD)) new SchemaRDD(this, logicalPlan) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0743cfe8cff0f..29c9dbb9457f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -431,4 +431,36 @@ class SQLQuerySuite extends QueryTest { ) clear() } + + test("apply schema") { + val schema = StructType( + StructField("f1", IntegerType, false) :: + StructField("f2", StringType, false) :: + StructField("f3", BooleanType, false) :: + StructField("f4", IntegerType, true) :: Nil) + + val rowRDD = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(values(0).toInt, values(1), values(2).toBoolean, v4) + } + + val schemaRDD = applySchema(rowRDD, schema) + schemaRDD.registerAsTable("applySchema") + checkAnswer( + sql("SELECT * FROM applySchema"), + (1, "A1", true, null) :: + (2, "B2", false, null) :: + (3, "C3", true, null) :: + (4, "D4", true, 2147483644) :: Nil) + + checkAnswer( + sql("SELECT f1, f4 FROM applySchema"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 330b20b315d63..213190e812026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -128,4 +128,11 @@ object TestData { case class TableName(tableName: String) TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerAsTable("tableName") + + val unparsedStrings = + TestSQLContext.sparkContext.parallelize( + "1, A1, true, null" :: + "2, B2, false, null" :: + "3, C3, true, null" :: + "4, D4, true, 2147483644" :: Nil) } From 1c9f33c000dd4c7ec7cb9e99bb4d1afe9be20211 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 24 Jul 2014 11:44:15 -0700 Subject: [PATCH 20/34] Java APIs for DataTypes and Row. --- .../spark/sql/api/java/types/ArrayType.java | 63 +++++++ .../spark/sql/api/java/types/BinaryType.java | 25 +++ .../spark/sql/api/java/types/BooleanType.java | 22 +++ .../spark/sql/api/java/types/ByteType.java | 25 +++ .../spark/sql/api/java/types/DataType.java | 161 ++++++++++++++++ .../spark/sql/api/java/types/DecimalType.java | 25 +++ .../spark/sql/api/java/types/DoubleType.java | 25 +++ .../spark/sql/api/java/types/FloatType.java | 25 +++ .../spark/sql/api/java/types/IntegerType.java | 25 +++ .../spark/sql/api/java/types/LongType.java | 25 +++ .../spark/sql/api/java/types/MapType.java | 62 +++++++ .../spark/sql/api/java/types/ShortType.java | 25 +++ .../spark/sql/api/java/types/StringType.java | 25 +++ .../spark/sql/api/java/types/StructField.java | 72 +++++++ .../spark/sql/api/java/types/StructType.java | 54 ++++++ .../sql/api/java/types/TimestampType.java | 25 +++ .../sql/api/java/types/package-info.java | 22 +++ .../org/apache/spark/sql/api/java/Row.scala | 56 +++++- .../scala/org/apache/spark/sql/package.scala | 8 +- .../spark/sql/api/java/JavaRowSuite.java | 170 +++++++++++++++++ .../java/JavaSideDataTypeConversionSuite.java | 175 ++++++++++++++++++ .../ScalaSideDataTypeConversionSuite.scala | 82 ++++++++ 22 files changed, 1192 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java new file mode 100644 index 0000000000000..61f52055842e6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Lists. + * An ArrayType object comprises two fields, {@code DataType elementType} and + * {@code boolean containsNull}. The field of {@code elementType} is used to specify the type of + * array elements. The field of {@code containsNull} is used to specify if the array can have + * any {@code null} value. + */ +public class ArrayType extends DataType { + private DataType elementType; + private boolean containsNull; + + protected ArrayType(DataType elementType, boolean containsNull) { + this.elementType = elementType; + this.containsNull = containsNull; + } + + public DataType getElementType() { + return elementType; + } + + public boolean isContainsNull() { + return containsNull; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + ArrayType arrayType = (ArrayType) o; + + if (containsNull != arrayType.containsNull) return false; + if (!elementType.equals(arrayType.elementType)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = elementType.hashCode(); + result = 31 * result + (containsNull ? 1 : 0); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java new file mode 100644 index 0000000000000..c33ee5e25cd32 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing byte[] values. + */ +public class BinaryType extends DataType { + protected BinaryType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java new file mode 100644 index 0000000000000..38981a21da58d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +public class BooleanType extends DataType { + protected BooleanType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java new file mode 100644 index 0000000000000..16b0d9ecf688c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Byte values. + */ +public class ByteType extends DataType { + protected ByteType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java new file mode 100644 index 0000000000000..7399bf9876f50 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * The base type of all Spark SQL data types. + */ +public abstract class DataType { + + /** + * Gets the StringType object. + */ + public static final StringType StringType = new StringType(); + + /** + * Gets the BinaryType object. + */ + public static final BinaryType BinaryType = new BinaryType(); + + /** + * Gets the BooleanType object. + */ + public static final BooleanType BooleanType = new BooleanType(); + + /** + * Gets the TimestampType object. + */ + public static final TimestampType TimestampType = new TimestampType(); + + /** + * Gets the DecimalType object. + */ + public static final DecimalType DecimalType = new DecimalType(); + + /** + * Gets the DoubleType object. + */ + public static final DoubleType DoubleType = new DoubleType(); + + /** + * Gets the FloatType object. + */ + public static final FloatType FloatType = new FloatType(); + + /** + * Gets the ByteType object. + */ + public static final ByteType ByteType = new ByteType(); + + /** + * Gets the IntegerType object. + */ + public static final IntegerType IntegerType = new IntegerType(); + + /** + * Gets the LongType object. + */ + public static final LongType LongType = new LongType(); + + /** + * Gets the ShortType object. + */ + public static final ShortType ShortType = new ShortType(); + + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}) and + * whether the array contains null values ({@code containsNull}). + * @param elementType + * @param containsNull + * @return + */ + public static ArrayType createArrayType(DataType elementType, boolean containsNull) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + + return new ArrayType(elementType, containsNull); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}) and values + * ({@code keyType}). + * @param keyType + * @param valueType + * @return + */ + public static MapType createMapType(DataType keyType, DataType valueType) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + + return new MapType(keyType, valueType); + } + + /** + * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and + * whether values of this field can be null values ({@code nullable}). + * @param name + * @param dataType + * @param nullable + * @return + */ + public static StructField createStructField(String name, DataType dataType, boolean nullable) { + if (name == null) { + throw new IllegalArgumentException("name should not be null."); + } + if (dataType == null) { + throw new IllegalArgumentException("dataType should not be null."); + } + + return new StructField(name, dataType, nullable); + } + + /** + * Creates a StructType with the given StructFields ({@code fields}). + * @param fields + * @return + */ + public static StructType createStructType(List fields) { + if (fields == null) { + throw new IllegalArgumentException("fields should not be null."); + } + Set distinctNames = new HashSet(); + for (StructField field: fields) { + if (field == null) { + throw new IllegalArgumentException( + "fields should not contain any null."); + } + + distinctNames.add(field.getName()); + } + if (distinctNames.size() != fields.size()) { + throw new IllegalArgumentException( + "fields should have distinct names."); + } + + return new StructType(fields); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java new file mode 100644 index 0000000000000..d483824999e85 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing java.math.BigDecimal values. + */ +public class DecimalType extends DataType { + protected DecimalType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java new file mode 100644 index 0000000000000..13a7bf6bbb5ed --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Double values. + */ +public class DoubleType extends DataType { + protected DoubleType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java new file mode 100644 index 0000000000000..bf47d4fc1fa07 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Float values. + */ +public class FloatType extends DataType { + protected FloatType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java new file mode 100644 index 0000000000000..f41ec2260df6b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Int values. + */ +public class IntegerType extends DataType { + protected IntegerType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java new file mode 100644 index 0000000000000..7c73a7b506a2b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Long values. + */ +public class LongType extends DataType { + protected LongType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java new file mode 100644 index 0000000000000..d946d967a33fc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Maps. A MapType object comprises two fields, + * {@code DataType keyType} and {@code DataType valueType}. + * The field of {@code keyType} is used to specify the type of keys in the map. + * The field of {@code valueType} is used to specify the type of values in the map. + */ +public class MapType extends DataType { + private DataType keyType; + private DataType valueType; + + protected MapType(DataType keyType, DataType valueType) { + this.keyType = keyType; + this.valueType = valueType; + } + + public DataType getKeyType() { + return keyType; + } + + public DataType getValueType() { + return valueType; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + MapType mapType = (MapType) o; + + if (!keyType.equals(mapType.keyType)) return false; + if (!valueType.equals(mapType.valueType)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = keyType.hashCode(); + result = 31 * result + valueType.hashCode(); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java new file mode 100644 index 0000000000000..8ffa75a835e63 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing Short values. + */ +public class ShortType extends DataType { + protected ShortType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java new file mode 100644 index 0000000000000..dd9be52f8c53b --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing String values. + */ +public class StringType extends DataType { + protected StringType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java new file mode 100644 index 0000000000000..25c82de9641c5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * A StructField object represents a field in a StructType object. + * A StructField object comprises three fields, {@code String name}, {@code DataType dataType}, + * and {@code boolean nullable}. The field of {@code name} is the name of a StructField. + * The field of {@code dataType} specifies the data type of a StructField. + * The field of {@code nullable} specifies if values of a StructField can contain {@code null} + * values. + */ +public class StructField { + private String name; + private DataType dataType; + private boolean nullable; + + protected StructField(String name, DataType dataType, boolean nullable) { + this.name = name; + this.dataType = dataType; + this.nullable = nullable; + } + + public String getName() { + return name; + } + + public DataType getDataType() { + return dataType; + } + + public boolean isNullable() { + return nullable; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StructField that = (StructField) o; + + if (nullable != that.nullable) return false; + if (!dataType.equals(that.dataType)) return false; + if (!name.equals(that.name)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = name.hashCode(); + result = 31 * result + dataType.hashCode(); + result = 31 * result + (nullable ? 1 : 0); + return result; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java new file mode 100644 index 0000000000000..cd397d42135f1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +import java.util.Arrays; +import java.util.List; + +/** + * The data type representing Rows. + * A StructType object comprises a List of StructFields. + */ +public class StructType extends DataType { + private StructField[] fields; + + protected StructType(List fields) { + this.fields = fields.toArray(new StructField[0]); + } + + public StructField[] getFields() { + return fields; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + StructType that = (StructType) o; + + if (!Arrays.equals(fields, that.fields)) return false; + + return true; + } + + @Override + public int hashCode() { + return Arrays.hashCode(fields); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java new file mode 100644 index 0000000000000..8c2f203d950c4 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java.types; + +/** + * The data type representing java.sql.Timestamp values. + */ +public class TimestampType extends DataType { + protected TimestampType() {} +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java new file mode 100644 index 0000000000000..a1c6fcf1430f5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +/** + * Allows users to get and create Spark SQL data types. + */ +package org.apache.spark.sql.api.java.types; \ No newline at end of file diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index 9b0dd2176149b..a87d6f25f6130 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -17,6 +17,11 @@ package org.apache.spark.sql.api.java +import scala.annotation.varargs +import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} +import scala.collection.JavaConversions +import scala.math.BigDecimal + import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow} /** @@ -29,7 +34,7 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { /** Returns the value of column `i`. */ def get(i: Int): Any = - row(i) + Row.toJavaValue(row(i)) /** Returns true if value at column `i` is NULL. */ def isNullAt(i: Int) = get(i) == null @@ -89,5 +94,54 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { */ def getString(i: Int): String = row.getString(i) + + def canEqual(other: Any): Boolean = other.isInstanceOf[Row] + + override def equals(other: Any): Boolean = other match { + case that: Row => + (that canEqual this) && + row == that.row + case _ => false + } + + override def hashCode(): Int = row.hashCode() } +object Row { + + private def toJavaValue(value: Any): Any = value match { + case row: ScalaRow => new Row(row) + case map: scala.collection.Map[_, _] => + JavaConversions.mapAsJavaMap( + map.map { + case (key, value) => (toJavaValue(key), toJavaValue(value)) + } + ) + case seq: scala.collection.Seq[_] => + JavaConversions.seqAsJavaList(seq.map(toJavaValue)) + case decimal: BigDecimal => decimal.underlying() + case other => other + } + + // TODO: Consolidate the toScalaValue at here with the scalafy in JsonRDD? + private def toScalaValue(value: Any): Any = value match { + case row: Row => row.row + case map: java.util.Map[_, _] => + JMapWrapper(map).map { + case (key, value) => (toScalaValue(key), toScalaValue(value)) + } + case list: java.util.List[_] => + JListWrapper(list).map(toScalaValue) + case decimal: java.math.BigDecimal => BigDecimal(decimal) + case other => other + } + + /** + * Creates a Row with the given values. + */ + @varargs def create(values: Any*): Row = { + // Right now, we cannot use @varargs to annotate the constructor of + // org.apache.spark.sql.api.java.Row. See https://issues.scala-lang.org/browse/SI-8383. + new Row(ScalaRow(values.map(toScalaValue):_*)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 25d370a539e03..8f87e6e19baf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -71,8 +71,8 @@ package object sql { * }}} * * For native primitive access, it is invalid to use the native primitive interface to retrieve - * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a value - * that might be null. + * a value that is null, instead a user must check `isNullAt` before attempting to retrieve a + * value that might be null. * An example of native primitive access: * {{{ * // using the row from the previous example. @@ -160,7 +160,7 @@ package object sql { /** * :: DeveloperApi :: * - * The data type representing `java.sql.Timestamp` values + * The data type representing `java.sql.Timestamp` values. * * @group dataType */ @@ -200,7 +200,7 @@ package object sql { /** * :: DeveloperApi :: * - * The data type representing `Byte` values + * The data type representing `Byte` values. * * @group dataType */ diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java new file mode 100644 index 0000000000000..52d07b5425cc3 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class JavaRowSuite { + private byte byteValue; + private short shortValue; + private int intValue; + private long longValue; + private float floatValue; + private double doubleValue; + private BigDecimal decimalValue; + private boolean booleanValue; + private String stringValue; + private byte[] binaryValue; + private Timestamp timestampValue; + + @Before + public void setUp() { + byteValue = (byte)127; + shortValue = (short)32767; + intValue = 2147483647; + longValue = 9223372036854775807L; + floatValue = (float)3.4028235E38; + doubleValue = 1.7976931348623157E308; + decimalValue = new BigDecimal("1.7976931348623157E328"); + booleanValue = true; + stringValue = "this is a string"; + binaryValue = stringValue.getBytes(); + timestampValue = Timestamp.valueOf("2014-06-30 09:20:00.0"); + } + + @Test + public void constructSimpleRow() { + Row simpleRow = Row.create( + byteValue, // ByteType + new Byte(byteValue), + shortValue, // ShortType + new Short(shortValue), + intValue, // IntegerType + new Integer(intValue), + longValue, // LongType + new Long(longValue), + floatValue, // FloatType + new Float(floatValue), + doubleValue, // DoubleType + new Double(doubleValue), + decimalValue, // DecimalType + booleanValue, // BooleanType + new Boolean(booleanValue), + stringValue, // StringType + binaryValue, // BinaryType + timestampValue, // TimestampType + null // null + ); + + Assert.assertEquals(byteValue, simpleRow.getByte(0)); + Assert.assertEquals(byteValue, simpleRow.get(0)); + Assert.assertEquals(byteValue, simpleRow.getByte(1)); + Assert.assertEquals(byteValue, simpleRow.get(1)); + Assert.assertEquals(shortValue, simpleRow.getShort(2)); + Assert.assertEquals(shortValue, simpleRow.get(2)); + Assert.assertEquals(shortValue, simpleRow.getShort(3)); + Assert.assertEquals(shortValue, simpleRow.get(3)); + Assert.assertEquals(intValue, simpleRow.getInt(4)); + Assert.assertEquals(intValue, simpleRow.get(4)); + Assert.assertEquals(intValue, simpleRow.getInt(5)); + Assert.assertEquals(intValue, simpleRow.get(5)); + Assert.assertEquals(longValue, simpleRow.getLong(6)); + Assert.assertEquals(longValue, simpleRow.get(6)); + Assert.assertEquals(longValue, simpleRow.getLong(7)); + Assert.assertEquals(longValue, simpleRow.get(7)); + // When we create the row, we do not do any conversion + // for a float/double value, so we just set the delta to 0. + Assert.assertEquals(floatValue, simpleRow.getFloat(8), 0); + Assert.assertEquals(floatValue, simpleRow.get(8)); + Assert.assertEquals(floatValue, simpleRow.getFloat(9), 0); + Assert.assertEquals(floatValue, simpleRow.get(9)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(10), 0); + Assert.assertEquals(doubleValue, simpleRow.get(10)); + Assert.assertEquals(doubleValue, simpleRow.getDouble(11), 0); + Assert.assertEquals(doubleValue, simpleRow.get(11)); + Assert.assertEquals(decimalValue, simpleRow.get(12)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(13)); + Assert.assertEquals(booleanValue, simpleRow.get(13)); + Assert.assertEquals(booleanValue, simpleRow.getBoolean(14)); + Assert.assertEquals(booleanValue, simpleRow.get(14)); + Assert.assertEquals(stringValue, simpleRow.getString(15)); + Assert.assertEquals(stringValue, simpleRow.get(15)); + Assert.assertEquals(binaryValue, simpleRow.get(16)); + Assert.assertEquals(timestampValue, simpleRow.get(17)); + Assert.assertEquals(true, simpleRow.isNullAt(18)); + Assert.assertEquals(null, simpleRow.get(18)); + } + + @Test + public void constructComplexRow() { + // Simple array + List simpleStringArray = Arrays.asList( + stringValue + " (1)", stringValue + " (2)", stringValue + "(3)"); + + // Simple map + Map simpleMap = new HashMap(); + simpleMap.put(stringValue + " (1)", longValue); + simpleMap.put(stringValue + " (2)", longValue - 1); + simpleMap.put(stringValue + " (3)", longValue - 2); + + // Simple struct + Row simpleStruct = Row.create( + doubleValue, stringValue, timestampValue, null); + + // Complex array + List> arrayOfMaps = Arrays.asList(simpleMap); + List arrayOfRows = Arrays.asList(simpleStruct); + + // Complex map + Map, Row> complexMap = new HashMap, Row>(); + complexMap.put(arrayOfRows, simpleStruct); + + // Complex struct + Row complexStruct = Row.create( + simpleStringArray, + simpleMap, + simpleStruct, + arrayOfMaps, + arrayOfRows, + complexMap, + null); + Assert.assertEquals(simpleStringArray, complexStruct.get(0)); + Assert.assertEquals(simpleMap, complexStruct.get(1)); + Assert.assertEquals(simpleStruct, complexStruct.get(2)); + Assert.assertEquals(arrayOfMaps, complexStruct.get(3)); + Assert.assertEquals(arrayOfRows, complexStruct.get(4)); + Assert.assertEquals(complexMap, complexStruct.get(5)); + Assert.assertEquals(null, complexStruct.get(6)); + + // A very complex row + Row complexRow = Row.create(arrayOfMaps, arrayOfRows, complexMap, complexStruct); + Assert.assertEquals(arrayOfMaps, complexRow.get(0)); + Assert.assertEquals(arrayOfRows, complexRow.get(1)); + Assert.assertEquals(complexMap, complexRow.get(2)); + Assert.assertEquals(complexStruct, complexRow.get(3)); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java new file mode 100644 index 0000000000000..4fbab15931fd3 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.util.List; +import java.util.ArrayList; + +import org.junit.Assert; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; + +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.api.java.types.DataType; +import org.apache.spark.sql.api.java.types.StructField; +import org.apache.spark.sql.test.TestSQLContext; +import org.junit.rules.ExpectedException; + +public class JavaSideDataTypeConversionSuite { + private transient JavaSparkContext javaCtx; + private transient JavaSQLContext javaSqlCtx; + + public void checkDataType(DataType javaDataType) { + org.apache.spark.sql.catalyst.types.DataType scalaDataType = + javaSqlCtx.asScalaDataType(javaDataType); + DataType actual = javaSqlCtx.asJavaDataType(scalaDataType); + Assert.assertEquals(javaDataType, actual); + } + + @Before + public void setUp() { + javaCtx = new JavaSparkContext(TestSQLContext.sparkContext()); + javaSqlCtx = new JavaSQLContext(javaCtx); + } + + @After + public void tearDown() { + javaCtx.stop(); + javaCtx = null; + } + + @Test + public void createDataTypes() { + // Simple DataTypes. + checkDataType(DataType.StringType); + checkDataType(DataType.BinaryType); + checkDataType(DataType.BooleanType); + checkDataType(DataType.TimestampType); + checkDataType(DataType.DecimalType); + checkDataType(DataType.DoubleType); + checkDataType(DataType.FloatType); + checkDataType(DataType.ByteType); + checkDataType(DataType.IntegerType); + checkDataType(DataType.LongType); + checkDataType(DataType.ShortType); + + // Simple ArrayType. + DataType simpleJavaArrayType = DataType.createArrayType(DataType.StringType, true); + checkDataType(simpleJavaArrayType); + + // Simple MapType. + DataType simpleJavaMapType = DataType.createMapType(DataType.StringType, DataType.LongType); + checkDataType(simpleJavaMapType); + + // Simple StructType. + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + simpleFields.add(DataType.createStructField("d", DataType.BinaryType, false)); + DataType simpleJavaStructType = DataType.createStructType(simpleFields); + checkDataType(simpleJavaStructType); + + // Complex StructType. + List complexFields = new ArrayList(); + complexFields.add(DataType.createStructField("simpleArray", simpleJavaArrayType, true)); + complexFields.add(DataType.createStructField("simpleMap", simpleJavaMapType, true)); + complexFields.add(DataType.createStructField("simpleStruct", simpleJavaStructType, true)); + complexFields.add(DataType.createStructField("boolean", DataType.BooleanType, false)); + DataType complexJavaStructType = DataType.createStructType(complexFields); + checkDataType(complexJavaStructType); + + // Complex ArrayType. + DataType complexJavaArrayType = DataType.createArrayType(complexJavaStructType, true); + checkDataType(complexJavaArrayType); + + // Complex MapType. + DataType complexJavaMapType = + DataType.createMapType(complexJavaStructType, complexJavaArrayType); + checkDataType(complexJavaMapType); + } + + @Test + public void illegalArgument() { + // ArrayType + try { + DataType.createArrayType(null, true); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + + // MapType + try { + DataType.createMapType(null, DataType.StringType); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createMapType(DataType.StringType, null); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createMapType(null, null); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + + // StructField + try { + DataType.createStructField(null, DataType.StringType, true); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createStructField("name", null, true); + } catch (IllegalArgumentException expectedException) { + } + try { + DataType.createStructField(null, null, true); + } catch (IllegalArgumentException expectedException) { + } + + // StructType + try { + DataType.createStructType(null); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("b", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + simpleFields.add(null); + DataType.createStructType(simpleFields); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + try { + List simpleFields = new ArrayList(); + simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); + simpleFields.add(DataType.createStructField("a", DataType.BooleanType, true)); + simpleFields.add(DataType.createStructField("c", DataType.LongType, true)); + DataType.createStructType(simpleFields); + Assert.fail(); + } catch (IllegalArgumentException expectedException) { + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala new file mode 100644 index 0000000000000..24dbe01ab2e75 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java + +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql._ +import org.apache.spark.sql.test.TestSQLContext +import org.scalatest.FunSuite + +class ScalaSideDataTypeConversionSuite extends FunSuite { + val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) + val javaSqlCtx = new JavaSQLContext(javaCtx) + + def checkDataType(scalaDataType: DataType) { + val javaDataType = javaSqlCtx.asJavaDataType(scalaDataType) + val actual = javaSqlCtx.asScalaDataType(javaDataType) + assert(scalaDataType === actual, s"Converted data type ${actual} " + + s"does not equal the expected data type ${scalaDataType}") + } + + test("convert data types") { + // Simple DataTypes. + checkDataType(StringType) + checkDataType(BinaryType) + checkDataType(BooleanType) + checkDataType(TimestampType) + checkDataType(DecimalType) + checkDataType(DoubleType) + checkDataType(FloatType) + checkDataType(ByteType) + checkDataType(IntegerType) + checkDataType(LongType) + checkDataType(ShortType) + + // Simple ArrayType. + val simpleScalaArrayType = ArrayType(StringType, true) + checkDataType(simpleScalaArrayType) + + // Simple MapType. + val simpleScalaMapType = MapType(StringType, LongType) + checkDataType(simpleScalaMapType) + + // Simple StructType. + val simpleScalaStructType = StructType( + StructField("a", DecimalType, false) :: + StructField("b", BooleanType, true) :: + StructField("c", LongType, true) :: + StructField("d", BinaryType, false) :: Nil) + checkDataType(simpleScalaStructType) + + // Complex StructType. + val complexScalaStructType = StructType( + StructField("simpleArray", simpleScalaArrayType, true) :: + StructField("simpleMap", simpleScalaMapType, true) :: + StructField("simpleStruct", simpleScalaStructType, true) :: + StructField("boolean", BooleanType, false) :: Nil) + checkDataType(complexScalaStructType) + + // Complex ArrayType. + val complexScalaArrayType = ArrayType(complexScalaStructType, true) + checkDataType(complexScalaArrayType) + + // Complex MapType. + val complexScalaMapType = MapType(complexScalaStructType, complexScalaArrayType) + checkDataType(complexScalaMapType) + } +} From b9f30711f42334f0c0f1790d7baeb1fd7979014a Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 24 Jul 2014 15:09:22 -0700 Subject: [PATCH 21/34] Java API for applySchema. --- .../spark/sql/api/java/types/DataType.java | 17 +- .../spark/sql/api/java/types/StructType.java | 6 +- .../org/apache/spark/sql/SQLContext.scala | 95 ++++++++++ .../org/apache/spark/sql/SchemaRDD.scala | 5 + .../org/apache/spark/sql/SchemaRDDLike.scala | 8 +- .../spark/sql/api/java/JavaSQLContext.scala | 60 ++++-- .../spark/sql/api/java/JavaSchemaRDD.scala | 5 + .../sql/api/java/JavaApplySchemaSuite.java | 172 ++++++++++++++++++ .../spark/sql/api/java/JavaRowSuite.java | 4 +- .../java/JavaSideDataTypeConversionSuite.java | 9 +- .../ScalaSideDataTypeConversionSuite.scala | 4 +- 11 files changed, 348 insertions(+), 37 deletions(-) create mode 100644 sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java index 7399bf9876f50..5f0fddcc94b67 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java @@ -134,11 +134,20 @@ public static StructField createStructField(String name, DataType dataType, bool } /** - * Creates a StructType with the given StructFields ({@code fields}). + * Creates a StructType with the given list of StructFields ({@code fields}). * @param fields * @return */ public static StructType createStructType(List fields) { + return createStructType(fields.toArray(new StructField[0])); + } + + /** + * Creates a StructType with the given StructField array ({@code fields}). + * @param fields + * @return + */ + public static StructType createStructType(StructField[] fields) { if (fields == null) { throw new IllegalArgumentException("fields should not be null."); } @@ -151,11 +160,11 @@ public static StructType createStructType(List fields) { distinctNames.add(field.getName()); } - if (distinctNames.size() != fields.size()) { - throw new IllegalArgumentException( - "fields should have distinct names."); + if (distinctNames.size() != fields.length) { + throw new IllegalArgumentException("fields should have distinct names."); } return new StructType(fields); } + } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java index cd397d42135f1..17142ff672822 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java @@ -22,13 +22,13 @@ /** * The data type representing Rows. - * A StructType object comprises a List of StructFields. + * A StructType object comprises an array of StructFields. */ public class StructType extends DataType { private StructField[] fields; - protected StructType(List fields) { - this.fields = fields.toArray(new StructField[0]); + protected StructType(StructField[] fields) { + this.fields = fields; } public StructField[] getFields() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8433b215cb3d5..11880a80443f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag +import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions @@ -420,4 +422,97 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) } + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + protected def asJavaStructField(scalaStructField: StructField): JStructField = { + org.apache.spark.sql.api.java.types.DataType.createStructField( + scalaStructField.name, + asJavaDataType(scalaStructField.dataType), + scalaStructField.nullable) + } + + /** + * Returns the equivalent DataType in Java for the given DataType in Scala. + */ + protected[sql] def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { + case StringType => + org.apache.spark.sql.api.java.types.DataType.StringType + case BinaryType => + org.apache.spark.sql.api.java.types.DataType.BinaryType + case BooleanType => + org.apache.spark.sql.api.java.types.DataType.BooleanType + case TimestampType => + org.apache.spark.sql.api.java.types.DataType.TimestampType + case DecimalType => + org.apache.spark.sql.api.java.types.DataType.DecimalType + case DoubleType => + org.apache.spark.sql.api.java.types.DataType.DoubleType + case FloatType => + org.apache.spark.sql.api.java.types.DataType.FloatType + case ByteType => + org.apache.spark.sql.api.java.types.DataType.ByteType + case IntegerType => + org.apache.spark.sql.api.java.types.DataType.IntegerType + case LongType => + org.apache.spark.sql.api.java.types.DataType.LongType + case ShortType => + org.apache.spark.sql.api.java.types.DataType.ShortType + + case arrayType: ArrayType => + org.apache.spark.sql.api.java.types.DataType.createArrayType( + asJavaDataType(arrayType.elementType), arrayType.containsNull) + case mapType: MapType => + org.apache.spark.sql.api.java.types.DataType.createMapType( + asJavaDataType(mapType.keyType), asJavaDataType(mapType.valueType)) + case structType: StructType => + org.apache.spark.sql.api.java.types.DataType.createStructType( + structType.fields.map(asJavaStructField).asJava) + } + + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + protected def asScalaStructField(javaStructField: JStructField): StructField = { + StructField( + javaStructField.getName, + asScalaDataType(javaStructField.getDataType), + javaStructField.isNullable) + } + + /** + * Returns the equivalent DataType in Scala for the given DataType in Java. + */ + protected[sql] def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { + case stringType: org.apache.spark.sql.api.java.types.StringType => + StringType + case binaryType: org.apache.spark.sql.api.java.types.BinaryType => + BinaryType + case booleanType: org.apache.spark.sql.api.java.types.BooleanType => + BooleanType + case timestampType: org.apache.spark.sql.api.java.types.TimestampType => + TimestampType + case decimalType: org.apache.spark.sql.api.java.types.DecimalType => + DecimalType + case doubleType: org.apache.spark.sql.api.java.types.DoubleType => + DoubleType + case floatType: org.apache.spark.sql.api.java.types.FloatType => + FloatType + case byteType: org.apache.spark.sql.api.java.types.ByteType => + ByteType + case integerType: org.apache.spark.sql.api.java.types.IntegerType => + IntegerType + case longType: org.apache.spark.sql.api.java.types.LongType => + LongType + case shortType: org.apache.spark.sql.api.java.types.ShortType => + ShortType + + case arrayType: org.apache.spark.sql.api.java.types.ArrayType => + ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull) + case mapType: org.apache.spark.sql.api.java.types.MapType => + MapType(asScalaDataType(mapType.getKeyType), asScalaDataType(mapType.getValueType)) + case structType: org.apache.spark.sql.api.java.types.StructType => + StructType(structType.getFields.map(asScalaStructField)) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 1c8c5276761d7..38c8c4bf5188b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -119,6 +119,11 @@ class SchemaRDD( override protected def getDependencies: Seq[Dependency[_]] = List(new OneToOneDependency(queryExecution.toRdd)) + /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). + * + * @group schema + */ + def schema: StructType = queryExecution.analyzed.schema // ======================================================================= // Query DSL diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala index cf9baa34eb2bc..af1318c557242 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala @@ -123,17 +123,11 @@ private[sql] trait SchemaRDDLike { def saveAsTable(tableName: String): Unit = sqlContext.executePlan(InsertIntoCreatedTable(None, tableName, logicalPlan)).toRdd - /** Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - def schema: StructType = queryExecution.analyzed.schema - /** Returns the schema as a string in the tree format. * * @group schema */ - def schemaString: String = schema.treeString + def schemaString: String = baseSchemaRDD.schema.treeString /** Prints out the schema. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 0f925cca07e25..2a50cee1d18a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -19,14 +19,17 @@ package org.apache.spark.sql.api.java import java.beans.Introspector +import scala.collection.JavaConverters._ + import org.apache.hadoop.conf.Configuration -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructType => JStructType} +import org.apache.spark.sql.api.java.types.{StructField => JStructField} import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} -import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} import org.apache.spark.util.Utils @@ -95,6 +98,20 @@ class JavaSQLContext(val sqlContext: SQLContext) { new JavaSchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) } + /** + * :: DeveloperApi :: + * Creates a JavaSchemaRDD from an RDD containing Rows by applying a schema to this RDD. + * It is important to make sure that the structure of every Row of the provided RDD matches the + * provided schema. Otherwise, there will be runtime exception. + */ + @DeveloperApi + def applySchema(rowRDD: JavaRDD[Row], schema: JStructType): JavaSchemaRDD = { + val scalaRowRDD = rowRDD.rdd.map(r => r.row) + val scalaSchema = sqlContext.asScalaDataType(schema).asInstanceOf[StructType] + val logicalPlan = SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD)) + new JavaSchemaRDD(sqlContext, logicalPlan) + } + /** * Loads a parquet file, returning the result as a [[JavaSchemaRDD]]. */ @@ -104,26 +121,45 @@ class JavaSQLContext(val sqlContext: SQLContext) { ParquetRelation(path, Some(sqlContext.sparkContext.hadoopConfiguration))) /** - * Loads a JSON file (one object per line), returning the result as a [[JavaSchemaRDD]]. + * Loads a JSON file (one object per line), returning the result as a JavaSchemaRDD. * It goes through the entire dataset once to determine the schema. - * - * @group userf */ def jsonFile(path: String): JavaSchemaRDD = jsonRDD(sqlContext.sparkContext.textFile(path)) + /** + * :: Experimental :: + * Loads a JSON file (one object per line) and applies the given schema, + * returning the result as a JavaSchemaRDD. + */ + @Experimental + def jsonFile(path: String, schema: JStructType): JavaSchemaRDD = + jsonRDD(sqlContext.sparkContext.textFile(path), schema) + /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[JavaSchemaRDD]]. + * [JavaSchemaRDD. * It goes through the entire dataset once to determine the schema. - * - * @group userf */ def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { - val schema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json, 1.0)) - val logicalPlan = - sqlContext.makeCustomRDDScan[String](json, schema, JsonRDD.jsonStringToRow(schema, _)) + val appliedScalaSchema = JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0)) + val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD)) + new JavaSchemaRDD(sqlContext, logicalPlan) + } + /** + * :: Experimental :: + * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, + * returning the result as a JavaSchemaRDD. + */ + @Experimental + def jsonRDD(json: JavaRDD[String], schema: JStructType): JavaSchemaRDD = { + val appliedScalaSchema = + Option(sqlContext.asScalaDataType(schema)).getOrElse( + JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[StructType] + val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) + val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD)) new JavaSchemaRDD(sqlContext, logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index 8fbf13b8b0150..e560648f1a46d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -22,6 +22,7 @@ import java.util.{List => JList} import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.sql.api.java.types.StructType import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.rdd.RDD @@ -53,6 +54,10 @@ class JavaSchemaRDD( override def toString: String = baseSchemaRDD.toString + /** Returns the schema of this JavaSchemaRDD (represented by a StructType). */ + def schema: StructType = + sqlContext.asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType] + // ======================================================================= // Base RDD functions that do NOT change schema // ======================================================================= diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java new file mode 100644 index 0000000000000..ead612e6d181e --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import java.io.File; +import java.io.Serializable; +import java.math.BigDecimal; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import com.google.common.io.Files; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.sql.api.java.types.DataType; +import org.apache.spark.sql.api.java.types.StructField; +import org.apache.spark.sql.api.java.types.StructType; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.test.TestSQLContext; + +// The test suite itself is Serializable so that anonymous Function implementations can be +// serialized, as an alternative to converting these anonymous classes to static inner classes; +// see http://stackoverflow.com/questions/758570/. +public class JavaApplySchemaSuite implements Serializable { + private transient JavaSparkContext javaCtx; + private transient JavaSQLContext javaSqlCtx; + private transient File tempDir; + + @Before + public void setUp() { + javaCtx = new JavaSparkContext(TestSQLContext.sparkContext()); + javaSqlCtx = new JavaSQLContext(javaCtx); + tempDir = Files.createTempDir(); + tempDir.deleteOnExit(); + } + + @After + public void tearDown() { + javaCtx.stop(); + javaCtx = null; + javaSqlCtx = null; + } + + public static class Person implements Serializable { + private String name; + private int age; + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public int getAge() { + return age; + } + + public void setAge(int age) { + this.age = age; + } + } + + @Test + public void applySchema() { + List personList = new ArrayList(2); + Person person1 = new Person(); + person1.setName("Michael"); + person1.setAge(29); + personList.add(person1); + Person person2 = new Person(); + person2.setName("Yin"); + person2.setAge(28); + personList.add(person2); + + JavaRDD rowRDD = javaCtx.parallelize(personList).map( + new Function() { + public Row call(Person person) throws Exception { + return Row.create(person.getName(), person.getAge()); + } + }); + + List fields = new ArrayList(2); + fields.add(DataType.createStructField("name", DataType.StringType, false)); + fields.add(DataType.createStructField("age", DataType.IntegerType, false)); + StructType schema = DataType.createStructType(fields); + + JavaSchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD, schema); + schemaRDD.registerAsTable("people"); + List actual = javaSqlCtx.sql("SELECT * FROM people").collect(); + + List expected = new ArrayList(2); + expected.add(Row.create("Michael", 29)); + expected.add(Row.create("Yin", 28)); + + Assert.assertEquals(expected, actual); + } + + @Test + public void applySchemaToJSON() { + JavaRDD jsonRDD = javaCtx.parallelize(Arrays.asList( + "{\"string\":\"this is a simple string.\", \"integer\":10, \"long\":21474836470, " + + "\"bigInteger\":92233720368547758070, \"double\":1.7976931348623157E308, " + + "\"boolean\":true, \"null\":null}", + "{\"string\":\"this is another simple string.\", \"integer\":11, \"long\":21474836469, " + + "\"bigInteger\":92233720368547758069, \"double\":1.7976931348623157E305, " + + "\"boolean\":false, \"null\":null}")); + List fields = new ArrayList(7); + fields.add(DataType.createStructField("bigInteger", DataType.DecimalType, true)); + fields.add(DataType.createStructField("boolean", DataType.BooleanType, true)); + fields.add(DataType.createStructField("double", DataType.DoubleType, true)); + fields.add(DataType.createStructField("integer", DataType.IntegerType, true)); + fields.add(DataType.createStructField("long", DataType.LongType, true)); + fields.add(DataType.createStructField("null", DataType.StringType, true)); + fields.add(DataType.createStructField("string", DataType.StringType, true)); + StructType expectedSchema = DataType.createStructType(fields); + List expectedResult = new ArrayList(1); + expectedResult.add( + Row.create( + new BigDecimal("92233720368547758070"), + true, + 1.7976931348623157E308, + 10, + 21474836470L, + null, + "this is a simple string.")); + expectedResult.add( + Row.create( + new BigDecimal("92233720368547758069"), + false, + 1.7976931348623157E305, + 11, + 21474836469L, + null, + "this is another simple string.")); + + JavaSchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD); + StructType actualSchema1 = schemaRDD1.schema(); + Assert.assertEquals(expectedSchema, actualSchema1); + schemaRDD1.registerAsTable("jsonTable1"); + List actual1 = javaSqlCtx.sql("select * from jsonTable1").collect(); + Assert.assertEquals(expectedResult, actual1); + + JavaSchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema); + StructType actualSchema2 = schemaRDD2.schema(); + Assert.assertEquals(expectedSchema, actualSchema2); + schemaRDD1.registerAsTable("jsonTable2"); + List actual2 = javaSqlCtx.sql("select * from jsonTable2").collect(); + Assert.assertEquals(expectedResult, actual2); + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java index 52d07b5425cc3..7391b2ae832eb 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java @@ -25,7 +25,7 @@ import java.util.Map; import org.junit.Assert; -import org.junit.Before; +import org.junit.BeforeClass; import org.junit.Test; public class JavaRowSuite { @@ -41,7 +41,7 @@ public class JavaRowSuite { private byte[] binaryValue; private Timestamp timestampValue; - @Before + @BeforeClass public void setUp() { byteValue = (byte)127; shortValue = (short)32767; diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java index 4fbab15931fd3..eccec65c667d2 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -38,8 +38,8 @@ public class JavaSideDataTypeConversionSuite { public void checkDataType(DataType javaDataType) { org.apache.spark.sql.catalyst.types.DataType scalaDataType = - javaSqlCtx.asScalaDataType(javaDataType); - DataType actual = javaSqlCtx.asJavaDataType(scalaDataType); + javaSqlCtx.sqlContext().asScalaDataType(javaDataType); + DataType actual = javaSqlCtx.sqlContext().asJavaDataType(scalaDataType); Assert.assertEquals(javaDataType, actual); } @@ -147,11 +147,6 @@ public void illegalArgument() { } // StructType - try { - DataType.createStructType(null); - Assert.fail(); - } catch (IllegalArgumentException expectedException) { - } try { List simpleFields = new ArrayList(); simpleFields.add(DataType.createStructField("a", DataType.DecimalType, false)); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index 24dbe01ab2e75..56102c2d5b8fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -27,8 +27,8 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { val javaSqlCtx = new JavaSQLContext(javaCtx) def checkDataType(scalaDataType: DataType) { - val javaDataType = javaSqlCtx.asJavaDataType(scalaDataType) - val actual = javaSqlCtx.asScalaDataType(javaDataType) + val javaDataType = javaSqlCtx.sqlContext.asJavaDataType(scalaDataType) + val actual = javaSqlCtx.sqlContext.asScalaDataType(javaDataType) assert(scalaDataType === actual, s"Converted data type ${actual} " + s"does not equal the expected data type ${scalaDataType}") } From d48fc7b6bafa85b2df02ce491de88e4de37e886c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 24 Jul 2014 20:56:52 -0700 Subject: [PATCH 22/34] Minor updates. --- .../spark/sql/api/java/JavaApplySchemaSuite.java | 10 ++-------- .../org/apache/spark/sql/api/java/JavaRowSuite.java | 4 ++-- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index ead612e6d181e..8ee4591105010 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -17,14 +17,12 @@ package org.apache.spark.sql.api.java; -import java.io.File; import java.io.Serializable; import java.math.BigDecimal; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import com.google.common.io.Files; import org.junit.After; import org.junit.Assert; import org.junit.Before; @@ -36,7 +34,6 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.test.TestSQLContext; // The test suite itself is Serializable so that anonymous Function implementations can be // serialized, as an alternative to converting these anonymous classes to static inner classes; @@ -44,14 +41,11 @@ public class JavaApplySchemaSuite implements Serializable { private transient JavaSparkContext javaCtx; private transient JavaSQLContext javaSqlCtx; - private transient File tempDir; @Before public void setUp() { - javaCtx = new JavaSparkContext(TestSQLContext.sparkContext()); + javaCtx = new JavaSparkContext("local", "JavaApplySchemaSuite"); javaSqlCtx = new JavaSQLContext(javaCtx); - tempDir = Files.createTempDir(); - tempDir.deleteOnExit(); } @After @@ -135,7 +129,7 @@ public void applySchemaToJSON() { fields.add(DataType.createStructField("null", DataType.StringType, true)); fields.add(DataType.createStructField("string", DataType.StringType, true)); StructType expectedSchema = DataType.createStructType(fields); - List expectedResult = new ArrayList(1); + List expectedResult = new ArrayList(2); expectedResult.add( Row.create( new BigDecimal("92233720368547758070"), diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java index 7391b2ae832eb..52d07b5425cc3 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java @@ -25,7 +25,7 @@ import java.util.Map; import org.junit.Assert; -import org.junit.BeforeClass; +import org.junit.Before; import org.junit.Test; public class JavaRowSuite { @@ -41,7 +41,7 @@ public class JavaRowSuite { private byte[] binaryValue; private Timestamp timestampValue; - @BeforeClass + @Before public void setUp() { byteValue = (byte)127; shortValue = (short)32767; From 246da964be0b1cded3163dcb6ae9e230297605c5 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 24 Jul 2014 23:02:59 -0700 Subject: [PATCH 23/34] Add java data type APIs to javadoc index. --- project/SparkBuild.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5461d25d72d7e..52fd61d2234b7 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -300,7 +300,7 @@ object Unidoc { "mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration", "mllib.tree.impurity", "mllib.tree.model", "mllib.util" ), - "-group", "Spark SQL", packageList("sql.api.java", "sql.hive.api.java"), + "-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"), "-noqualifier", "java.lang" ) ) From 1d9339576f3a261ad5bda14f084cc55b321f4565 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 27 Jul 2014 17:06:24 -0700 Subject: [PATCH 24/34] Python APIs. --- python/pyspark/sql.py | 510 +++++++++++++++++- .../org/apache/spark/sql/SQLContext.scala | 22 + 2 files changed, 518 insertions(+), 14 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index cb83e89176823..45ffd0756125d 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -20,8 +20,412 @@ from py4j.protocol import Py4JError -__all__ = ["SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] +__all__ = [ + "StringType", "BinaryType", "BooleanType", "DecimalType", "DoubleType", + "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", + "ArrayType", "MapType", "StructField", "StructType", + "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] +class PrimitiveTypeSingleton(type): + _instances = {} + def __call__(cls): + if cls not in cls._instances: + cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + return cls._instances[cls] + +class StringType(object): + """Spark SQL StringType + + The data type representing string values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "StringType" + +class BinaryType(object): + """Spark SQL BinaryType + + The data type representing bytes values and bytearray values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "BinaryType" + +class BooleanType(object): + """Spark SQL BooleanType + + The data type representing bool values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "BooleanType" + +class TimestampType(object): + """Spark SQL TimestampType + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "TimestampType" + +class DecimalType(object): + """Spark SQL DecimalType + + The data type representing decimal.Decimal values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "DecimalType" + +class DoubleType(object): + """Spark SQL DoubleType + + The data type representing float values. Because a float value + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "DoubleType" + +class FloatType(object): + """Spark SQL FloatType + + For PySpark, please use L{DoubleType} instead of using L{FloatType}. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "FloatType" + +class ByteType(object): + """Spark SQL ByteType + + For PySpark, please use L{IntegerType} instead of using L{ByteType}. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "ByteType" + +class IntegerType(object): + """Spark SQL IntegerType + + The data type representing int values. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "IntegerType" + +class LongType(object): + """Spark SQL LongType + + The data type representing long values. If the any value is beyond the range of + [-9223372036854775808, 9223372036854775807], please use DecimalType. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "LongType" + +class ShortType(object): + """Spark SQL ShortType + + For PySpark, please use L{IntegerType} instead of using L{ShortType}. + + """ + __metaclass__ = PrimitiveTypeSingleton + + def _get_scala_type_string(self): + return "ShortType" + +class ArrayType(object): + """Spark SQL ArrayType + + The data type representing list values. + + """ + def __init__(self, elementType, containsNull): + """ + Create an ArrayType + :param elementType: the data type of elements. + :param containsNull: indicates whether the list contains null values. + :return: + >>> ArrayType(StringType, True) == ArrayType(StringType, False) + False + >>> ArrayType(StringType, True) == ArrayType(StringType, True) + True + """ + self.elementType = elementType + self.containsNull = containsNull + + def _get_scala_type_string(self): + return "ArrayType(" + self.elementType._get_scala_type_string() + "," + \ + str(self.containsNull).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and \ + self.elementType == other.elementType and \ + self.containsNull == other.containsNull) + + def __ne__(self, other): + return not self.__eq__(other) + + +class MapType(object): + """Spark SQL MapType + + The data type representing dict values. + + """ + def __init__(self, keyType, valueType): + """ + Create a MapType + :param keyType: the data type of keys. + :param valueType: the data type of values. + :return: + >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType) + True + >>> MapType(StringType, IntegerType) == MapType(StringType, FloatType) + False + """ + self.keyType = keyType + self.valueType = valueType + + def _get_scala_type_string(self): + return "MapType(" + self.keyType._get_scala_type_string() + "," + \ + self.valueType._get_scala_type_string() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and \ + self.keyType == other.keyType and \ + self.valueType == other.valueType) + + def __ne__(self, other): + return not self.__eq__(other) + +class StructField(object): + """Spark SQL StructField + + Represents a field in a StructType. + + """ + def __init__(self, name, dataType, nullable): + """ + Create a StructField + :param name: the name of this field. + :param dataType: the data type of this field. + :param nullable: indicates whether values of this field can be null. + :return: + >>> StructField("f1", StringType, True) == StructField("f1", StringType, True) + True + >>> StructField("f1", StringType, True) == StructField("f2", StringType, True) + False + """ + self.name = name + self.dataType = dataType + self.nullable = nullable + + def _get_scala_type_string(self): + return "StructField(" + self.name + "," + \ + self.dataType._get_scala_type_string() + "," + \ + str(self.nullable).lower() + ")" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and \ + self.name == other.name and \ + self.dataType == other.dataType and \ + self.nullable == other.nullable) + + def __ne__(self, other): + return not self.__eq__(other) + +class StructType(object): + """Spark SQL StructType + + The data type representing tuple values. + + """ + def __init__(self, fields): + """ + Create a StructType + :param fields: + :return: + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType([StructField("f1", StringType, True)]) + >>> struct2 = StructType([StructField("f1", StringType, True), + ... [StructField("f2", IntegerType, False)]]) + >>> struct1 == struct2 + False + """ + self.fields = fields + + def _get_scala_type_string(self): + return "StructType(List(" + \ + ",".join([field._get_scala_type_string() for field in self.fields]) + "))" + + def __eq__(self, other): + return (isinstance(other, self.__class__) and \ + self.fields == other.fields) + + def __ne__(self, other): + return not self.__eq__(other) + +def _parse_datatype_list(datatype_list_string): + index = 0 + datatype_list = [] + start = 0 + depth = 0 + while index < len(datatype_list_string): + if depth == 0 and datatype_list_string[index] == ",": + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + start = index + 1 + elif datatype_list_string[index] == "(": + depth += 1 + elif datatype_list_string[index] == ")": + depth -= 1 + + index += 1 + + # Handle the last data type + datatype_string = datatype_list_string[start:index].strip() + datatype_list.append(_parse_datatype_string(datatype_string)) + return datatype_list + +def _parse_datatype_string(datatype_string): + """Parses the given data type string. + :param datatype_string: + :return: + >>> def check_datatype(datatype): + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype._get_scala_type_string()) + ... python_datatype = _parse_datatype_string(scala_datatype.toString()) + ... return datatype == python_datatype + >>> check_datatype(StringType()) + True + >>> check_datatype(BinaryType()) + True + >>> check_datatype(BooleanType()) + True + >>> check_datatype(TimestampType()) + True + >>> check_datatype(DecimalType()) + True + >>> check_datatype(DoubleType()) + True + >>> check_datatype(FloatType()) + True + >>> check_datatype(ByteType()) + True + >>> check_datatype(IntegerType()) + True + >>> check_datatype(LongType()) + True + >>> check_datatype(ShortType()) + True + >>> # Simple ArrayType. + >>> simple_arraytype = ArrayType(StringType(), True) + >>> check_datatype(simple_arraytype) + True + >>> # Simple MapType. + >>> simple_maptype = MapType(StringType(), LongType()) + >>> check_datatype(simple_maptype) + True + >>> # Simple StructType. + >>> simple_structtype = StructType([ + ... StructField("a", DecimalType(), False), + ... StructField("b", BooleanType(), True), + ... StructField("c", LongType(), True), + ... StructField("d", BinaryType(), False)]) + >>> check_datatype(simple_structtype) + True + >>> # Complex StructType. + >>> complex_structtype = StructType([ + ... StructField("simpleArray", simple_arraytype, True), + ... StructField("simpleMap", simple_maptype, True), + ... StructField("simpleStruct", simple_structtype, True), + ... StructField("boolean", BooleanType(), False)]) + >>> check_datatype(complex_structtype) + True + >>> # Complex ArrayType. + >>> complex_arraytype = ArrayType(complex_structtype, True) + >>> check_datatype(complex_arraytype) + True + >>> # Complex MapType. + >>> complex_maptype = MapType(complex_structtype, complex_arraytype) + >>> check_datatype(complex_maptype) + True + """ + left_bracket_index = datatype_string.find("(") + if left_bracket_index == -1: + # It is a primitive type. + left_bracket_index = len(datatype_string) + type_or_field = datatype_string[:left_bracket_index] + rest_part = datatype_string[left_bracket_index+1:len(datatype_string)-1].strip() + if type_or_field == "StringType": + return StringType() + elif type_or_field == "BinaryType": + return BinaryType() + elif type_or_field == "BooleanType": + return BooleanType() + elif type_or_field == "TimestampType": + return TimestampType() + elif type_or_field == "DecimalType": + return DecimalType() + elif type_or_field == "DoubleType": + return DoubleType() + elif type_or_field == "FloatType": + return FloatType() + elif type_or_field == "ByteType": + return ByteType() + elif type_or_field == "IntegerType": + return IntegerType() + elif type_or_field == "LongType": + return LongType() + elif type_or_field == "ShortType": + return ShortType() + elif type_or_field == "ArrayType": + last_comma_index = rest_part.rfind(",") + containsNull = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + containsNull = False + elementType = _parse_datatype_string(rest_part[:last_comma_index].strip()) + return ArrayType(elementType, containsNull) + elif type_or_field == "MapType": + keyType, valueType = _parse_datatype_list(rest_part.strip()) + return MapType(keyType, valueType) + elif type_or_field == "StructField": + first_comma_index = rest_part.find(",") + name = rest_part[:first_comma_index].strip() + last_comma_index = rest_part.rfind(",") + nullable = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + nullable = False + dataType = _parse_datatype_string( + rest_part[first_comma_index+1:last_comma_index].strip()) + return StructField(name, dataType, nullable) + elif type_or_field == "StructType": + # rest_part should be in the format like + # List(StructField(field1,IntegerType,false)). + field_list_string = rest_part[rest_part.find("(")+1:-1] + fields = _parse_datatype_list(field_list_string) + return StructType(fields) class SQLContext: """Main entry point for SparkSQL functionality. @@ -107,6 +511,24 @@ def inferSchema(self, rdd): srdd = self._ssql_ctx.inferSchema(jrdd.rdd()) return SchemaRDD(srdd, self) + def applySchema(self, rdd, schema): + """Applies the given schema to the given RDD of L{dict}s. + :param rdd: + :param schema: + :return: + >>> schema = StructType([StructField("field1", IntegerType(), False), + ... StructField("field2", StringType(), False)]) + >>> srdd = sqlCtx.applySchema(rdd, schema) + >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd2 = sqlCtx.sql("SELECT * from table1") + >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, + ... {"field1" : 3, "field2": "row3"}] + True + """ + jrdd = self._pythonToJavaMap(rdd._jrdd) + srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema._get_scala_type_string()) + return SchemaRDD(srdd, self) + def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -137,10 +559,11 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path): - """Loads a text file storing one JSON object per line, - returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonFile(self, path, schema = None): + """Loads a text file storing one JSON object per line as a L{SchemaRDD}. + + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it goes through the entire dataset once to determine the schema. >>> import tempfile, shutil >>> jsonFile = tempfile.mkdtemp() @@ -149,8 +572,8 @@ def jsonFile(self, path): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> srdd1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() == [ @@ -158,16 +581,45 @@ def jsonFile(self, path): ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] True + >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + >>> srdd4.collect() == [ + ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, + ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, + ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] + True + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + >>> srdd6.collect() == [ + ... {"f1": "row1", "f2": None, "f3": None}, + ... {"f1": None, "f2": [10, 11], "f3": 10}, + ... {"f1": "row3", "f2": [], "f3": None}] + True """ - jschema_rdd = self._ssql_ctx.jsonFile(path) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonFile(path) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema._get_scala_type_string()) + jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(jschema_rdd, self) - def jsonRDD(self, rdd): - """Loads an RDD storing one JSON object per string, returning the result as a L{SchemaRDD}. - It goes through the entire dataset once to determine the schema. + def jsonRDD(self, rdd, schema = None): + """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. - >>> srdd = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + If the schema is provided, applies the given schema to this JSON dataset. + Otherwise, it goes through the entire dataset once to determine the schema. + + >>> srdd1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(srdd1, "table1") >>> srdd2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table1") >>> srdd2.collect() == [ @@ -175,6 +627,29 @@ def jsonRDD(self, rdd): ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] True + >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) + >>> sqlCtx.registerRDDAsTable(srdd3, "table2") + >>> srdd4 = sqlCtx.sql( + ... "SELECT field1 AS f1, field2 as f2, field3 as f3, field6 as f4 from table2") + >>> srdd4.collect() == [ + ... {"f1":1, "f2":"row1", "f3":{"field4":11, "field5": None}, "f4":None}, + ... {"f1":2, "f2":None, "f3":{"field4":22, "field5": [10, 11]}, "f4":[{"field7": "row2"}]}, + ... {"f1":None, "f2":"row3", "f3":{"field4":33, "field5": []}, "f4":None}] + True + >>> schema = StructType([ + ... StructField("field2", StringType(), True), + ... StructField("field3", + ... StructType([ + ... StructField("field5", ArrayType(IntegerType(), False), True)]), False)]) + >>> srdd5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(srdd5, "table3") + >>> srdd6 = sqlCtx.sql( + ... "SELECT field2 AS f1, field3.field5 as f2, field3.field5[0] as f3 from table3") + >>> srdd6.collect() == [ + ... {"f1": "row1", "f2": None, "f3": None}, + ... {"f1": None, "f2": [10, 11], "f3": 10}, + ... {"f1": "row3", "f2": [], "f3": None}] + True """ def func(split, iterator): for x in iterator: @@ -184,7 +659,11 @@ def func(split, iterator): keyed = PipelinedRDD(rdd, func) keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + if schema is None: + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) + else: + scala_datatype = self._ssql_ctx.parseDataType(schema._get_scala_type_string()) + jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(jschema_rdd, self) def sql(self, sqlQuery): @@ -387,6 +866,9 @@ def saveAsTable(self, tableName): """Creates a new table with the contents of this SchemaRDD.""" self._jschema_rdd.saveAsTable(tableName) + def schema(self): + return _parse_datatype_string(self._jschema_rdd.schema().toString()) + def schemaString(self): """Returns the output schema in the tree format.""" return self._jschema_rdd.schemaString() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 11880a80443f3..e358f00f8d852 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -105,6 +105,28 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, logicalPlan) } + /** + * Parses the data type in our internal string representation. The data type string should + * have the same format as the one generate by `toString` in scala. + */ + private[sql] def parseDataType(dataTypeString: String): DataType = { + val parser = org.apache.spark.sql.catalyst.types.DataType + parser(dataTypeString) + } + + /** + * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. + */ + private[sql] def applySchema(rdd: RDD[Map[String, _]], schemaString: String): SchemaRDD = { + val schema = parseDataType(schemaString).asInstanceOf[StructType] + val rowRdd = rdd.mapPartitions { iter => + iter.map { map => + new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row + } + } + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))) + } + /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. * From 3edb3aee7e86ecc5e1dd91ed054a3d336f5ecdfb Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sun, 27 Jul 2014 17:23:40 -0700 Subject: [PATCH 25/34] Python doc. --- python/pyspark/sql.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 45ffd0756125d..ac5d2ef15ab2f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -67,8 +67,7 @@ def _get_scala_type_string(self): return "BooleanType" class TimestampType(object): - """Spark SQL TimestampType - """ + """Spark SQL TimestampType""" __metaclass__ = PrimitiveTypeSingleton def _get_scala_type_string(self): @@ -159,11 +158,12 @@ class ArrayType(object): """ def __init__(self, elementType, containsNull): - """ - Create an ArrayType + """Creates an ArrayType + :param elementType: the data type of elements. :param containsNull: indicates whether the list contains null values. :return: + >>> ArrayType(StringType, True) == ArrayType(StringType, False) False >>> ArrayType(StringType, True) == ArrayType(StringType, True) @@ -192,11 +192,11 @@ class MapType(object): """ def __init__(self, keyType, valueType): - """ - Create a MapType + """Creates a MapType :param keyType: the data type of keys. :param valueType: the data type of values. :return: + >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType) True >>> MapType(StringType, IntegerType) == MapType(StringType, FloatType) @@ -224,12 +224,12 @@ class StructField(object): """ def __init__(self, name, dataType, nullable): - """ - Create a StructField + """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. :param nullable: indicates whether values of this field can be null. :return: + >>> StructField("f1", StringType, True) == StructField("f1", StringType, True) True >>> StructField("f1", StringType, True) == StructField("f2", StringType, True) @@ -260,10 +260,10 @@ class StructType(object): """ def __init__(self, fields): - """ - Create a StructType + """Creates a StructType :param fields: :return: + >>> struct1 = StructType([StructField("f1", StringType, True)]) >>> struct2 = StructType([StructField("f1", StringType, True)]) >>> struct1 == struct2 @@ -313,6 +313,7 @@ def _parse_datatype_string(datatype_string): """Parses the given data type string. :param datatype_string: :return: + >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype._get_scala_type_string()) ... python_datatype = _parse_datatype_string(scala_datatype.toString()) @@ -516,6 +517,7 @@ def applySchema(self, rdd, schema): :param rdd: :param schema: :return: + >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) >>> srdd = sqlCtx.applySchema(rdd, schema) @@ -867,6 +869,7 @@ def saveAsTable(self, tableName): self._jschema_rdd.saveAsTable(tableName) def schema(self): + """Returns the schema of this SchemaRDD (represented by a L{StructType}).""" return _parse_datatype_string(self._jschema_rdd.schema().toString()) def schemaString(self): From 1cb35fec81a393077aa8e2c47cf412b89b3a365c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 28 Jul 2014 15:21:49 -0700 Subject: [PATCH 26/34] Add "valueContainsNull" to MapType. --- python/pyspark/sql.py | 24 +++-- .../catalyst/expressions/complexTypes.scala | 2 +- .../sql/catalyst/expressions/generators.scala | 4 +- .../spark/sql/catalyst/types/dataTypes.scala | 15 ++- .../spark/sql/api/java/types/DataType.java | 25 ++++- .../spark/sql/api/java/types/MapType.java | 10 +- .../org/apache/spark/sql/SQLContext.scala | 94 ------------------- .../spark/sql/parquet/ParquetConverter.scala | 4 +- .../sql/parquet/ParquetTableSupport.scala | 2 +- .../spark/sql/parquet/ParquetTypes.scala | 6 +- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../spark/sql/hive/HiveInspectors.scala | 2 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 2 +- 13 files changed, 76 insertions(+), 118 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index ac5d2ef15ab2f..c0d086acbe955 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -191,28 +191,32 @@ class MapType(object): The data type representing dict values. """ - def __init__(self, keyType, valueType): + def __init__(self, keyType, valueType, valueContainsNull=True): """Creates a MapType :param keyType: the data type of keys. :param valueType: the data type of values. + :param valueContainsNull: indicates whether values contains null values. :return: - >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType) + >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True) True - >>> MapType(StringType, IntegerType) == MapType(StringType, FloatType) + >>> MapType(StringType, IntegerType, False) == MapType(StringType, FloatType) False """ self.keyType = keyType self.valueType = valueType + self.valueContainsNull = valueContainsNull def _get_scala_type_string(self): return "MapType(" + self.keyType._get_scala_type_string() + "," + \ - self.valueType._get_scala_type_string() + ")" + self.valueType._get_scala_type_string() + "," + \ + str(self.valueContainsNull).lower() + ")" def __eq__(self, other): return (isinstance(other, self.__class__) and \ self.keyType == other.keyType and \ - self.valueType == other.valueType) + self.valueType == other.valueType and \ + self.valueContainsNull == other.valueContainsNull) def __ne__(self, other): return not self.__eq__(other) @@ -369,7 +373,7 @@ def _parse_datatype_string(datatype_string): >>> check_datatype(complex_arraytype) True >>> # Complex MapType. - >>> complex_maptype = MapType(complex_structtype, complex_arraytype) + >>> complex_maptype = MapType(complex_structtype, complex_arraytype, False) >>> check_datatype(complex_maptype) True """ @@ -409,8 +413,12 @@ def _parse_datatype_string(datatype_string): elementType = _parse_datatype_string(rest_part[:last_comma_index].strip()) return ArrayType(elementType, containsNull) elif type_or_field == "MapType": - keyType, valueType = _parse_datatype_list(rest_part.strip()) - return MapType(keyType, valueType) + last_comma_index = rest_part.rfind(",") + valueContainsNull = True + if rest_part[last_comma_index+1:].strip().lower() == "false": + valueContainsNull = False + keyType, valueType = _parse_datatype_list(rest_part[:last_comma_index].strip()) + return MapType(keyType, valueType, valueContainsNull) elif type_or_field == "StructField": first_comma_index = rest_part.find(",") name = rest_part[:first_comma_index].strip() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index f13a6d5f98382..68e20e38a09cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -32,7 +32,7 @@ case class GetItem(child: Expression, ordinal: Expression) extends Expression { override def references = children.flatMap(_.references).toSet def dataType = child.dataType match { case ArrayType(dt, _) => dt - case MapType(_, vt) => vt + case MapType(_, vt, _) => vt } override lazy val resolved = childrenResolved && diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 0a8d4dd718329..422839dab770d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -85,7 +85,7 @@ case class Explode(attributeNames: Seq[String], child: Expression) private lazy val elementTypes = child.dataType match { case ArrayType(et, _) => et :: Nil - case MapType(kt,vt) => kt :: vt :: Nil + case MapType(kt,vt, _) => kt :: vt :: Nil } // TODO: Move this pattern into Generator. @@ -105,7 +105,7 @@ case class Explode(attributeNames: Seq[String], child: Expression) case ArrayType(_, _) => val inputArray = child.eval(input).asInstanceOf[Seq[Any]] if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) - case MapType(_, _) => + case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any,Any]] if (inputMap == null) Nil else inputMap.map { case (k,v) => new GenericRow(Array(k,v)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 93ffee327cefe..13eff8ec87bc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -53,8 +53,8 @@ object DataType extends RegexParsers { } protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType <~ ")" ^^ { - case t1 ~ _ ~ t2 => MapType(t1, t2) + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) } protected lazy val structField: Parser[StructField] = @@ -344,7 +344,16 @@ case class StructType(fields: Seq[StructField]) extends DataType { def simpleString: String = "struct" } -case class MapType(keyType: DataType, valueType: DataType) extends DataType { +object MapType { + /** + * Construct a [[MapType]] object with the given key type and value type. + * The `valueContainsNull` is true. + */ + def apply(keyType: DataType, valueType: DataType): MapType = + MapType(keyType: DataType, valueType: DataType, true) +} + +case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") builder.append(s"${prefix}-- value: ${valueType.simpleString}\n") diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java index 5f0fddcc94b67..c67287dea8ba6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java @@ -111,7 +111,30 @@ public static MapType createMapType(DataType keyType, DataType valueType) { throw new IllegalArgumentException("valueType should not be null."); } - return new MapType(keyType, valueType); + return new MapType(keyType, valueType, true); + } + + /** + * Creates a MapType by specifying the data type of keys ({@code keyType}), the data type of + * values ({@code keyType}), and whether values contain any null value + * ({@code valueContainsNull}). + * @param keyType + * @param valueType + * @param valueContainsNull + * @return + */ + public static MapType createMapType( + DataType keyType, + DataType valueType, + boolean valueContainsNull) { + if (keyType == null) { + throw new IllegalArgumentException("keyType should not be null."); + } + if (valueType == null) { + throw new IllegalArgumentException("valueType should not be null."); + } + + return new MapType(keyType, valueType, valueContainsNull); } /** diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java index d946d967a33fc..d116241a0e407 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java @@ -26,10 +26,12 @@ public class MapType extends DataType { private DataType keyType; private DataType valueType; + private boolean valueContainsNull; - protected MapType(DataType keyType, DataType valueType) { + protected MapType(DataType keyType, DataType valueType, boolean valueContainsNull) { this.keyType = keyType; this.valueType = valueType; + this.valueContainsNull = valueContainsNull; } public DataType getKeyType() { @@ -40,6 +42,10 @@ public DataType getValueType() { return valueType; } + public boolean isValueContainsNull() { + return valueContainsNull; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -47,6 +53,7 @@ public boolean equals(Object o) { MapType mapType = (MapType) o; + if (valueContainsNull != mapType.valueContainsNull) return false; if (!keyType.equals(mapType.keyType)) return false; if (!valueType.equals(mapType.valueType)) return false; @@ -57,6 +64,7 @@ public boolean equals(Object o) { public int hashCode() { int result = keyType.hashCode(); result = 31 * result + valueType.hashCode(); + result = 31 * result + (valueContainsNull ? 1 : 0); return result; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e358f00f8d852..cb48b689903c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -443,98 +443,4 @@ class SQLContext(@transient val sparkContext: SparkContext) } new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema, rowRdd))) } - - /** - * Returns the equivalent StructField in Scala for the given StructField in Java. - */ - protected def asJavaStructField(scalaStructField: StructField): JStructField = { - org.apache.spark.sql.api.java.types.DataType.createStructField( - scalaStructField.name, - asJavaDataType(scalaStructField.dataType), - scalaStructField.nullable) - } - - /** - * Returns the equivalent DataType in Java for the given DataType in Scala. - */ - protected[sql] def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { - case StringType => - org.apache.spark.sql.api.java.types.DataType.StringType - case BinaryType => - org.apache.spark.sql.api.java.types.DataType.BinaryType - case BooleanType => - org.apache.spark.sql.api.java.types.DataType.BooleanType - case TimestampType => - org.apache.spark.sql.api.java.types.DataType.TimestampType - case DecimalType => - org.apache.spark.sql.api.java.types.DataType.DecimalType - case DoubleType => - org.apache.spark.sql.api.java.types.DataType.DoubleType - case FloatType => - org.apache.spark.sql.api.java.types.DataType.FloatType - case ByteType => - org.apache.spark.sql.api.java.types.DataType.ByteType - case IntegerType => - org.apache.spark.sql.api.java.types.DataType.IntegerType - case LongType => - org.apache.spark.sql.api.java.types.DataType.LongType - case ShortType => - org.apache.spark.sql.api.java.types.DataType.ShortType - - case arrayType: ArrayType => - org.apache.spark.sql.api.java.types.DataType.createArrayType( - asJavaDataType(arrayType.elementType), arrayType.containsNull) - case mapType: MapType => - org.apache.spark.sql.api.java.types.DataType.createMapType( - asJavaDataType(mapType.keyType), asJavaDataType(mapType.valueType)) - case structType: StructType => - org.apache.spark.sql.api.java.types.DataType.createStructType( - structType.fields.map(asJavaStructField).asJava) - } - - /** - * Returns the equivalent StructField in Scala for the given StructField in Java. - */ - protected def asScalaStructField(javaStructField: JStructField): StructField = { - StructField( - javaStructField.getName, - asScalaDataType(javaStructField.getDataType), - javaStructField.isNullable) - } - - /** - * Returns the equivalent DataType in Scala for the given DataType in Java. - */ - protected[sql] def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { - case stringType: org.apache.spark.sql.api.java.types.StringType => - StringType - case binaryType: org.apache.spark.sql.api.java.types.BinaryType => - BinaryType - case booleanType: org.apache.spark.sql.api.java.types.BooleanType => - BooleanType - case timestampType: org.apache.spark.sql.api.java.types.TimestampType => - TimestampType - case decimalType: org.apache.spark.sql.api.java.types.DecimalType => - DecimalType - case doubleType: org.apache.spark.sql.api.java.types.DoubleType => - DoubleType - case floatType: org.apache.spark.sql.api.java.types.FloatType => - FloatType - case byteType: org.apache.spark.sql.api.java.types.ByteType => - ByteType - case integerType: org.apache.spark.sql.api.java.types.IntegerType => - IntegerType - case longType: org.apache.spark.sql.api.java.types.LongType => - LongType - case shortType: org.apache.spark.sql.api.java.types.ShortType => - ShortType - - case arrayType: org.apache.spark.sql.api.java.types.ArrayType => - ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull) - case mapType: org.apache.spark.sql.api.java.types.MapType => - MapType(asScalaDataType(mapType.getKeyType), asScalaDataType(mapType.getValueType)) - case structType: org.apache.spark.sql.api.java.types.StructType => - StructType(structType.getFields.map(asScalaStructField)) - } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index caf0e040fa74e..0a3b59cbc233a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -85,11 +85,11 @@ private[sql] object CatalystConverter { case StructType(fields: Seq[StructField]) => { new CatalystStructConverter(fields.toArray, fieldIndex, parent) } - case MapType(keyType: DataType, valueType: DataType) => { + case MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) => { new CatalystMapConverter( Array( new FieldType(MAP_KEY_SCHEMA_NAME, keyType, false), - new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, true)), + new FieldType(MAP_VALUE_SCHEMA_NAME, valueType, valueContainsNull)), fieldIndex, parent) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 4d2a92ac779cd..6d4ce32ac5bfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -175,7 +175,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging { case t @ ArrayType(_, false) => writeArray( t, value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) - case t @ MapType(_, _) => writeMap( + case t @ MapType(_, _, _) => writeMap( t, value.asInstanceOf[CatalystConverter.MapScalaType[_, _]]) case t @ StructType(_) => writeStruct( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 9aea982110d57..46f5d81a755ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -130,6 +130,8 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) + // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true + // at here. MapType(keyType, valueType) } case _ => { @@ -140,6 +142,8 @@ private[parquet] object ParquetTypesConverter extends Logging { assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) val valueType = toDataType(keyValueGroup.getFields.apply(1)) assert(keyValueGroup.getFields.apply(1).getRepetition == Repetition.REQUIRED) + // TODO: set valueContainsNull explicitly instead of assuming valueContainsNull is true + // at here. MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType val elementType = toDataType(groupType.getFields.apply(0)) @@ -248,7 +252,7 @@ private[parquet] object ParquetTypesConverter extends Logging { } new ParquetGroupType(repetition, name, fields) } - case MapType(keyType, valueType) => { + case MapType(keyType, valueType, _) => { val parquetKeyType = fromDataType( keyType, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 69383f2e2d86d..b681f1c239279 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -261,7 +261,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType)) => + case (map: Map[_,_], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) @@ -280,7 +280,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { }.mkString("{", ",", "}") case (seq: Seq[_], ArrayType(typ, _)) => seq.map(v => (v, typ)).map(toHiveStructString).mkString("[", ",", "]") - case (map: Map[_,_], MapType(kType, vType)) => + case (map: Map[_,_], MapType(kType, vType, _)) => map.map { case (key, value) => toHiveStructString((key, kType)) + ":" + toHiveStructString((value, vType)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 28b1a43d85773..354fcd53f303b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -154,7 +154,7 @@ private[hive] trait HiveInspectors { def toInspector(dataType: DataType): ObjectInspector = dataType match { case ArrayType(tpe, _) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe)) - case MapType(keyType, valueType) => + case MapType(keyType, valueType, _) => ObjectInspectorFactory.getStandardMapObjectInspector( toInspector(keyType), toInspector(valueType)) case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 57738aabff176..0af661a744a7b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -234,7 +234,7 @@ object HiveMetastoreTypes extends RegexParsers { case ArrayType(elementType, _) => s"array<${toMetastoreType(elementType)}>" case StructType(fields) => s"struct<${fields.map(f => s"${f.name}:${toMetastoreType(f.dataType)}").mkString(",")}>" - case MapType(keyType, valueType) => + case MapType(keyType, valueType, _) => s"map<${toMetastoreType(keyType)},${toMetastoreType(valueType)}>" case StringType => "string" case FloatType => "float" From 991f860116bc1215f04c5153128da036112bb506 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 28 Jul 2014 15:22:25 -0700 Subject: [PATCH 27/34] Move "asJavaDataType" and "asScalaDataType" to DataTypeConversions.scala. --- .../spark/sql/api/java/JavaSQLContext.scala | 11 +- .../spark/sql/api/java/JavaSchemaRDD.scala | 4 +- .../scala/org/apache/spark/sql/package.scala | 18 ++- .../sql/types/util/DataTypeConversions.scala | 124 ++++++++++++++++++ .../java/JavaSideDataTypeConversionSuite.java | 28 +--- .../org/apache/spark/sql/DataTypeSuite.scala | 6 + .../ScalaSideDataTypeConversionSuite.scala | 15 +-- 7 files changed, 162 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index 2a50cee1d18a6..b9f6c6fbe9a35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -19,19 +19,18 @@ package org.apache.spark.sql.api.java import java.beans.Introspector -import scala.collection.JavaConverters._ - import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} -import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructType => JStructType} -import org.apache.spark.sql.api.java.types.{StructField => JStructField} +import org.apache.spark.sql.api.java.types.{StructType => JStructType} import org.apache.spark.sql.json.JsonRDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericRow, Row => ScalaRow} import org.apache.spark.sql.parquet.ParquetRelation import org.apache.spark.sql.execution.{ExistingRdd, SparkLogicalPlan} +import org.apache.spark.sql.types.util.DataTypeConversions +import DataTypeConversions.asScalaDataType; import org.apache.spark.util.Utils /** @@ -107,7 +106,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { @DeveloperApi def applySchema(rowRDD: JavaRDD[Row], schema: JStructType): JavaSchemaRDD = { val scalaRowRDD = rowRDD.rdd.map(r => r.row) - val scalaSchema = sqlContext.asScalaDataType(schema).asInstanceOf[StructType] + val scalaSchema = asScalaDataType(schema).asInstanceOf[StructType] val logicalPlan = SparkLogicalPlan(ExistingRdd(scalaSchema.toAttributes, scalaRowRDD)) new JavaSchemaRDD(sqlContext, logicalPlan) } @@ -156,7 +155,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { @Experimental def jsonRDD(json: JavaRDD[String], schema: JStructType): JavaSchemaRDD = { val appliedScalaSchema = - Option(sqlContext.asScalaDataType(schema)).getOrElse( + Option(asScalaDataType(schema)).getOrElse( JsonRDD.nullTypeToStringType(JsonRDD.inferSchema(json.rdd, 1.0))).asInstanceOf[StructType] val scalaRowRDD = JsonRDD.jsonStringToRow(json.rdd, appliedScalaSchema) val logicalPlan = SparkLogicalPlan(ExistingRdd(appliedScalaSchema.toAttributes, scalaRowRDD)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala index e560648f1a46d..824574149858c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSchemaRDD.scala @@ -23,8 +23,10 @@ import org.apache.spark.Partitioner import org.apache.spark.api.java.{JavaRDDLike, JavaRDD} import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.sql.api.java.types.StructType +import org.apache.spark.sql.types.util.DataTypeConversions import org.apache.spark.sql.{SQLContext, SchemaRDD, SchemaRDDLike} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import DataTypeConversions._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -56,7 +58,7 @@ class JavaSchemaRDD( /** Returns the schema of this JavaSchemaRDD (represented by a StructType). */ def schema: StructType = - sqlContext.asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType] + asJavaDataType(baseSchemaRDD.schema).asInstanceOf[StructType] // ======================================================================= // Base RDD functions that do NOT change schema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 8f87e6e19baf0..819e36bbf9c02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,7 +17,10 @@ package org.apache.spark +import scala.collection.JavaConverters._ + import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField} /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -243,8 +246,7 @@ package object sql { * The data type representing `Seq`s. * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and * `containsNull: Boolean`. The field of `elementType` is used to specify the type of - * array elements. The field of `containsNull` is used to specify if the array can have - * any `null` value. + * array elements. The field of `containsNull` is used to specify if the array has `null` valus. * * @group dataType */ @@ -271,10 +273,11 @@ package object sql { /** * :: DeveloperApi :: * - * The data type representing `Map`s. A [[MapType]] object comprises two fields, - * `keyType: [[DataType]]` and `valueType: [[DataType]]`. + * The data type representing `Map`s. A [[MapType]] object comprises three fields, + * `keyType: [[DataType]]`, `valueType: [[DataType]]` and `valueContainsNull: Boolean`. * The field of `keyType` is used to specify the type of keys in the map. * The field of `valueType` is used to specify the type of values in the map. + * The field of `valueContainsNull` is used to specify if values of this map has `null` values. * * @group dataType */ @@ -284,10 +287,15 @@ package object sql { /** * :: DeveloperApi :: * - * A [[MapType]] can be constructed by + * A [[MapType]] object can be constructed with two ways, + * {{{ + * MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) + * }}} and * {{{ * MapType(keyType: DataType, valueType: DataType) * }}} + * For `MapType(keyType: DataType, valueType: DataType)`, + * the field of `valueContainsNull` is set to `true`. * * @group dataType */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala new file mode 100644 index 0000000000000..a51383a431cf5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types.util + +import org.apache.spark.sql._ +import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField} + +import scala.collection.JavaConverters._ + +protected[sql] object DataTypeConversions { + + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + def asJavaStructField(scalaStructField: StructField): JStructField = { + org.apache.spark.sql.api.java.types.DataType.createStructField( + scalaStructField.name, + asJavaDataType(scalaStructField.dataType), + scalaStructField.nullable) + } + + /** + * Returns the equivalent DataType in Java for the given DataType in Scala. + */ + def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { + case StringType => + org.apache.spark.sql.api.java.types.DataType.StringType + case BinaryType => + org.apache.spark.sql.api.java.types.DataType.BinaryType + case BooleanType => + org.apache.spark.sql.api.java.types.DataType.BooleanType + case TimestampType => + org.apache.spark.sql.api.java.types.DataType.TimestampType + case DecimalType => + org.apache.spark.sql.api.java.types.DataType.DecimalType + case DoubleType => + org.apache.spark.sql.api.java.types.DataType.DoubleType + case FloatType => + org.apache.spark.sql.api.java.types.DataType.FloatType + case ByteType => + org.apache.spark.sql.api.java.types.DataType.ByteType + case IntegerType => + org.apache.spark.sql.api.java.types.DataType.IntegerType + case LongType => + org.apache.spark.sql.api.java.types.DataType.LongType + case ShortType => + org.apache.spark.sql.api.java.types.DataType.ShortType + + case arrayType: ArrayType => + org.apache.spark.sql.api.java.types.DataType.createArrayType( + asJavaDataType(arrayType.elementType), arrayType.containsNull) + case mapType: MapType => + org.apache.spark.sql.api.java.types.DataType.createMapType( + asJavaDataType(mapType.keyType), + asJavaDataType(mapType.valueType), + mapType.valueContainsNull) + case structType: StructType => + org.apache.spark.sql.api.java.types.DataType.createStructType( + structType.fields.map(asJavaStructField).asJava) + } + + /** + * Returns the equivalent StructField in Scala for the given StructField in Java. + */ + def asScalaStructField(javaStructField: JStructField): StructField = { + StructField( + javaStructField.getName, + asScalaDataType(javaStructField.getDataType), + javaStructField.isNullable) + } + + /** + * Returns the equivalent DataType in Scala for the given DataType in Java. + */ + def asScalaDataType(javaDataType: JDataType): DataType = javaDataType match { + case stringType: org.apache.spark.sql.api.java.types.StringType => + StringType + case binaryType: org.apache.spark.sql.api.java.types.BinaryType => + BinaryType + case booleanType: org.apache.spark.sql.api.java.types.BooleanType => + BooleanType + case timestampType: org.apache.spark.sql.api.java.types.TimestampType => + TimestampType + case decimalType: org.apache.spark.sql.api.java.types.DecimalType => + DecimalType + case doubleType: org.apache.spark.sql.api.java.types.DoubleType => + DoubleType + case floatType: org.apache.spark.sql.api.java.types.FloatType => + FloatType + case byteType: org.apache.spark.sql.api.java.types.ByteType => + ByteType + case integerType: org.apache.spark.sql.api.java.types.IntegerType => + IntegerType + case longType: org.apache.spark.sql.api.java.types.LongType => + LongType + case shortType: org.apache.spark.sql.api.java.types.ShortType => + ShortType + + case arrayType: org.apache.spark.sql.api.java.types.ArrayType => + ArrayType(asScalaDataType(arrayType.getElementType), arrayType.isContainsNull) + case mapType: org.apache.spark.sql.api.java.types.MapType => + MapType( + asScalaDataType(mapType.getKeyType), + asScalaDataType(mapType.getValueType), + mapType.isValueContainsNull) + case structType: org.apache.spark.sql.api.java.types.StructType => + StructType(structType.getFields.map(asScalaStructField)) + } +} diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java index eccec65c667d2..96a503962f7d1 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaSideDataTypeConversionSuite.java @@ -21,40 +21,20 @@ import java.util.ArrayList; import org.junit.Assert; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.sql.types.util.DataTypeConversions; import org.apache.spark.sql.api.java.types.DataType; import org.apache.spark.sql.api.java.types.StructField; -import org.apache.spark.sql.test.TestSQLContext; -import org.junit.rules.ExpectedException; public class JavaSideDataTypeConversionSuite { - private transient JavaSparkContext javaCtx; - private transient JavaSQLContext javaSqlCtx; - public void checkDataType(DataType javaDataType) { org.apache.spark.sql.catalyst.types.DataType scalaDataType = - javaSqlCtx.sqlContext().asScalaDataType(javaDataType); - DataType actual = javaSqlCtx.sqlContext().asJavaDataType(scalaDataType); + DataTypeConversions.asScalaDataType(javaDataType); + DataType actual = DataTypeConversions.asJavaDataType(scalaDataType); Assert.assertEquals(javaDataType, actual); } - @Before - public void setUp() { - javaCtx = new JavaSparkContext(TestSQLContext.sparkContext()); - javaSqlCtx = new JavaSQLContext(javaCtx); - } - - @After - public void tearDown() { - javaCtx.stop(); - javaCtx = null; - } - @Test public void createDataTypes() { // Simple DataTypes. @@ -102,7 +82,7 @@ public void createDataTypes() { // Complex MapType. DataType complexJavaMapType = - DataType.createMapType(complexJavaStructType, complexJavaArrayType); + DataType.createMapType(complexJavaStructType, complexJavaArrayType, false); checkDataType(complexJavaMapType); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index c5bd7b391db41..cf7d79f42db1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -27,6 +27,12 @@ class DataTypeSuite extends FunSuite { assert(ArrayType(StringType, false) === array) } + test("construct an MapType") { + val map = MapType(StringType, IntegerType) + + assert(MapType(StringType, IntegerType, true) === map) + } + test("extract fields from a StructType") { val struct = StructType( StructField("a", IntegerType, true) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index 56102c2d5b8fb..46de6fe239228 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -17,18 +17,17 @@ package org.apache.spark.sql.api.java -import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.sql._ -import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types.util.DataTypeConversions import org.scalatest.FunSuite +import org.apache.spark.sql._ +import DataTypeConversions._ + class ScalaSideDataTypeConversionSuite extends FunSuite { - val javaCtx = new JavaSparkContext(TestSQLContext.sparkContext) - val javaSqlCtx = new JavaSQLContext(javaCtx) def checkDataType(scalaDataType: DataType) { - val javaDataType = javaSqlCtx.sqlContext.asJavaDataType(scalaDataType) - val actual = javaSqlCtx.sqlContext.asScalaDataType(javaDataType) + val javaDataType = asJavaDataType(scalaDataType) + val actual = asScalaDataType(javaDataType) assert(scalaDataType === actual, s"Converted data type ${actual} " + s"does not equal the expected data type ${scalaDataType}") } @@ -76,7 +75,7 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { checkDataType(complexScalaArrayType) // Complex MapType. - val complexScalaMapType = MapType(complexScalaStructType, complexScalaArrayType) + val complexScalaMapType = MapType(complexScalaStructType, complexScalaArrayType, false) checkDataType(complexScalaMapType) } } From bd40a33d06797b13ccbe60705e6e6df01ba3ccc4 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 28 Jul 2014 17:16:09 -0700 Subject: [PATCH 28/34] Address comments. --- python/pyspark/sql.py | 88 +++++++++++-------- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 8 +- .../catalyst/expressions/BoundAttribute.scala | 5 +- .../sql/catalyst/planning/QueryPlanner.scala | 3 +- .../sql/catalyst/planning/patterns.scala | 7 +- .../spark/sql/catalyst/rules/Rule.scala | 3 +- .../sql/catalyst/rules/RuleExecutor.scala | 13 ++- .../spark/sql/api/java/types/ArrayType.java | 9 +- .../spark/sql/api/java/types/BinaryType.java | 2 + .../spark/sql/api/java/types/BooleanType.java | 5 ++ .../spark/sql/api/java/types/ByteType.java | 4 +- .../spark/sql/api/java/types/DataType.java | 21 ++++- .../spark/sql/api/java/types/DecimalType.java | 2 + .../spark/sql/api/java/types/DoubleType.java | 4 +- .../spark/sql/api/java/types/FloatType.java | 4 +- .../spark/sql/api/java/types/IntegerType.java | 4 +- .../spark/sql/api/java/types/LongType.java | 4 +- .../spark/sql/api/java/types/MapType.java | 9 +- .../spark/sql/api/java/types/ShortType.java | 4 +- .../spark/sql/api/java/types/StringType.java | 2 + .../spark/sql/api/java/types/StructField.java | 4 + .../spark/sql/api/java/types/StructType.java | 5 ++ .../sql/api/java/types/TimestampType.java | 2 + .../sql/api/java/types/package-info.java | 2 +- .../org/apache/spark/sql/SQLContext.scala | 3 +- .../apache/spark/sql/execution/Exchange.scala | 2 +- .../scala/org/apache/spark/sql/package.scala | 9 +- 28 files changed, 154 insertions(+), 78 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index c0d086acbe955..1e91fdc60b46f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -41,18 +41,18 @@ class StringType(object): """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "StringType" class BinaryType(object): """Spark SQL BinaryType - The data type representing bytes values and bytearray values. + The data type representing bytearray values. """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "BinaryType" class BooleanType(object): @@ -63,14 +63,18 @@ class BooleanType(object): """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "BooleanType" class TimestampType(object): - """Spark SQL TimestampType""" + """Spark SQL TimestampType + + The data type representing datetime.datetime values. + + """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "TimestampType" class DecimalType(object): @@ -81,40 +85,48 @@ class DecimalType(object): """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "DecimalType" class DoubleType(object): """Spark SQL DoubleType - The data type representing float values. Because a float value + The data type representing float values. """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "DoubleType" class FloatType(object): """Spark SQL FloatType - For PySpark, please use L{DoubleType} instead of using L{FloatType}. + For now, please use L{DoubleType} instead of using L{FloatType}. + Because query evaluation is done in Scala, java.lang.Double will be be used + for Python float numbers. Because the underlying JVM type of FloatType is + java.lang.Float (in Java) and Float (in scala), there will be a java.lang.ClassCastException + if FloatType (Python) used. """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "FloatType" class ByteType(object): """Spark SQL ByteType - For PySpark, please use L{IntegerType} instead of using L{ByteType}. + For now, please use L{IntegerType} instead of using L{ByteType}. + Because query evaluation is done in Scala, java.lang.Integer will be be used + for Python int numbers. Because the underlying JVM type of ByteType is + java.lang.Byte (in Java) and Byte (in scala), there will be a java.lang.ClassCastException + if ByteType (Python) used. """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "ByteType" class IntegerType(object): @@ -125,7 +137,7 @@ class IntegerType(object): """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "IntegerType" class LongType(object): @@ -137,18 +149,22 @@ class LongType(object): """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "LongType" class ShortType(object): """Spark SQL ShortType - For PySpark, please use L{IntegerType} instead of using L{ShortType}. + For now, please use L{IntegerType} instead of using L{ShortType}. + Because query evaluation is done in Scala, java.lang.Integer will be be used + for Python int numbers. Because the underlying JVM type of ShortType is + java.lang.Short (in Java) and Short (in scala), there will be a java.lang.ClassCastException + if ShortType (Python) used. """ __metaclass__ = PrimitiveTypeSingleton - def _get_scala_type_string(self): + def __repr__(self): return "ShortType" class ArrayType(object): @@ -157,23 +173,23 @@ class ArrayType(object): The data type representing list values. """ - def __init__(self, elementType, containsNull): + def __init__(self, elementType, containsNull=False): """Creates an ArrayType :param elementType: the data type of elements. :param containsNull: indicates whether the list contains null values. :return: - >>> ArrayType(StringType, True) == ArrayType(StringType, False) - False - >>> ArrayType(StringType, True) == ArrayType(StringType, True) + >>> ArrayType(StringType) == ArrayType(StringType, False) True + >>> ArrayType(StringType, True) == ArrayType(StringType) + False """ self.elementType = elementType self.containsNull = containsNull - def _get_scala_type_string(self): - return "ArrayType(" + self.elementType._get_scala_type_string() + "," + \ + def __repr__(self): + return "ArrayType(" + self.elementType.__repr__() + "," + \ str(self.containsNull).lower() + ")" def __eq__(self, other): @@ -207,9 +223,9 @@ def __init__(self, keyType, valueType, valueContainsNull=True): self.valueType = valueType self.valueContainsNull = valueContainsNull - def _get_scala_type_string(self): - return "MapType(" + self.keyType._get_scala_type_string() + "," + \ - self.valueType._get_scala_type_string() + "," + \ + def __repr__(self): + return "MapType(" + self.keyType.__repr__() + "," + \ + self.valueType.__repr__() + "," + \ str(self.valueContainsNull).lower() + ")" def __eq__(self, other): @@ -243,9 +259,9 @@ def __init__(self, name, dataType, nullable): self.dataType = dataType self.nullable = nullable - def _get_scala_type_string(self): + def __repr__(self): return "StructField(" + self.name + "," + \ - self.dataType._get_scala_type_string() + "," + \ + self.dataType.__repr__() + "," + \ str(self.nullable).lower() + ")" def __eq__(self, other): @@ -280,9 +296,9 @@ def __init__(self, fields): """ self.fields = fields - def _get_scala_type_string(self): + def __repr__(self): return "StructType(List(" + \ - ",".join([field._get_scala_type_string() for field in self.fields]) + "))" + ",".join([field.__repr__() for field in self.fields]) + "))" def __eq__(self, other): return (isinstance(other, self.__class__) and \ @@ -319,7 +335,7 @@ def _parse_datatype_string(datatype_string): :return: >>> def check_datatype(datatype): - ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype._get_scala_type_string()) + ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__()) ... python_datatype = _parse_datatype_string(scala_datatype.toString()) ... return datatype == python_datatype >>> check_datatype(StringType()) @@ -536,7 +552,7 @@ def applySchema(self, rdd, schema): True """ jrdd = self._pythonToJavaMap(rdd._jrdd) - srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema._get_scala_type_string()) + srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema.__repr__()) return SchemaRDD(srdd, self) def registerRDDAsTable(self, rdd, tableName): @@ -569,7 +585,7 @@ def parquetFile(self, path): jschema_rdd = self._ssql_ctx.parquetFile(path) return SchemaRDD(jschema_rdd, self) - def jsonFile(self, path, schema = None): + def jsonFile(self, path, schema=None): """Loads a text file storing one JSON object per line as a L{SchemaRDD}. If the schema is provided, applies the given schema to this JSON dataset. @@ -618,11 +634,11 @@ def jsonFile(self, path, schema = None): if schema is None: jschema_rdd = self._ssql_ctx.jsonFile(path) else: - scala_datatype = self._ssql_ctx.parseDataType(schema._get_scala_type_string()) + scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) jschema_rdd = self._ssql_ctx.jsonFile(path, scala_datatype) return SchemaRDD(jschema_rdd, self) - def jsonRDD(self, rdd, schema = None): + def jsonRDD(self, rdd, schema=None): """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. If the schema is provided, applies the given schema to this JSON dataset. @@ -672,7 +688,7 @@ def func(split, iterator): if schema is None: jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd()) else: - scala_datatype = self._ssql_ctx.parseDataType(schema._get_scala_type_string()) + scala_datatype = self._ssql_ctx.parseDataType(schema.__repr__()) jschema_rdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) return SchemaRDD(jschema_rdd, self) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02bdb64f308a5..f847355a43537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -109,12 +109,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case q: LogicalPlan if q.childrenResolved => - logger.trace(s"Attempting to resolve ${q.simpleString}") + logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = q.resolve(name).getOrElse(u) - logger.debug(s"Resolving $u to $result") + logDebug(s"Resolving $u to $result") result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 47c7ad076ad07..e94f2a3bea63e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -75,7 +75,7 @@ trait HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") newType } } @@ -154,7 +154,7 @@ trait HiveTypeCoercion { (Alias(Cast(l, StringType), l.name)(), r) case (l, r) if l.dataType != r.dataType => - logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") findTightestCommonType(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() @@ -170,7 +170,7 @@ trait HiveTypeCoercion { val newLeft = if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") + logDebug(s"Widening numeric types in union $castedLeft ${left.output}") Project(castedLeft, left) } else { left @@ -178,7 +178,7 @@ trait HiveTypeCoercion { val newRight = if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logger.debug(s"Widening numeric types in union $castedRight ${right.output}") + logDebug(s"Widening numeric types in union $castedRight ${right.output}") Project(castedRight, right) } else { right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index cbc214d442064..92a30810c736d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import com.typesafe.scalalogging.slf4j.Logging - +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.plans.QueryPlan @@ -80,7 +79,7 @@ object BindReferences extends Logging { // produce new attributes that can't be bound. Likely the right thing to do is remove // this rule and require all operators to explicitly bind to the input schema that // they specify. - logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + logDebug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") a } else { BoundReference(ordinal, a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 4ff5791635f4c..5839c9f7c43ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import com.typesafe.scalalogging.slf4j.Logging - +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index b8ae326be6fab..820eaeb75f768 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec -import com.typesafe.scalalogging.slf4j.Logging - +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -114,7 +113,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => - logger.debug(s"Considering join on: $condition") + logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val (joinPredicates, otherPredicates) = @@ -132,7 +131,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index f39bff8c25164..03414b2301e81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import com.typesafe.scalalogging.slf4j.Logging - +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e70ce66cb745f..e73515ff29377 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import com.typesafe.scalalogging.slf4j.Logging - +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -61,7 +60,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { case (plan, rule) => val result = rule(plan) if (!result.fastEquals(plan)) { - logger.trace( + logTrace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} @@ -72,25 +71,25 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { } iteration += 1 if (iteration > batch.strategy.maxIterations) { - logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}") + logInfo(s"Max iterations ($iteration) reached for batch ${batch.name}") continue = false } if (curPlan.fastEquals(lastPlan)) { - logger.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") + logTrace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") continue = false } lastPlan = curPlan } if (!batchStartPlan.fastEquals(curPlan)) { - logger.debug( + logDebug( s""" |=== Result of Batch ${batch.name} === |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { - logger.trace(s"Batch ${batch.name} has no effect.") + logTrace(s"Batch ${batch.name} has no effect.") } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java index 61f52055842e6..17334ca31b2b7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ArrayType.java @@ -21,8 +21,13 @@ * The data type representing Lists. * An ArrayType object comprises two fields, {@code DataType elementType} and * {@code boolean containsNull}. The field of {@code elementType} is used to specify the type of - * array elements. The field of {@code containsNull} is used to specify if the array can have - * any {@code null} value. + * array elements. The field of {@code containsNull} is used to specify if the array has + * {@code null} values. + * + * To create an {@link ArrayType}, + * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType)} or + * {@link org.apache.spark.sql.api.java.types.DataType#createArrayType(DataType, boolean)} + * should be used. */ public class ArrayType extends DataType { private DataType elementType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java index c33ee5e25cd32..61703179850e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BinaryType.java @@ -19,6 +19,8 @@ /** * The data type representing byte[] values. + * + * {@code BinaryType} is represented by the singleton object {@link DataType#BinaryType}. */ public class BinaryType extends DataType { protected BinaryType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java index 38981a21da58d..8fa24d85d1238 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/BooleanType.java @@ -17,6 +17,11 @@ package org.apache.spark.sql.api.java.types; +/** + * The data type representing boolean and Boolean values. + * + * {@code BooleanType} is represented by the singleton object {@link DataType#BooleanType}. + */ public class BooleanType extends DataType { protected BooleanType() {} } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java index 16b0d9ecf688c..2de32978e2705 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ByteType.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.api.java.types; /** - * The data type representing Byte values. + * The data type representing byte and Byte values. + * + * {@code ByteType} is represented by the singleton object {@link DataType#ByteType}. */ public class ByteType extends DataType { protected ByteType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java index c67287dea8ba6..6fd04aa2c6c9c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java @@ -23,6 +23,9 @@ /** * The base type of all Spark SQL data types. + * + * To get/create specific data type, users should use singleton objects and factory methods + * provided by this class. */ public abstract class DataType { @@ -81,6 +84,21 @@ public abstract class DataType { */ public static final ShortType ShortType = new ShortType(); + /** + * Creates an ArrayType by specifying the data type of elements ({@code elementType}). + * The field of {@code containsNull} is set to {@code false}. + * + * @param elementType + * @return + */ + public static ArrayType createArrayType(DataType elementType) { + if (elementType == null) { + throw new IllegalArgumentException("elementType should not be null."); + } + + return new ArrayType(elementType, false); + } + /** * Creates an ArrayType by specifying the data type of elements ({@code elementType}) and * whether the array contains null values ({@code containsNull}). @@ -98,7 +116,8 @@ public static ArrayType createArrayType(DataType elementType, boolean containsNu /** * Creates a MapType by specifying the data type of keys ({@code keyType}) and values - * ({@code keyType}). + * ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}. + * * @param keyType * @param valueType * @return diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java index d483824999e85..9250491a2d2ca 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DecimalType.java @@ -19,6 +19,8 @@ /** * The data type representing java.math.BigDecimal values. + * + * {@code DecimalType} is represented by the singleton object {@link DataType#DecimalType}. */ public class DecimalType extends DataType { protected DecimalType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java index 13a7bf6bbb5ed..3e86917fddc4b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DoubleType.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.api.java.types; /** - * The data type representing Double values. + * The data type representing double and Double values. + * + * {@code DoubleType} is represented by the singleton object {@link DataType#DoubleType}. */ public class DoubleType extends DataType { protected DoubleType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java index bf47d4fc1fa07..fa860d40176ef 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/FloatType.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.api.java.types; /** - * The data type representing Float values. + * The data type representing float and Float values. + * + * {@code FloatType} is represented by the singleton object {@link DataType#FloatType}. */ public class FloatType extends DataType { protected FloatType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java index f41ec2260df6b..bd973eca2c3ce 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/IntegerType.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.api.java.types; /** - * The data type representing Int values. + * The data type representing int and Integer values. + * + * {@code IntegerType} is represented by the singleton object {@link DataType#IntegerType}. */ public class IntegerType extends DataType { protected IntegerType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java index 7c73a7b506a2b..e00233304cefa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/LongType.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.api.java.types; /** - * The data type representing Long values. + * The data type representing long and Long values. + * + * {@code LongType} is represented by the singleton object {@link DataType#LongType}. */ public class LongType extends DataType { protected LongType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java index d116241a0e407..d2270d4b6ff9c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java @@ -19,9 +19,16 @@ /** * The data type representing Maps. A MapType object comprises two fields, - * {@code DataType keyType} and {@code DataType valueType}. + * {@code DataType keyType}, {@code DataType valueType}, and {@code boolean valueContainsNull}. * The field of {@code keyType} is used to specify the type of keys in the map. * The field of {@code valueType} is used to specify the type of values in the map. + * The field of {@code valueContainsNull} is used to specify if map values have + * {@code null} values. + * + * To create a {@link MapType}, + * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType)} or + * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType, boolean)} + * should be used. */ public class MapType extends DataType { private DataType keyType; diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java index 8ffa75a835e63..98f9507acf121 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/ShortType.java @@ -18,7 +18,9 @@ package org.apache.spark.sql.api.java.types; /** - * The data type representing Short values. + * The data type representing short and Short values. + * + * {@code ShortType} is represented by the singleton object {@link DataType#ShortType}. */ public class ShortType extends DataType { protected ShortType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java index dd9be52f8c53b..b8e7dbe646071 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StringType.java @@ -19,6 +19,8 @@ /** * The data type representing String values. + * + * {@code StringType} is represented by the singleton object {@link DataType#StringType}. */ public class StringType extends DataType { protected StringType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java index 25c82de9641c5..54e9c11ea415e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructField.java @@ -24,6 +24,10 @@ * The field of {@code dataType} specifies the data type of a StructField. * The field of {@code nullable} specifies if values of a StructField can contain {@code null} * values. + * + * To create a {@link StructField}, + * {@link org.apache.spark.sql.api.java.types.DataType#createStructField(String, DataType, boolean)} + * should be used. */ public class StructField { private String name; diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java index 17142ff672822..33a42f4b16265 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/StructType.java @@ -23,6 +23,11 @@ /** * The data type representing Rows. * A StructType object comprises an array of StructFields. + * + * To create an {@link StructType}, + * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(java.util.List)} or + * {@link org.apache.spark.sql.api.java.types.DataType#createStructType(StructField[])} + * should be used. */ public class StructType extends DataType { private StructField[] fields; diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java index 8c2f203d950c4..65295779f71ec 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/TimestampType.java @@ -19,6 +19,8 @@ /** * The data type representing java.sql.Timestamp values. + * + * {@code TimestampType} is represented by the singleton object {@link DataType#TimestampType}. */ public class TimestampType extends DataType { protected TimestampType() {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java index a1c6fcf1430f5..f169ac65e226f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/package-info.java @@ -19,4 +19,4 @@ /** * Allows users to get and create Spark SQL data types. */ -package org.apache.spark.sql.api.java.types; \ No newline at end of file +package org.apache.spark.sql.api.java.types; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cb48b689903c7..f93839c66c5d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -107,7 +107,8 @@ class SQLContext(@transient val sparkContext: SparkContext) /** * Parses the data type in our internal string representation. The data type string should - * have the same format as the one generate by `toString` in scala. + * have the same format as the one generated by `toString` in scala. + * It is only used by PySpark. */ private[sql] def parseDataType(dataTypeString: String): DataType = { val parser = org.apache.spark.sql.catalyst.types.DataType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 00010ef6e798a..67d70d599c3f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -99,7 +99,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) - logger.debug( + logDebug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 819e36bbf9c02..e7732fd7d8336 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -17,10 +17,7 @@ package org.apache.spark -import scala.collection.JavaConverters._ - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.api.java.types.{DataType => JDataType, StructField => JStructField} /** * Allows the execution of relational queries, including those expressed in SQL using Spark. @@ -243,10 +240,12 @@ package object sql { /** * :: DeveloperApi :: * - * The data type representing `Seq`s. + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * * An [[ArrayType]] object comprises two fields, `elementType: [[DataType]]` and * `containsNull: Boolean`. The field of `elementType` is used to specify the type of - * array elements. The field of `containsNull` is used to specify if the array has `null` valus. + * array elements. The field of `containsNull` is used to specify if the array has `null` values. * * @group dataType */ From ab71f21b2d892f320ba2bc2441b41853a0064d6f Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 28 Jul 2014 17:26:03 -0700 Subject: [PATCH 29/34] Format. --- .../org/apache/spark/sql/catalyst/types/dataTypes.scala | 5 ++++- .../src/main/scala/org/apache/spark/sql/json/JsonRDD.scala | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 13eff8ec87bc1..afc8e786aaac7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -353,7 +353,10 @@ object MapType { MapType(keyType: DataType, valueType: DataType, true) } -case class MapType(keyType: DataType, valueType: DataType, valueContainsNull: Boolean) extends DataType { +case class MapType( + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") builder.append(s"${prefix}-- value: ${valueType.simpleString}\n") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 85396f26142e4..0c12f74658ec5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -264,7 +264,7 @@ private[sql] object JsonRDD extends Logging { // So, right now, we will have Infinity for those BigDecimal number. // TODO: Support BigDecimal. json.mapPartitions(iter => { - // Also, when there is a key appearing multiple times (a duplicate key), + // When there is a key appearing multiple times (a duplicate key), // the ObjectMapper will take the last value associated with this duplicate key. // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() From 2476ed01fe0af83f6f9bd85e6d354d4a887f3a9e Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 29 Jul 2014 00:16:40 -0700 Subject: [PATCH 30/34] Minor updates. --- python/pyspark/sql.py | 36 +++++++++++---- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 8 ++-- .../catalyst/expressions/BoundAttribute.scala | 4 +- .../apache/spark/sql/catalyst/package.scala | 23 ++++++++++ .../sql/catalyst/planning/QueryPlanner.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 6 +-- .../spark/sql/catalyst/rules/Rule.scala | 2 +- .../sql/catalyst/rules/RuleExecutor.scala | 12 ++--- .../spark/sql/catalyst/types/dataTypes.scala | 3 +- .../spark/sql/api/java/types/MapType.java | 1 + .../spark/sql/api/java/JavaSQLContext.scala | 2 +- .../org/apache/spark/sql/api/java/Row.scala | 3 ++ .../apache/spark/sql/execution/Exchange.scala | 2 +- .../scala/org/apache/spark/sql/package.scala | 1 + .../org/apache/spark/sql/SQLQuerySuite.scala | 44 +++++++++++++++---- 16 files changed, 115 insertions(+), 38 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1e91fdc60b46f..0569f3a62c244 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -21,9 +21,9 @@ from py4j.protocol import Py4JError __all__ = [ - "StringType", "BinaryType", "BooleanType", "DecimalType", "DoubleType", - "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", - "ArrayType", "MapType", "StructField", "StructType", + "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", + "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", + "ShortType", "ArrayType", "MapType", "StructField", "StructType", "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] class PrimitiveTypeSingleton(type): @@ -106,7 +106,7 @@ class FloatType(object): Because query evaluation is done in Scala, java.lang.Double will be be used for Python float numbers. Because the underlying JVM type of FloatType is java.lang.Float (in Java) and Float (in scala), there will be a java.lang.ClassCastException - if FloatType (Python) used. + if FloatType (Python) is used. """ __metaclass__ = PrimitiveTypeSingleton @@ -121,7 +121,7 @@ class ByteType(object): Because query evaluation is done in Scala, java.lang.Integer will be be used for Python int numbers. Because the underlying JVM type of ByteType is java.lang.Byte (in Java) and Byte (in scala), there will be a java.lang.ClassCastException - if ByteType (Python) used. + if ByteType (Python) is used. """ __metaclass__ = PrimitiveTypeSingleton @@ -159,7 +159,7 @@ class ShortType(object): Because query evaluation is done in Scala, java.lang.Integer will be be used for Python int numbers. Because the underlying JVM type of ShortType is java.lang.Short (in Java) and Short (in scala), there will be a java.lang.ClassCastException - if ShortType (Python) used. + if ShortType (Python) is used. """ __metaclass__ = PrimitiveTypeSingleton @@ -171,13 +171,16 @@ class ArrayType(object): """Spark SQL ArrayType The data type representing list values. + An ArrayType object comprises two fields, elementType (a DataType) and containsNull (a bool). + The field of elementType is used to specify the type of array elements. + The field of containsNull is used to specify if the array has None values. """ def __init__(self, elementType, containsNull=False): """Creates an ArrayType :param elementType: the data type of elements. - :param containsNull: indicates whether the list contains null values. + :param containsNull: indicates whether the list contains None values. :return: >>> ArrayType(StringType) == ArrayType(StringType, False) @@ -205,6 +208,12 @@ class MapType(object): """Spark SQL MapType The data type representing dict values. + A MapType object comprises three fields, + keyType (a DataType), valueType (a DataType) and valueContainsNull (a bool). + The field of keyType is used to specify the type of keys in the map. + The field of valueType is used to specify the type of values in the map. + The field of valueContainsNull is used to specify if values of this map has None values. + For values of a MapType column, keys are not allowed to have None values. """ def __init__(self, keyType, valueType, valueContainsNull=True): @@ -241,6 +250,10 @@ class StructField(object): """Spark SQL StructField Represents a field in a StructType. + A StructField object comprises three fields, name (a string), dataType (a DataType), + and nullable (a bool). The field of name is the name of a StructField. The field of + dataType specifies the data type of a StructField. + The field of nullable specifies if values of a StructField can contain None values. """ def __init__(self, name, dataType, nullable): @@ -276,7 +289,8 @@ def __ne__(self, other): class StructType(object): """Spark SQL StructType - The data type representing tuple values. + The data type representing namedtuple values. + A StructType object comprises a list of L{StructField}s. """ def __init__(self, fields): @@ -308,6 +322,11 @@ def __ne__(self, other): return not self.__eq__(other) def _parse_datatype_list(datatype_list_string): + """Parses a list of comma separated data types. + + :param datatype_list_string: + :return: + """ index = 0 datatype_list = [] start = 0 @@ -331,6 +350,7 @@ def _parse_datatype_list(datatype_list_string): def _parse_datatype_string(datatype_string): """Parses the given data type string. + :param datatype_string: :return: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index f847355a43537..02bdb64f308a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -109,12 +109,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case q: LogicalPlan if q.childrenResolved => - logTrace(s"Attempting to resolve ${q.simpleString}") + logger.trace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. val result = q.resolve(name).getOrElse(u) - logDebug(s"Resolving $u to $result") + logger.debug(s"Resolving $u to $result") result } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e94f2a3bea63e..47c7ad076ad07 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -75,7 +75,7 @@ trait HiveTypeCoercion { // Leave the same if the dataTypes match. case Some(newType) if a.dataType == newType.dataType => a case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + logger.debug(s"Promoting $a to $newType in ${q.simpleString}}") newType } } @@ -154,7 +154,7 @@ trait HiveTypeCoercion { (Alias(Cast(l, StringType), l.name)(), r) case (l, r) if l.dataType != r.dataType => - logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") + logger.debug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}") findTightestCommonType(l.dataType, r.dataType).map { widestType => val newLeft = if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)() @@ -170,7 +170,7 @@ trait HiveTypeCoercion { val newLeft = if (castedLeft.map(_.dataType) != left.output.map(_.dataType)) { - logDebug(s"Widening numeric types in union $castedLeft ${left.output}") + logger.debug(s"Widening numeric types in union $castedLeft ${left.output}") Project(castedLeft, left) } else { left @@ -178,7 +178,7 @@ trait HiveTypeCoercion { val newRight = if (castedRight.map(_.dataType) != right.output.map(_.dataType)) { - logDebug(s"Widening numeric types in union $castedRight ${right.output}") + logger.debug(s"Widening numeric types in union $castedRight ${right.output}") Project(castedRight, right) } else { right diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 92a30810c736d..fc398e685121d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.plans.QueryPlan @@ -79,7 +79,7 @@ object BindReferences extends Logging { // produce new attributes that can't be bound. Likely the right thing to do is remove // this rule and require all operators to explicitly bind to the input schema that // they specify. - logDebug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + logger.debug(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") a } else { BoundReference(ordinal, a) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala new file mode 100644 index 0000000000000..a2d85ea7136b3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/package.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +package object catalyst { + protected[catalyst] type Logging = com.typesafe.scalalogging.slf4j.Logging +} + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala index 5839c9f7c43ef..781ba489b44c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/QueryPlanner.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.TreeNode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 820eaeb75f768..3da23907008f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.planning import scala.annotation.tailrec -import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -113,7 +113,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => - logDebug(s"Considering join on: $condition") + logger.debug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val (joinPredicates, otherPredicates) = @@ -131,7 +131,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { val rightKeys = joinKeys.map(_._2) if (joinKeys.nonEmpty) { - logDebug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") + logger.debug(s"leftKeys:${leftKeys} | rightKeys:${rightKeys}") Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) } else { None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala index 03414b2301e81..f8960b3fe7a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/Rule.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.trees.TreeNode abstract class Rule[TreeType <: TreeNode[_]] extends Logging { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index e73515ff29377..b14f8ed2e33e0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.rules -import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.Logging import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.sideBySide @@ -60,7 +60,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { case (plan, rule) => val result = rule(plan) if (!result.fastEquals(plan)) { - logTrace( + logger.trace( s""" |=== Applying Rule ${rule.ruleName} === |${sideBySide(plan.treeString, result.treeString).mkString("\n")} @@ -71,25 +71,25 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { } iteration += 1 if (iteration > batch.strategy.maxIterations) { - logInfo(s"Max iterations ($iteration) reached for batch ${batch.name}") + logger.info(s"Max iterations ($iteration) reached for batch ${batch.name}") continue = false } if (curPlan.fastEquals(lastPlan)) { - logTrace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") + logger.trace(s"Fixed point reached for batch ${batch.name} after $iteration iterations.") continue = false } lastPlan = curPlan } if (!batchStartPlan.fastEquals(curPlan)) { - logDebug( + logger.debug( s""" |=== Result of Batch ${batch.name} === |${sideBySide(plan.treeString, curPlan.treeString).mkString("\n")} """.stripMargin) } else { - logTrace(s"Batch ${batch.name} has no effect.") + logger.trace(s"Batch ${batch.name} has no effect.") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index afc8e786aaac7..c006f82e9031f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -359,7 +359,8 @@ case class MapType( valueContainsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"${prefix}-- key: ${keyType.simpleString}\n") - builder.append(s"${prefix}-- value: ${valueType.simpleString}\n") + builder.append(s"${prefix}-- value: ${valueType.simpleString} " + + s"(valueContainsNull = ${valueContainsNull})\n") DataType.buildFormattedString(keyType, s"$prefix |", builder) DataType.buildFormattedString(valueType, s"$prefix |", builder) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java index d2270d4b6ff9c..94936e2e4ee7a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/MapType.java @@ -24,6 +24,7 @@ * The field of {@code valueType} is used to specify the type of values in the map. * The field of {@code valueContainsNull} is used to specify if map values have * {@code null} values. + * For values of a MapType column, keys are not allowed to have {@code null} values. * * To create a {@link MapType}, * {@link org.apache.spark.sql.api.java.types.DataType#createMapType(DataType, DataType)} or diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala index b9f6c6fbe9a35..3325782604b36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/JavaSQLContext.scala @@ -137,7 +137,7 @@ class JavaSQLContext(val sqlContext: SQLContext) { /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [JavaSchemaRDD. + * JavaSchemaRDD. * It goes through the entire dataset once to determine the schema. */ def jsonRDD(json: JavaRDD[String]): JavaSchemaRDD = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala index a87d6f25f6130..6c67934bda5b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala @@ -110,6 +110,8 @@ class Row(private[spark] val row: ScalaRow) extends Serializable { object Row { private def toJavaValue(value: Any): Any = value match { + // For values of this ScalaRow, we will do the conversion when + // they are actually accessed. case row: ScalaRow => new Row(row) case map: scala.collection.Map[_, _] => JavaConversions.mapAsJavaMap( @@ -125,6 +127,7 @@ object Row { // TODO: Consolidate the toScalaValue at here with the scalafy in JsonRDD? private def toScalaValue(value: Any): Any = value match { + // Values of this row have been converted to Scala values. case row: Row => row.row case map: java.util.Map[_, _] => JMapWrapper(map).map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 67d70d599c3f7..00010ef6e798a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -99,7 +99,7 @@ private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPl !operator.requiredChildDistribution.zip(operator.children).map { case (required, child) => val valid = child.outputPartitioning.satisfies(required) - logDebug( + logger.debug( s"${if (valid) "Valid" else "Invalid"} distribution," + s"required: $required current: ${child.outputPartitioning}") valid diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index e7732fd7d8336..0995a4eb6299f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -277,6 +277,7 @@ package object sql { * The field of `keyType` is used to specify the type of keys in the map. * The field of `valueType` is used to specify the type of values in the map. * The field of `valueContainsNull` is used to specify if values of this map has `null` values. + * For values of a MapType column, keys are not allowed to have `null` values. * * @group dataType */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c07753c40b656..7c6e7b3abd4b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.test._ /* Implicits */ @@ -448,13 +446,13 @@ class SQLQuerySuite extends QueryTest { } test("apply schema") { - val schema = StructType( + val schema1 = StructType( StructField("f1", IntegerType, false) :: StructField("f2", StringType, false) :: StructField("f3", BooleanType, false) :: StructField("f4", IntegerType, true) :: Nil) - val rowRDD = unparsedStrings.map { r => + val rowRDD1 = unparsedStrings.map { r => val values = r.split(",").map(_.trim) val v4 = try values(3).toInt catch { case _: NumberFormatException => null @@ -462,17 +460,47 @@ class SQLQuerySuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerAsTable("applySchema") + val schemaRDD1 = applySchema(rowRDD1, schema1) + schemaRDD1.registerAsTable("applySchema1") checkAnswer( - sql("SELECT * FROM applySchema"), + sql("SELECT * FROM applySchema1"), (1, "A1", true, null) :: (2, "B2", false, null) :: (3, "C3", true, null) :: (4, "D4", true, 2147483644) :: Nil) checkAnswer( - sql("SELECT f1, f4 FROM applySchema"), + sql("SELECT f1, f4 FROM applySchema1"), + (1, null) :: + (2, null) :: + (3, null) :: + (4, 2147483644) :: Nil) + + val schema2 = StructType( + StructField("f1", StructType( + StructField("f11", IntegerType, false) :: + StructField("f12", BooleanType, false) :: Nil), false) :: + StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) + + val rowRDD2 = unparsedStrings.map { r => + val values = r.split(",").map(_.trim) + val v4 = try values(3).toInt catch { + case _: NumberFormatException => null + } + Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) + } + + val schemaRDD2 = applySchema(rowRDD2, schema2) + schemaRDD2.registerAsTable("applySchema2") + checkAnswer( + sql("SELECT * FROM applySchema2"), + (Seq(1, true), Map("A1" -> null)) :: + (Seq(2, false), Map("B2" -> null)) :: + (Seq(3, true), Map("C3" -> null)) :: + (Seq(4, true), Map("D4" -> 2147483644)) :: Nil) + + checkAnswer( + sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), (1, null) :: (2, null) :: (3, null) :: From 122d1e7f4b664cd460c1bc9da72025b6494aaa07 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 29 Jul 2014 13:03:22 -0700 Subject: [PATCH 31/34] Address comments. --- python/pyspark/sql.py | 54 +++++++++++++------ .../spark/sql/api/java/types/DataType.java | 22 -------- .../org/apache/spark/sql/SQLContext.scala | 20 +++++++ .../org/apache/spark/sql/json/JsonRDD.scala | 7 ++- .../spark/sql/parquet/ParquetTypes.scala | 4 +- .../sql/types/util/DataTypeConversions.scala | 44 ++++++--------- 6 files changed, 78 insertions(+), 73 deletions(-) diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 0569f3a62c244..2c62156102986 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -26,13 +26,16 @@ "ShortType", "ArrayType", "MapType", "StructField", "StructType", "SQLContext", "HiveContext", "LocalHiveContext", "TestHiveContext", "SchemaRDD", "Row"] + class PrimitiveTypeSingleton(type): _instances = {} + def __call__(cls): if cls not in cls._instances: cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() return cls._instances[cls] + class StringType(object): """Spark SQL StringType @@ -44,6 +47,7 @@ class StringType(object): def __repr__(self): return "StringType" + class BinaryType(object): """Spark SQL BinaryType @@ -55,6 +59,7 @@ class BinaryType(object): def __repr__(self): return "BinaryType" + class BooleanType(object): """Spark SQL BooleanType @@ -66,6 +71,7 @@ class BooleanType(object): def __repr__(self): return "BooleanType" + class TimestampType(object): """Spark SQL TimestampType @@ -77,6 +83,7 @@ class TimestampType(object): def __repr__(self): return "TimestampType" + class DecimalType(object): """Spark SQL DecimalType @@ -88,6 +95,7 @@ class DecimalType(object): def __repr__(self): return "DecimalType" + class DoubleType(object): """Spark SQL DoubleType @@ -99,13 +107,15 @@ class DoubleType(object): def __repr__(self): return "DoubleType" + class FloatType(object): """Spark SQL FloatType For now, please use L{DoubleType} instead of using L{FloatType}. Because query evaluation is done in Scala, java.lang.Double will be be used for Python float numbers. Because the underlying JVM type of FloatType is - java.lang.Float (in Java) and Float (in scala), there will be a java.lang.ClassCastException + java.lang.Float (in Java) and Float (in scala), and we are trying to cast the type, + there will be a java.lang.ClassCastException if FloatType (Python) is used. """ @@ -114,13 +124,15 @@ class FloatType(object): def __repr__(self): return "FloatType" + class ByteType(object): """Spark SQL ByteType For now, please use L{IntegerType} instead of using L{ByteType}. Because query evaluation is done in Scala, java.lang.Integer will be be used for Python int numbers. Because the underlying JVM type of ByteType is - java.lang.Byte (in Java) and Byte (in scala), there will be a java.lang.ClassCastException + java.lang.Byte (in Java) and Byte (in scala), and we are trying to cast the type, + there will be a java.lang.ClassCastException if ByteType (Python) is used. """ @@ -129,6 +141,7 @@ class ByteType(object): def __repr__(self): return "ByteType" + class IntegerType(object): """Spark SQL IntegerType @@ -140,6 +153,7 @@ class IntegerType(object): def __repr__(self): return "IntegerType" + class LongType(object): """Spark SQL LongType @@ -152,13 +166,15 @@ class LongType(object): def __repr__(self): return "LongType" + class ShortType(object): """Spark SQL ShortType For now, please use L{IntegerType} instead of using L{ShortType}. Because query evaluation is done in Scala, java.lang.Integer will be be used for Python int numbers. Because the underlying JVM type of ShortType is - java.lang.Short (in Java) and Short (in scala), there will be a java.lang.ClassCastException + java.lang.Short (in Java) and Short (in scala), and we are trying to cast the type, + there will be a java.lang.ClassCastException if ShortType (Python) is used. """ @@ -167,6 +183,7 @@ class ShortType(object): def __repr__(self): return "ShortType" + class ArrayType(object): """Spark SQL ArrayType @@ -196,9 +213,9 @@ def __repr__(self): str(self.containsNull).lower() + ")" def __eq__(self, other): - return (isinstance(other, self.__class__) and \ - self.elementType == other.elementType and \ - self.containsNull == other.containsNull) + return (isinstance(other, self.__class__) and + self.elementType == other.elementType and + self.containsNull == other.containsNull) def __ne__(self, other): return not self.__eq__(other) @@ -238,14 +255,15 @@ def __repr__(self): str(self.valueContainsNull).lower() + ")" def __eq__(self, other): - return (isinstance(other, self.__class__) and \ - self.keyType == other.keyType and \ - self.valueType == other.valueType and \ - self.valueContainsNull == other.valueContainsNull) + return (isinstance(other, self.__class__) and + self.keyType == other.keyType and + self.valueType == other.valueType and + self.valueContainsNull == other.valueContainsNull) def __ne__(self, other): return not self.__eq__(other) + class StructField(object): """Spark SQL StructField @@ -278,14 +296,15 @@ def __repr__(self): str(self.nullable).lower() + ")" def __eq__(self, other): - return (isinstance(other, self.__class__) and \ - self.name == other.name and \ - self.dataType == other.dataType and \ - self.nullable == other.nullable) + return (isinstance(other, self.__class__) and + self.name == other.name and + self.dataType == other.dataType and + self.nullable == other.nullable) def __ne__(self, other): return not self.__eq__(other) + class StructType(object): """Spark SQL StructType @@ -315,12 +334,13 @@ def __repr__(self): ",".join([field.__repr__() for field in self.fields]) + "))" def __eq__(self, other): - return (isinstance(other, self.__class__) and \ - self.fields == other.fields) + return (isinstance(other, self.__class__) and + self.fields == other.fields) def __ne__(self, other): return not self.__eq__(other) + def _parse_datatype_list(datatype_list_string): """Parses a list of comma separated data types. @@ -348,6 +368,7 @@ def _parse_datatype_list(datatype_list_string): datatype_list.append(_parse_datatype_string(datatype_string)) return datatype_list + def _parse_datatype_string(datatype_string): """Parses the given data type string. @@ -472,6 +493,7 @@ def _parse_datatype_string(datatype_string): fields = _parse_datatype_list(field_list_string) return StructType(fields) + class SQLContext: """Main entry point for SparkSQL functionality. diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java index 6fd04aa2c6c9c..f84e5a490a905 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/types/DataType.java @@ -87,9 +87,6 @@ public abstract class DataType { /** * Creates an ArrayType by specifying the data type of elements ({@code elementType}). * The field of {@code containsNull} is set to {@code false}. - * - * @param elementType - * @return */ public static ArrayType createArrayType(DataType elementType) { if (elementType == null) { @@ -102,9 +99,6 @@ public static ArrayType createArrayType(DataType elementType) { /** * Creates an ArrayType by specifying the data type of elements ({@code elementType}) and * whether the array contains null values ({@code containsNull}). - * @param elementType - * @param containsNull - * @return */ public static ArrayType createArrayType(DataType elementType, boolean containsNull) { if (elementType == null) { @@ -117,10 +111,6 @@ public static ArrayType createArrayType(DataType elementType, boolean containsNu /** * Creates a MapType by specifying the data type of keys ({@code keyType}) and values * ({@code keyType}). The field of {@code valueContainsNull} is set to {@code true}. - * - * @param keyType - * @param valueType - * @return */ public static MapType createMapType(DataType keyType, DataType valueType) { if (keyType == null) { @@ -137,10 +127,6 @@ public static MapType createMapType(DataType keyType, DataType valueType) { * Creates a MapType by specifying the data type of keys ({@code keyType}), the data type of * values ({@code keyType}), and whether values contain any null value * ({@code valueContainsNull}). - * @param keyType - * @param valueType - * @param valueContainsNull - * @return */ public static MapType createMapType( DataType keyType, @@ -159,10 +145,6 @@ public static MapType createMapType( /** * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and * whether values of this field can be null values ({@code nullable}). - * @param name - * @param dataType - * @param nullable - * @return */ public static StructField createStructField(String name, DataType dataType, boolean nullable) { if (name == null) { @@ -177,8 +159,6 @@ public static StructField createStructField(String name, DataType dataType, bool /** * Creates a StructType with the given list of StructFields ({@code fields}). - * @param fields - * @return */ public static StructType createStructType(List fields) { return createStructType(fields.toArray(new StructField[0])); @@ -186,8 +166,6 @@ public static StructType createStructType(List fields) { /** * Creates a StructType with the given StructField array ({@code fields}). - * @param fields - * @return */ public static StructType createStructType(StructField[] fields) { if (fields == null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index f93839c66c5d8..79293a70896f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -94,6 +94,26 @@ class SQLContext(@transient val sparkContext: SparkContext) * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. * It is important to make sure that the structure of every [[Row]] of the provided RDD matches * the provided schema. Otherwise, there will be runtime exception. + * Example: + * {{{ + * import org.apache.spark.sql._ + * val sqlContext = new org.apache.spark.sql.SQLContext(sc) + * + * val schema = + * StructType( + * StructField("name", StringType, false) :: + * StructField("age", IntegerType, true) :: Nil) + * + * val people = sc.textFile("examples/src/main/resources/people.txt").map(_.split(",")).map(p => Row(p(0), p(1).trim.toInt)) + * val peopleSchemaRDD = sqlContext. applySchema(people, schema) + * peopleSchemaRDD.printSchema + * // root + * // |-- name: string (nullable = false) + * // |-- age: integer (nullable = true) + * + * peopleSchemaRDD.registerAsTable("people") + * sqlContext.sql("select name from people").collect.foreach(println) + * }}} * * @group userf */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 0c12f74658ec5..bd29ee421bbc4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -268,10 +268,9 @@ private[sql] object JsonRDD extends Logging { // the ObjectMapper will take the last value associated with this duplicate key. // For example: for {"key": 1, "key":2}, we will get "key"->2. val mapper = new ObjectMapper() - iter.map { - record => - val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) - parsed.asInstanceOf[Map[String, Any]] + iter.map { record => + val parsed = scalafy(mapper.readValue(record, classOf[java.util.Map[String, Any]])) + parsed.asInstanceOf[Map[String, Any]] } }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 46f5d81a755ce..aaef1a1d474fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -116,7 +116,7 @@ private[parquet] object ParquetTypesConverter extends Logging { case ParquetOriginalType.LIST => { // TODO: check enums! assert(groupType.getFieldCount == 1) val field = groupType.getFields.apply(0) - ArrayType(toDataType(field), false) + ArrayType(toDataType(field), containsNull = false) } case ParquetOriginalType.MAP => { assert( @@ -147,7 +147,7 @@ private[parquet] object ParquetTypesConverter extends Logging { MapType(keyType, valueType) } else if (correspondsToArray(groupType)) { // ArrayType val elementType = toDataType(groupType.getFields.apply(0)) - ArrayType(elementType, false) + ArrayType(elementType, containsNull = false) } else { // everything else: StructType val fields = groupType .getFields diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index a51383a431cf5..d1aa3c8d53757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -28,7 +28,7 @@ protected[sql] object DataTypeConversions { * Returns the equivalent StructField in Scala for the given StructField in Java. */ def asJavaStructField(scalaStructField: StructField): JStructField = { - org.apache.spark.sql.api.java.types.DataType.createStructField( + JDataType.createStructField( scalaStructField.name, asJavaDataType(scalaStructField.dataType), scalaStructField.nullable) @@ -38,39 +38,25 @@ protected[sql] object DataTypeConversions { * Returns the equivalent DataType in Java for the given DataType in Scala. */ def asJavaDataType(scalaDataType: DataType): JDataType = scalaDataType match { - case StringType => - org.apache.spark.sql.api.java.types.DataType.StringType - case BinaryType => - org.apache.spark.sql.api.java.types.DataType.BinaryType - case BooleanType => - org.apache.spark.sql.api.java.types.DataType.BooleanType - case TimestampType => - org.apache.spark.sql.api.java.types.DataType.TimestampType - case DecimalType => - org.apache.spark.sql.api.java.types.DataType.DecimalType - case DoubleType => - org.apache.spark.sql.api.java.types.DataType.DoubleType - case FloatType => - org.apache.spark.sql.api.java.types.DataType.FloatType - case ByteType => - org.apache.spark.sql.api.java.types.DataType.ByteType - case IntegerType => - org.apache.spark.sql.api.java.types.DataType.IntegerType - case LongType => - org.apache.spark.sql.api.java.types.DataType.LongType - case ShortType => - org.apache.spark.sql.api.java.types.DataType.ShortType + case StringType => JDataType.StringType + case BinaryType => JDataType.BinaryType + case BooleanType => JDataType.BooleanType + case TimestampType => JDataType.TimestampType + case DecimalType => JDataType.DecimalType + case DoubleType => JDataType.DoubleType + case FloatType => JDataType.FloatType + case ByteType => JDataType.ByteType + case IntegerType => JDataType.IntegerType + case LongType => JDataType.LongType + case ShortType => JDataType.ShortType - case arrayType: ArrayType => - org.apache.spark.sql.api.java.types.DataType.createArrayType( + case arrayType: ArrayType => JDataType.createArrayType( asJavaDataType(arrayType.elementType), arrayType.containsNull) - case mapType: MapType => - org.apache.spark.sql.api.java.types.DataType.createMapType( + case mapType: MapType => JDataType.createMapType( asJavaDataType(mapType.keyType), asJavaDataType(mapType.valueType), mapType.valueContainsNull) - case structType: StructType => - org.apache.spark.sql.api.java.types.DataType.createStructType( + case structType: StructType => JDataType.createStructType( structType.fields.map(asJavaStructField).asJava) } From e5f8df5d8b576ae1d34e905a26e27a6756f838d7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 29 Jul 2014 13:17:24 -0700 Subject: [PATCH 32/34] Scaladoc. --- .../catalyst/expressions/WrapDynamic.scala | 11 +++++++++++ .../spark/sql/catalyst/types/dataTypes.scala | 19 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala index c7f8e383ec868..eb8898900d6a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WrapDynamic.scala @@ -21,10 +21,16 @@ import scala.language.dynamics import org.apache.spark.sql.catalyst.types.DataType +/** + * The data type representing [[DynamicRow]] values. + */ case object DynamicType extends DataType { def simpleString: String = "dynamic" } +/** + * Wrap a [[Row]] as a [[DynamicRow]]. + */ case class WrapDynamic(children: Seq[Attribute]) extends Expression { type EvaluatedType = DynamicRow @@ -39,6 +45,11 @@ case class WrapDynamic(children: Seq[Attribute]) extends Expression { } } +/** + * DynamicRows use scala's Dynamic trait to emulate an ORM of in a dynamically typed language. + * Since the type of the column is not known at compile time, all attributes are converted to + * strings before being passed to the function. + */ class DynamicRow(val schema: Seq[Attribute], values: Array[Any]) extends GenericRow(values) with Dynamic { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index c006f82e9031f..e6eb5a0744d16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -266,6 +266,13 @@ object ArrayType { def apply(elementType: DataType): ArrayType = ArrayType(elementType, false) } +/** + * The data type for collections of multiple values. + * Internally these are represented as columns that contain a ``scala.collection.Seq``. + * + * @param elementType The data type of values. + * @param containsNull Indicates if values have `null` values + */ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataType { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append( @@ -276,6 +283,12 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT def simpleString: String = "array" } +/** + * A field inside a StructType. + * @param name The name of this field. + * @param dataType The data type of this field. + * @param nullable Indicates if values of this field can be `null` values. + */ case class StructField(name: String, dataType: DataType, nullable: Boolean) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { @@ -353,6 +366,12 @@ object MapType { MapType(keyType: DataType, valueType: DataType, true) } +/** + * The data type for Maps. Keys in a map are not allowed to have `null` values. + * @param keyType The data type of map keys. + * @param valueType The data type of map values. + * @param valueContainsNull Indicates if map values have `null` values. + */ case class MapType( keyType: DataType, valueType: DataType, From c712fbf998a0eefbb1acbe433cc3e7dcd7fec6d7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 29 Jul 2014 22:18:07 -0700 Subject: [PATCH 33/34] Converts types of values based on defined schema. --- .../apache/spark/api/python/PythonRDD.scala | 3 +- python/pyspark/sql.py | 58 +++---- .../spark/sql/catalyst/types/dataTypes.scala | 5 +- .../org/apache/spark/sql/SQLContext.scala | 160 +++++++++++------- .../org/apache/spark/sql/SchemaRDD.scala | 4 +- 5 files changed, 134 insertions(+), 96 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0d8453fb184a3..f551a59ee3fe8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -544,7 +544,8 @@ private[spark] object PythonRDD extends Logging { } /** - * Convert an RDD of serialized Python dictionaries to Scala Maps + * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions). + * It is only used by pyspark.sql. * TODO: Support more Python types. */ def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = { diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 3f2f4dad49a83..13f0ed4e35490 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -111,12 +111,7 @@ def __repr__(self): class FloatType(object): """Spark SQL FloatType - For now, please use L{DoubleType} instead of using L{FloatType}. - Because query evaluation is done in Scala, java.lang.Double will be be used - for Python float numbers. Because the underlying JVM type of FloatType is - java.lang.Float (in Java) and Float (in scala), and we are trying to cast the type, - there will be a java.lang.ClassCastException - if FloatType (Python) is used. + The data type representing single precision floating-point values. """ __metaclass__ = PrimitiveTypeSingleton @@ -128,12 +123,7 @@ def __repr__(self): class ByteType(object): """Spark SQL ByteType - For now, please use L{IntegerType} instead of using L{ByteType}. - Because query evaluation is done in Scala, java.lang.Integer will be be used - for Python int numbers. Because the underlying JVM type of ByteType is - java.lang.Byte (in Java) and Byte (in scala), and we are trying to cast the type, - there will be a java.lang.ClassCastException - if ByteType (Python) is used. + The data type representing int values with 1 singed byte. """ __metaclass__ = PrimitiveTypeSingleton @@ -170,12 +160,7 @@ def __repr__(self): class ShortType(object): """Spark SQL ShortType - For now, please use L{IntegerType} instead of using L{ShortType}. - Because query evaluation is done in Scala, java.lang.Integer will be be used - for Python int numbers. Because the underlying JVM type of ShortType is - java.lang.Short (in Java) and Short (in scala), and we are trying to cast the type, - there will be a java.lang.ClassCastException - if ShortType (Python) is used. + The data type representing int values with 2 signed bytes. """ __metaclass__ = PrimitiveTypeSingleton @@ -198,7 +183,6 @@ def __init__(self, elementType, containsNull=False): :param elementType: the data type of elements. :param containsNull: indicates whether the list contains None values. - :return: >>> ArrayType(StringType) == ArrayType(StringType, False) True @@ -238,7 +222,6 @@ def __init__(self, keyType, valueType, valueContainsNull=True): :param keyType: the data type of keys. :param valueType: the data type of values. :param valueContainsNull: indicates whether values contains null values. - :return: >>> MapType(StringType, IntegerType) == MapType(StringType, IntegerType, True) True @@ -279,7 +262,6 @@ def __init__(self, name, dataType, nullable): :param name: the name of this field. :param dataType: the data type of this field. :param nullable: indicates whether values of this field can be null. - :return: >>> StructField("f1", StringType, True) == StructField("f1", StringType, True) True @@ -314,8 +296,6 @@ class StructType(object): """ def __init__(self, fields): """Creates a StructType - :param fields: - :return: >>> struct1 = StructType([StructField("f1", StringType, True)]) >>> struct2 = StructType([StructField("f1", StringType, True)]) @@ -342,11 +322,7 @@ def __ne__(self, other): def _parse_datatype_list(datatype_list_string): - """Parses a list of comma separated data types. - - :param datatype_list_string: - :return: - """ + """Parses a list of comma separated data types.""" index = 0 datatype_list = [] start = 0 @@ -372,9 +348,6 @@ def _parse_datatype_list(datatype_list_string): def _parse_datatype_string(datatype_string): """Parses the given data type string. - :param datatype_string: - :return: - >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.__repr__()) ... python_datatype = _parse_datatype_string(scala_datatype.toString()) @@ -582,9 +555,6 @@ def inferSchema(self, rdd): def applySchema(self, rdd, schema): """Applies the given schema to the given RDD of L{dict}s. - :param rdd: - :param schema: - :return: >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) @@ -594,9 +564,27 @@ def applySchema(self, rdd, schema): >>> srdd2.collect() == [{"field1" : 1, "field2" : "row1"}, {"field1" : 2, "field2": "row2"}, ... {"field1" : 3, "field2": "row3"}] True + >>> from datetime import datetime + >>> rdd = sc.parallelize([{"byte": 127, "short": -32768, "float": 1.0, + ... "time": datetime(2010, 1, 1, 1, 1, 1), "map": {"a": 1}, "struct": {"b": 2}, + ... "list": [1, 2, 3]}]) + >>> schema = StructType([ + ... StructField("byte", ByteType(), False), + ... StructField("short", ShortType(), False), + ... StructField("float", FloatType(), False), + ... StructField("time", TimestampType(), False), + ... StructField("map", MapType(StringType(), IntegerType(), False), False), + ... StructField("struct", StructType([StructField("b", ShortType(), False)]), False), + ... StructField("list", ArrayType(ByteType(), False), False), + ... StructField("null", DoubleType(), True)]) + >>> srdd = sqlCtx.applySchema(rdd, schema).map( + ... lambda x: ( + ... x.byte, x.short, x.float, x.time, x.map["a"], x.struct["b"], x.list, x.null)) + >>> srdd.collect()[0] + (127, -32768, 1.0, datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) """ jrdd = self._pythonToJavaMap(rdd._jrdd) - srdd = self._ssql_ctx.applySchema(jrdd.rdd(), schema.__repr__()) + srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.__repr__()) return SchemaRDD(srdd, self) def registerRDDAsTable(self, rdd, tableName): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index e6eb5a0744d16..ea7120022c51d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -313,13 +313,13 @@ case class StructType(fields: Seq[StructField]) extends DataType { */ lazy val fieldNames: Seq[String] = fields.map(_.name) private lazy val fieldNamesSet: Set[String] = fieldNames.toSet - + private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. */ def apply(name: String): StructField = { - fields.find(f => f.name == name).getOrElse( + nameToField.get(name).getOrElse( throw new IllegalArgumentException(s"Field ${name} does not exist.")) } @@ -333,6 +333,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { throw new IllegalArgumentException( s"Field ${nonExistFields.mkString(",")} does not exist.") } + // Preserve the original order of fields. StructType(fields.filter(f => names.contains(f.name))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index cc2bf7059ca7a..61aa0882c476a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -125,29 +125,6 @@ class SQLContext(@transient val sparkContext: SparkContext) new SchemaRDD(this, logicalPlan) } - /** - * Parses the data type in our internal string representation. The data type string should - * have the same format as the one generated by `toString` in scala. - * It is only used by PySpark. - */ - private[sql] def parseDataType(dataTypeString: String): DataType = { - val parser = org.apache.spark.sql.catalyst.types.DataType - parser(dataTypeString) - } - - /** - * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. - */ - private[sql] def applySchema(rdd: RDD[Map[String, _]], schemaString: String): SchemaRDD = { - val schema = parseDataType(schemaString).asInstanceOf[StructType] - val rowRdd = rdd.mapPartitions { iter => - iter.map { map => - new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row - } - } - new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))) - } - /** * Loads a Parquet file, returning the result as a [[SchemaRDD]]. * @@ -438,6 +415,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ private[sql] def inferSchema(rdd: RDD[Map[String, _]]): SchemaRDD = { import scala.collection.JavaConversions._ + def typeOfComplexValue: PartialFunction[Any, DataType] = { case c: java.util.Calendar => TimestampType case c: java.util.List[_] => @@ -453,48 +431,116 @@ class SQLContext(@transient val sparkContext: SparkContext) def typeOfObject = ScalaReflection.typeOfObject orElse typeOfComplexValue val firstRow = rdd.first() - val schema = StructType( - firstRow.map { case (fieldName, obj) => - StructField(fieldName, typeOfObject(obj), true) - }.toSeq) - - def needTransform(obj: Any): Boolean = obj match { - case c: java.util.List[_] => true - case c: java.util.Map[_, _] => true - case c if c.getClass.isArray => true - case c: java.util.Calendar => true - case c => false + val fields = firstRow.map { + case (fieldName, obj) => StructField(fieldName, typeOfObject(obj), true) + }.toSeq + + applySchemaToPythonRDD(rdd, StructType(fields)) + } + + /** + * Parses the data type in our internal string representation. The data type string should + * have the same format as the one generated by `toString` in scala. + * It is only used by PySpark. + */ + private[sql] def parseDataType(dataTypeString: String): DataType = { + val parser = org.apache.spark.sql.catalyst.types.DataType + parser(dataTypeString) + } + + /** + * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Map[String, _]], + schemaString: String): SchemaRDD = { + val schema = parseDataType(schemaString).asInstanceOf[StructType] + applySchemaToPythonRDD(rdd, schema) + } + + /** + * Apply a schema defined by the schema to an RDD. It is only used by PySpark. + */ + private[sql] def applySchemaToPythonRDD( + rdd: RDD[Map[String, _]], + schema: StructType): SchemaRDD = { + import scala.collection.JavaConversions._ + import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper} + + def needsConversion(dataType: DataType): Boolean = dataType match { + case ByteType => true + case ShortType => true + case FloatType => true + case TimestampType => true + case ArrayType(_, _) => true + case MapType(_, _, _) => true + case StructType(_) => true + case other => false } - // convert JList, JArray into Seq, convert JMap into Map - // convert Calendar into Timestamp - def transform(obj: Any): Any = obj match { - case c: java.util.List[_] => c.map(transform).toSeq - case c: java.util.Map[_, _] => c.map { - case (key, value) => (key, transform(value)) - }.toMap - case c if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(transform).toSeq - case c: java.util.Calendar => - new java.sql.Timestamp(c.getTime().getTime()) - case c => c + // Converts value to the type specified by the data type. + // Because Python does not have data types for TimestampType, FloatType, ShortType, and + // ByteType, we need to explicitly convert values in columns of these data types to the desired + // JVM data types. + def convert(obj: Any, dataType: DataType): Any = (obj, dataType) match { + // TODO: We should check nullable + case (null, _) => null + + case (c: java.util.List[_], ArrayType(elementType, _)) => + val converted = c.map { e => convert(e, elementType)} + JListWrapper(converted) + + case (c: java.util.Map[_, _], struct: StructType) => + val row = new GenericMutableRow(struct.fields.length) + struct.fields.zipWithIndex.foreach { + case (field, i) => + val value = convert(c.get(field.name), field.dataType) + row.update(i, value) + } + row + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val converted = c.map { + case (key, value) => + (convert(key, keyType), convert(value, valueType)) + } + JMapWrapper(converted) + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + val converted = c.asInstanceOf[Array[_]].map(e => convert(e, elementType)) + converted: Seq[Any] + + case (c: java.util.Calendar, TimestampType) => new java.sql.Timestamp(c.getTime().getTime()) + case (c: Int, ByteType) => c.toByte + case (c: Int, ShortType) => c.toShort + case (c: Double, FloatType) => c.toFloat + + case (c, _) => c + } + + val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) { + rdd.map(m => m.map { case (key, value) => (key, convert(value, schema(key).dataType)) }) + } else { + rdd } - val need = firstRow.exists { case (key, value) => needTransform(value) } - val transformed = if (need) { - rdd.mapPartitions { iter => - iter.map { - m => m.map {case (key, value) => (key, transform(value))} + val rowRdd = convertedRdd.mapPartitions { iter => + val row = new GenericMutableRow(schema.fields.length) + val fieldsWithIndex = schema.fields.zipWithIndex + iter.map { m => + // We cannot use m.values because the order of values returned by m.values may not + // match fields order. + fieldsWithIndex.foreach { + case (field, i) => + val value = + m.get(field.name).flatMap(v => Option(v)).map(v => convert(v, field.dataType)).orNull + row.update(i, value) } - } - } else rdd - val rowRdd = transformed.mapPartitions { iter => - iter.map { map => - new GenericRow(map.values.toArray.asInstanceOf[Array[Any]]): Row + row: Row } } + new SchemaRDD(this, SparkLogicalPlan(ExistingRdd(schema.toAttributes, rowRdd))) } - } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala index 0940300a72983..2a79abb92d247 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.{Map => JMap, List => JList, Set => JSet} +import java.util.{Map => JMap, List => JList} import scala.collection.JavaConversions._ import scala.collection.JavaConverters._ @@ -380,6 +380,8 @@ class SchemaRDD( * Converts a JavaRDD to a PythonRDD. It is used by pyspark. */ private[sql] def javaToPython: JavaRDD[Array[Byte]] = { + import scala.collection.Map + def toJava(obj: Any, dataType: DataType): Any = dataType match { case struct: StructType => rowToMap(obj.asInstanceOf[Row], struct) case array: ArrayType => obj match { From 1d45977bd69e5569444a6086562e5528b886123b Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 29 Jul 2014 22:46:36 -0700 Subject: [PATCH 34/34] Clean up. --- .../scala/org/apache/spark/sql/SQLContext.scala | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 662e920463e8d..86338752a21c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -24,15 +24,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental} import org.apache.spark.rdd.RDD -<<<<<<< HEAD -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.dsl.ExpressionConversions -import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.ScalaReflection -======= import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.dsl.ExpressionConversions @@ -40,8 +31,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types._ ->>>>>>> upstream/master import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.SparkStrategies @@ -477,6 +466,8 @@ class SQLContext(@transient val sparkContext: SparkContext) private[sql] def applySchemaToPythonRDD( rdd: RDD[Map[String, _]], schema: StructType): SchemaRDD = { + // TODO: We should have a better implementation once we do not turn a Python side record + // to a Map. import scala.collection.JavaConversions._ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}