Skip to content

Commit

Permalink
[SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD
Browse files Browse the repository at this point in the history
Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and then select columns or save to a Parquet file. Examples in Scala/Python are attached. The Scala code was copied from jkbradley.

~~This PR contains the changes from #3068 . I will rebase after #3068 is merged.~~

marmbrus jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #3070 from mengxr/SPARK-3573 and squashes the following commits:

3a0b6e5 [Xiangrui Meng] organize imports
236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples
  • Loading branch information
mengxr committed Nov 4, 2014
1 parent 04450d1 commit 1a9c6cd
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 6 deletions.
2 changes: 1 addition & 1 deletion dev/run-tests
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
if [ -n "$_SQL_TESTS_ONLY" ]; then
# This must be an array of individual arguments. Otherwise, having one long string
#+ will be interpreted as a single test, which doesn't work.
SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test")
SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test")
else
SBT_MAVEN_TEST_ARGS=("test")
fi
Expand Down
62 changes: 62 additions & 0 deletions examples/src/main/python/mllib/dataset_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
An example of how to use SchemaRDD as a dataset for ML. Run with::
bin/spark-submit examples/src/main/python/mllib/dataset_example.py
"""

import os
import sys
import tempfile
import shutil

from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.mllib.util import MLUtils
from pyspark.mllib.stat import Statistics


def summarize(dataset):
print "schema: %s" % dataset.schema().json()
labels = dataset.map(lambda r: r.label)
print "label average: %f" % labels.mean()
features = dataset.map(lambda r: r.features)
summary = Statistics.colStats(features)
print "features average: %r" % summary.mean()

if __name__ == "__main__":
if len(sys.argv) > 2:
print >> sys.stderr, "Usage: dataset_example.py <libsvm file>"
exit(-1)
sc = SparkContext(appName="DatasetExample")
sqlCtx = SQLContext(sc)
if len(sys.argv) == 2:
input = sys.argv[1]
else:
input = "data/mllib/sample_libsvm_data.txt"
points = MLUtils.loadLibSVMFile(sc, input)
dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache()
summarize(dataset0)
tempdir = tempfile.NamedTemporaryFile(delete=False).name
os.unlink(tempdir)
print "Save dataset as a Parquet file to %s." % tempdir
dataset0.saveAsParquetFile(tempdir)
print "Load it back and summarize it again."
dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache()
summarize(dataset1)
shutil.rmtree(tempdir)
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.examples.mllib

import java.io.File

import com.google.common.io.Files
import scopt.OptionParser

import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}

/**
* An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
* {{{
* ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DatasetExample {

case class Params(
input: String = "data/mllib/sample_libsvm_data.txt",
dataFormat: String = "libsvm") extends AbstractParams[Params]

def main(args: Array[String]) {
val defaultParams = Params()

val parser = new OptionParser[Params]("DatasetExample") {
head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
opt[String]("input")
.text(s"input path to dataset")
.action((x, c) => c.copy(input = x))
opt[String]("dataFormat")
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
.action((x, c) => c.copy(input = x))
checkConfig { params =>
success
}
}

parser.parse(args, defaultParams).map { params =>
run(params)
}.getOrElse {
sys.exit(1)
}
}

def run(params: Params) {

val conf = new SparkConf().setAppName(s"DatasetExample with $params")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext._ // for implicit conversions

// Load input data
val origData: RDD[LabeledPoint] = params.dataFormat match {
case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
}
println(s"Loaded ${origData.count()} instances from file: ${params.input}")

// Convert input data to SchemaRDD explicitly.
val schemaRDD: SchemaRDD = origData
println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")

// Select columns, using implicit conversion to SchemaRDD.
val labelsSchemaRDD: SchemaRDD = origData.select('label)
val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
val numLabels = labels.count()
val meanLabel = labels.fold(0.0)(_ + _) / numLabels
println(s"Selected label column with average value $meanLabel")

val featuresSchemaRDD: SchemaRDD = origData.select('features)
val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")

val tmpDir = Files.createTempDir()
tmpDir.deleteOnExit()
val outputDir = new File(tmpDir, "dataset").toString
println(s"Saving to $outputDir as Parquet file.")
schemaRDD.saveAsParquetFile(outputDir)

println(s"Loading Parquet file with UDT from $outputDir.")
val newDataset = sqlContext.parquetFile(outputDir)

println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
(sum1, sum2) => sum1.merge(sum2))
println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")

sc.stop()
}

}
5 changes: 5 additions & 0 deletions mllib/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>
Expand Down
69 changes: 67 additions & 2 deletions mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,26 @@

package org.apache.spark.mllib.linalg

import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import java.util
import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}

import scala.annotation.varargs
import scala.collection.JavaConverters._

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}

import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
import org.apache.spark.sql.catalyst.types._

/**
* Represents a numeric vector, whose index type is Int and value type is Double.
*
* Note: Users should not implement this interface.
*/
@SQLUserDefinedType(udt = classOf[VectorUDT])
sealed trait Vector extends Serializable {

/**
Expand Down Expand Up @@ -74,6 +78,65 @@ sealed trait Vector extends Serializable {
}
}

/**
* User-defined type for [[Vector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.SchemaRDD]].
*/
private[spark] class VectorUDT extends UserDefinedType[Vector] {

override def sqlType: StructType = {
// type: 0 = sparse, 1 = dense
// We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
// vectors. The "values" field is nullable because we might want to add binary vectors later,
// which uses "size" and "indices", but not "values".
StructType(Seq(
StructField("type", ByteType, nullable = false),
StructField("size", IntegerType, nullable = true),
StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
}

override def serialize(obj: Any): Row = {
val row = new GenericMutableRow(4)
obj match {
case sv: SparseVector =>
row.setByte(0, 0)
row.setInt(1, sv.size)
row.update(2, sv.indices.toSeq)
row.update(3, sv.values.toSeq)
case dv: DenseVector =>
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
row.update(3, dv.values.toSeq)
}
row
}

override def deserialize(datum: Any): Vector = {
datum match {
case row: Row =>
require(row.length == 4,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
val tpe = row.getByte(0)
tpe match {
case 0 =>
val size = row.getInt(1)
val indices = row.getAs[Iterable[Int]](2).toArray
val values = row.getAs[Iterable[Double]](3).toArray
new SparseVector(size, indices, values)
case 1 =>
val values = row.getAs[Iterable[Double]](3).toArray
new DenseVector(values)
}
}
}

override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"

override def userClass: Class[Vector] = classOf[Vector]
}

/**
* Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
* We don't use the name `Vector` because Scala imports
Expand Down Expand Up @@ -191,6 +254,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
@SQLUserDefinedType(udt = classOf[VectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {

override def size: Int = values.length
Expand All @@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
*/
@SQLUserDefinedType(udt = classOf[VectorUDT])
class SparseVector(
override val size: Int,
val indices: Array[Int],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite {
throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.")
}
}

test("VectorUDT") {
val dv0 = Vectors.dense(Array.empty[Double])
val dv1 = Vectors.dense(1.0, 2.0)
val sv0 = Vectors.sparse(2, Array.empty, Array.empty)
val sv1 = Vectors.sparse(2, Array(1), Array(2.0))
val udt = new VectorUDT()
for (v <- Seq(dv0, dv1, sv0, sv1)) {
assert(v === udt.deserialize(udt.serialize(v)))
}
}
}
50 changes: 50 additions & 0 deletions python/pyspark/mllib/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

import numpy as np

from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
IntegerType, ByteType, Row


__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']

Expand Down Expand Up @@ -106,7 +109,54 @@ def _format_float(f, digits=4):
return s


class VectorUDT(UserDefinedType):
"""
SQL user-defined type (UDT) for Vector.
"""

@classmethod
def sqlType(cls):
return StructType([
StructField("type", ByteType(), False),
StructField("size", IntegerType(), True),
StructField("indices", ArrayType(IntegerType(), False), True),
StructField("values", ArrayType(DoubleType(), False), True)])

@classmethod
def module(cls):
return "pyspark.mllib.linalg"

@classmethod
def scalaUDT(cls):
return "org.apache.spark.mllib.linalg.VectorUDT"

def serialize(self, obj):
if isinstance(obj, SparseVector):
indices = [int(i) for i in obj.indices]
values = [float(v) for v in obj.values]
return (0, obj.size, indices, values)
elif isinstance(obj, DenseVector):
values = [float(v) for v in obj]
return (1, None, None, values)
else:
raise ValueError("cannot serialize %r of type %r" % (obj, type(obj)))

def deserialize(self, datum):
assert len(datum) == 4, \
"VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
tpe = datum[0]
if tpe == 0:
return SparseVector(datum[1], datum[2], datum[3])
elif tpe == 1:
return DenseVector(datum[3])
else:
raise ValueError("do not recognize type %r" % tpe)


class Vector(object):

__UDT__ = VectorUDT()

"""
Abstract class for DenseVector and SparseVector
"""
Expand Down
Loading

0 comments on commit 1a9c6cd

Please sign in to comment.