Skip to content

Commit

Permalink
Changing UDT to annotation
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbradley committed Nov 2, 2014
1 parent fea04af commit b226b9e
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 115 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.annotation;

import java.lang.annotation.*;
import java.lang.annotation.*;

/**
* A lower-level, unstable API intended for developers.
Expand All @@ -31,5 +31,5 @@
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.TYPE, ElementType.FIELD, ElementType.METHOD, ElementType.PARAMETER,
ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
ElementType.CONSTRUCTOR, ElementType.LOCAL_VARIABLE, ElementType.PACKAGE})
public @interface DeveloperApi {}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst

import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.annotation.UserDefinedType
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.types._
Expand All @@ -35,7 +36,11 @@ object ScalaReflection {

case class Schema(dataType: DataType, nullable: Boolean)

/** Converts Scala objects to catalyst rows / types */
/**
* Converts Scala objects to catalyst rows / types.
* Note: This is always called after schemaFor has been called.
* This ordering is important for UDT registration.
*/
def convertToCatalyst(a: Any, dataType: DataType): Any = (a, dataType) match {
case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType))
Expand All @@ -48,7 +53,7 @@ object ScalaReflection {
convertToCatalyst(elem, field.dataType)
}.toArray)
case (d: BigDecimal, _) => Decimal(d)
case (udt, udtType: UserDefinedType[_]) => udtType.serialize(udt)
case (udt, udtType: UserDefinedTypeType[_]) => udtType.serialize(udt)
case (other, _) => other
}

Expand All @@ -59,7 +64,7 @@ object ScalaReflection {
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (d: Decimal, DecimalType) => d.toBigDecimal
case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt)
case (udt: Row, udtType: UserDefinedTypeType[_]) => udtType.deserialize(udt)
case (other, _) => other
}

Expand All @@ -70,69 +75,69 @@ object ScalaReflection {
}

