Skip to content

Commit

Permalink
support UDT in parquet
Browse files Browse the repository at this point in the history
  • Loading branch information
mengxr authored and jkbradley committed Nov 2, 2014
1 parent db16139 commit cfbc321
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ object DatasetExample {

// Convert input data to SchemaRDD explicitly.
val schemaRDD: SchemaRDD = origData
println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")

// Select columns, using implicit conversion to SchemaRDD.
Expand All @@ -95,6 +96,16 @@ object DatasetExample {
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

schemaRDD.saveAsParquetFile("/tmp/dataset")
val newDataset = sqlContext.parquetFile("/tmp/dataset")

println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

sc.stop()
}

Expand Down
67 changes: 23 additions & 44 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,6 @@ object Vectors {
// Note: Explicit registration is only needed for Vector and SparseVector;
// the annotation works for DenseVector.
UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[Vector], new VectorUDT())
UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[DenseVector],
new DenseVectorUDT())
UDTRegistry.registerType(scala.reflect.runtime.universe.typeOf[SparseVector],
new SparseVectorUDT())

Expand Down Expand Up @@ -204,7 +202,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
@SQLUserDefinedType(udt = classOf[DenseVectorUDT])
@SQLUserDefinedType(serdes = classOf[DenseVectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {

override def size: Int = values.length
Expand Down Expand Up @@ -261,16 +259,16 @@ class SparseVector(
* User-defined type for [[Vector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class VectorUDT extends UserDefinedType[Vector] {
private[spark] class VectorUDT extends UserDefinedTypeSerDes[Vector] {

/**
* vectorType: 0 = dense, 1 = sparse.
* dense, sparse: One element holds the vector, and the other is null.
*/
override def sqlType: StructType = StructType(Seq(
StructField("vectorType", ByteType, nullable = false),
StructField("dense", new DenseVectorUDT(), nullable = true),
StructField("sparse", new SparseVectorUDT(), nullable = true)))
StructField("dense", new UserDefinedType(new DenseVectorUDT), nullable = true),
StructField("sparse", new UserDefinedType(new SparseVectorUDT), nullable = true)))

override def serialize(obj: Any): Row = {
val row = new GenericMutableRow(3)
Expand Down Expand Up @@ -298,82 +296,63 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
new SparseVectorUDT().deserialize(row.getAs[Row](2))
}
}

override def userType: Class[Vector] = classOf[Vector]
}

/**
* User-defined type for [[DenseVector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class DenseVectorUDT extends UserDefinedType[DenseVector] {
private[spark] class DenseVectorUDT extends UserDefinedTypeSerDes[DenseVector] {

override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false)

override def serialize(obj: Any): Row = obj match {
case v: DenseVector =>
val row: GenericMutableRow = new GenericMutableRow(v.size)
var i = 0
while (i < v.size) {
row.setDouble(i, v(i))
i += 1
}
val row: GenericMutableRow = new GenericMutableRow(1)
row.update(0, v.values.toSeq)
row
}

override def deserialize(row: Row): DenseVector = {
val values = new Array[Double](row.length)
var i = 0
while (i < row.length) {
values(i) = row.getDouble(i)
i += 1
}
val values = row.getAs[Seq[Double]](0).toArray
new DenseVector(values)
}

override def userType: Class[DenseVector] = classOf[DenseVector]
}

/**
* User-defined type for [[SparseVector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class SparseVectorUDT extends UserDefinedType[SparseVector] {
private[spark] class SparseVectorUDT extends UserDefinedTypeSerDes[SparseVector] {

override def sqlType: StructType = StructType(Seq(
StructField("size", IntegerType, nullable = false),
StructField("indices", ArrayType(DoubleType, containsNull = false), nullable = false),
StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = false),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = false)))

override def serialize(obj: Any): Row = obj match {
case v: SparseVector =>
val nnz = v.indices.size
val row: GenericMutableRow = new GenericMutableRow(1 + 2 * nnz)
val row: GenericMutableRow = new GenericMutableRow(3)
row.setInt(0, v.size)
var i = 0
while (i < nnz) {
row.setInt(1 + i, v.indices(i))
i += 1
}
i = 0
while (i < nnz) {
row.setDouble(1 + nnz + i, v.values(i))
i += 1
}
row.update(1, v.indices.toSeq)
row.update(2, v.values.toSeq)
row
case row: Row =>
row
}

override def deserialize(row: Row): SparseVector = {
require(row.length >= 1,
s"SparseVectorUDT.deserialize given row with length ${row.length} but requires length >= 1")
val vSize = row.getInt(0)
val nnz: Int = (row.length - 1) / 2
require(nnz * 2 + 1 == row.length,
s"SparseVectorUDT.deserialize given row with non-matching indices, values lengths")
val indices = new Array[Int](nnz)
val values = new Array[Double](nnz)
var i = 0
while (i < nnz) {
indices(i) = row.getInt(1 + i)
values(i) = row.getDouble(1 + nnz + i)
i += 1
}
val indices = row.getAs[Seq[Int]](1).toArray
val values = row.getAs[Seq[Double]](2).toArray
new SparseVector(vSize, indices, values)
}

override def userType: Class[SparseVector] = classOf[SparseVector]
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object ScalaReflection {
convertToCatalyst(elem, field.dataType)
}.toArray)
case (d: BigDecimal, _) => Decimal(d)
case (udt, udtType: UserDefinedType[_]) => udtType.serialize(udt)
case (udt, udtType: UserDefinedTypeSerDes[_]) => udtType.serialize(udt)
case (other, _) => other
}

Expand All @@ -64,8 +64,8 @@ object ScalaReflection {
case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
convertToScala(k, mapType.keyType) -> convertToScala(v, mapType.valueType)
}
case (d: Decimal, DecimalType) => d.toBigDecimal
case (udt: Row, udtType: UserDefinedType[_]) => udtType.deserialize(udt)
case (d: Decimal, _: DecimalType) => d.toBigDecimal
case (r: Row, udt: UserDefinedType[_]) => udt.serdes.deserialize(r)
case (other, _) => other
}

Expand Down Expand Up @@ -94,12 +94,12 @@ object ScalaReflection {
// whereas className is from Scala reflection. This can make it hard to find classes
// in some cases, such as when a class is enclosed in an object (in which case
// Java appends a '$' to the object name but Scala does not).
val udt = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
UDTRegistry.registerType(t, udt)
Schema(udt, nullable = true)
val serdes = Utils.classForName(className)
.getAnnotation(classOf[SQLUserDefinedType]).serdes().newInstance()
UDTRegistry.registerType(t, serdes)
Schema(new UserDefinedType(serdes), nullable = true)
case t if UDTRegistry.udtRegistry.contains(t) =>
Schema(UDTRegistry.udtRegistry(t), nullable = true)
Schema(new UserDefinedType(UDTRegistry.udtRegistry(t)), nullable = true)
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.reflect.runtime.universe._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes

/**
* ::DeveloperApi::
Expand All @@ -32,14 +32,14 @@ import org.apache.spark.sql.catalyst.types.UserDefinedType
@DeveloperApi
object UDTRegistry {
/** Map: UserType --> UserDefinedType */
val udtRegistry = new mutable.HashMap[Any, UserDefinedType[_]]()
val udtRegistry = new mutable.HashMap[Any, UserDefinedTypeSerDes[_]]()

/**
* Register a user-defined type and its serializer, to allow automatic conversion between
* RDDs of user types and SchemaRDDs.
* If this type has already been registered, this does nothing.
*/
def registerType(userType: Type, udt: UserDefinedType[_]): Unit = {
def registerType(userType: Type, udt: UserDefinedTypeSerDes[_]): Unit = {
// TODO: Check to see if type is built-in. Throw exception?
UDTRegistry.udtRegistry(userType) = udt
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import java.lang.annotation.*;

import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.sql.catalyst.types.UserDefinedType;
import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes;

/**
* ::DeveloperApi::
Expand All @@ -38,5 +38,5 @@
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.TYPE)
public @interface SQLUserDefinedType {
Class<? extends UserDefinedType<?> > udt();
Class<? extends UserDefinedTypeSerDes<?> > serdes();
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.ScalaReflectionLock
import org.apache.spark.sql.catalyst.{UDTRegistry, ScalaReflectionLock}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression, Row}
import org.apache.spark.sql.catalyst.types.decimal._
import org.apache.spark.sql.catalyst.util.Metadata
Expand Down Expand Up @@ -68,6 +68,13 @@ object DataType {
("fields", JArray(fields)),
("type", JString("struct"))) =>
StructType(fields.map(parseStructField))

case JSortedObject(
("serdes", JString(serdesClass)),
("type", JString("udt"))) => {
val serdes = Class.forName(serdesClass).newInstance().asInstanceOf[UserDefinedTypeSerDes[_]]
new UserDefinedType(serdes)
}
}

private def parseStructField(json: JValue): StructField = json match {
Expand Down Expand Up @@ -573,7 +580,7 @@ case class MapType(
* The data type for User Defined Types (UDTs).
*
* This interface allows a user to make their own classes more interoperable with SparkSQL;
* e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create a SchemaRDD
* e.g., by creating a [[UserDefinedTypeSerDes]] for a class X, it becomes possible to create a SchemaRDD
* which has class X in the schema.
*
* For SparkSQL to recognize UDTs, the UDT must be registered in
Expand All @@ -586,7 +593,9 @@ case class MapType(
* The conversion via `deserialize` occurs when reading from a `SchemaRDD`.
*/
@DeveloperApi
abstract class UserDefinedType[UserType] extends DataType with Serializable {
abstract class UserDefinedTypeSerDes[UserType] extends Serializable {

def userType: Class[UserType]

/** Underlying storage type for this UDT used by SparkSQL */
def sqlType: DataType
Expand All @@ -598,6 +607,12 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {

/** Convert a Row object to the user type */
def deserialize(row: Row): UserType
}

def simpleString: String = "udt"
case class UserDefinedType[UserType](serdes: UserDefinedTypeSerDes[UserType])
extends DataType with Serializable {
override private[sql] def jsonValue: JValue = {
("type" -> "udt") ~
("serdes" -> serdes.getClass.getName)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ private[sql] object CatalystConverter {
fieldIndex,
parent)
}
case UserDefinedType(serdes) => {
createConverter(field.copy(dataType = serdes.sqlType), fieldIndex, parent)
}
// Strings, Shorts and Bytes do not have a corresponding type in Parquet
// so we need to treat them separately
case StringType => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ private[parquet] class RowWriteSupport extends WriteSupport[Row] with Logging {
case t @ StructType(_) => writeStruct(
t,
value.asInstanceOf[CatalystConverter.StructScalaType[_]])
case UserDefinedType(serdes) => {
println(value.getClass)
writeValue(serdes.sqlType, serdes.serialize(value))
}
case _ => writePrimitive(schema.asInstanceOf[PrimitiveType], value)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ private[parquet] object ParquetTypesConverter extends Logging {
parquetKeyType,
parquetValueType)
}
case UserDefinedType(serdes) => {
fromDataType(serdes.sqlType, name, nullable, inArray)
}
case _ => sys.error(s"Unsupported datatype $ctype")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.UDTRegistry
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.types.UserDefinedType
import org.apache.spark.sql.catalyst.types.UserDefinedTypeSerDes
import org.apache.spark.sql.test.TestSQLContext._

@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT])
@SQLUserDefinedType(serdes = classOf[MyDenseVectorUDT])
class MyDenseVector(val data: Array[Double]) extends Serializable {
override def equals(other: Any): Boolean = other match {
case v: MyDenseVector =>
Expand All @@ -35,7 +35,9 @@ class MyDenseVector(val data: Array[Double]) extends Serializable {

case class MyLabeledPoint(label: Double, features: MyDenseVector)

class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
class MyDenseVectorUDT extends UserDefinedTypeSerDes[MyDenseVector] {

override def userType: Class[MyDenseVector] = classOf[MyDenseVector]

override def sqlType: ArrayType = ArrayType(DoubleType, containsNull = false)

Expand Down

0 comments on commit cfbc321

Please sign in to comment.