/** Returns a Sequence of attributes for the given case class type. */
def attributesFor[T: TypeTag](
udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Seq[Attribute] = {
schemaFor[T](udtRegistry) match {
def attributesFor[T: TypeTag]: Seq[Attribute] = {
schemaFor[T] match {
case Schema(s: StructType, _) =>
s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
}
}

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor[T: TypeTag](udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema =
schemaFor(typeOf[T], udtRegistry)
def schemaFor[T: TypeTag]: Schema = schemaFor(typeOf[T])

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(
tpe: `Type`,
udtRegistry: scala.collection.Map[Any, UserDefinedType[_]]): Schema = tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType, udtRegistry).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs), udtRegistry)
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType, udtRegistry)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType, udtRegistry)
Schema(MapType(schemaFor(keyType, udtRegistry).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if udtRegistry.contains(t) =>
Schema(udtRegistry(t), nullable = true)
def schemaFor(tpe: `Type`): Schema = {
println(s"schemaFor: tpe = $tpe, tpe.getClass = ${tpe.getClass}, classOf[UserDefinedType] = ${classOf[UserDefinedType]}")
tpe match {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
val params = t.member(nme.CONSTRUCTOR).asMethod.paramss
Schema(StructType(
params.head.map { p =>
val Schema(dataType, nullable) =
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
// Need to decide if we actually need a special type here.
case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< typeOf[Array[_]] =>
sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
case t if t <:< typeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< typeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[Date] => Schema(DateType, nullable = true)
case t if t <:< typeOf[BigDecimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[Decimal] => Schema(DecimalType.Unlimited, nullable = true)
case t if t <:< typeOf[java.lang.Integer] => Schema(IntegerType, nullable = true)
case t if t <:< typeOf[java.lang.Long] => Schema(LongType, nullable = true)
case t if t <:< typeOf[java.lang.Double] => Schema(DoubleType, nullable = true)
case t if t <:< typeOf[java.lang.Float] => Schema(FloatType, nullable = true)
case t if t <:< typeOf[java.lang.Short] => Schema(ShortType, nullable = true)
case t if t <:< typeOf[java.lang.Byte] => Schema(ByteType, nullable = true)
case t if t <:< typeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true)
case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false)
case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false)
case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false)
case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false)
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if t.getClass.isAnnotationPresent(classOf[UserDefinedType]) =>
UDTRegistry.registerType(t)
Schema(UDTRegistry.udtRegistry(t), nullable = true)
}
}

def typeOfObject: PartialFunction[Any, DataType] = {
Expand Down Expand Up @@ -166,7 +171,7 @@ object ScalaReflection {
def asRelation: LocalRelation = {
// Pass empty map to attributesFor since this method is only used for debugging Catalyst,
// not used with SparkSQL.
val output = attributesFor[A](Map.empty)
val output = attributesFor[A]
LocalRelation(output, data)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.annotation.UserDefinedType

import scala.collection.mutable

import org.apache.spark.sql.catalyst.types.UserDefinedTypeType

import scala.reflect.runtime.universe._

/**
* Global registry for user-defined types (UDTs).
*/
private[sql] object UDTRegistry {
/** Map: UserType --> UserDefinedType */
val udtRegistry = new mutable.HashMap[Any, UserDefinedTypeType[_]]()

/**
* Register a user-defined type and its serializer, to allow automatic conversion between
* RDDs of user types and SchemaRDDs.
* Fails if this type has been registered already.
*/
/*
def registerType[UserType](implicit userType: TypeTag[UserType]): Unit = {
// TODO: Check to see if type is built-in. Throw exception?
val udt: UserDefinedTypeType[_] =
userType.getClass.getAnnotation(classOf[UserDefinedType]).udt().newInstance()
UDTRegistry.udtRegistry(userType.tpe) = udt
}*/

def registerType[UserType](implicit userType: Type): Unit = {
// TODO: Check to see if type is built-in. Throw exception?
if (!UDTRegistry.udtRegistry.contains(userType)) {
val udt: UserDefinedTypeType[_] =
userType.getClass.getAnnotation(classOf[UserDefinedType]).udt().newInstance()
UDTRegistry.udtRegistry(userType) = udt
}
// TODO: Else: Should we check (assert) that udt is the same as what is in the registry?
}

//def getUDT
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.annotation;

import org.apache.spark.sql.catalyst.types.UserDefinedTypeType;

import java.lang.annotation.*;

/**
* A user-defined type which can be automatically recognized by a SQLContext and registered.
*/
// TODO: Should I used @Documented ?
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface UserDefinedType {
Class<? extends UserDefinedTypeType<?> > udt();
}
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ case class MapType(
/**
* The data type for User Defined Types.
*/
abstract class UserDefinedType[UserType] extends DataType {
abstract class UserDefinedTypeType[UserType] extends DataType {

/** Underlying storage type for this UDT used by SparkSQL */
def sqlType: DataType
Expand Down
23 changes: 2 additions & 21 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.types.UserDefinedType

import scala.collection.mutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}

Expand Down Expand Up @@ -104,7 +101,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]) = {
SparkPlan.currentContext.set(self)
println(s"createSchemaRDD called")
val attributeSeq = ScalaReflection.attributesFor[A](udtRegistry)
val attributeSeq = ScalaReflection.attributesFor[A]
val schema = StructType.fromAttributes(attributeSeq)
val rowRDD = RDDConversions.productToRowRdd(rdd, schema)
println("done with productToRowRdd")
Expand Down Expand Up @@ -259,7 +256,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
new SchemaRDD(
this,
ParquetRelation.createEmpty(
path, ScalaReflection.attributesFor[A](udtRegistry), allowExisting, conf, this))
path, ScalaReflection.attributesFor[A], allowExisting, conf, this))
}

/**
Expand Down Expand Up @@ -290,22 +287,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
def table(tableName: String): SchemaRDD =
new SchemaRDD(this, catalog.lookupRelation(None, tableName))

/**
* Register a user-defined type and its serializer, to allow automatic conversion between
* RDDs of user types and SchemaRDDs.
* Fails if this type has been registered already.
*/
def registerType[UserType](
udt: UserDefinedType[UserType])(implicit userType: TypeTag[UserType]): Unit = {
require(!udtRegistry.contains(userType.tpe),
"registerUserType called on type which was already registered.")
// TODO: Check to see if type is built-in. Throw exception?
udtRegistry(userType.tpe) = udt
}

/** Map: UserType --> UserDefinedType */
protected[sql] val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]()

protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext

Expand Down
Loading

0 comments on commit b226b9e

Please sign in to comment.