Skip to content

Commit

Permalink
[SPARK-14215] [SQL] [PYSPARK] Support chained Python UDFs
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR brings the support for chained Python UDFs, for example

```sql
select udf1(udf2(a))
select udf1(udf2(a) + 3)
select udf1(udf2(a) + udf3(b))
```

Also directly chained unary Python UDFs are put in single batch of Python UDFs, others may require multiple batches.

For example,
```python
>>> sqlContext.sql("select double(double(1))").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [pythonUDF#10 AS double(double(1))#9]
:     +- INPUT
+- !BatchPythonEvaluation double(double(1)), [pythonUDF#10]
   +- Scan OneRowRelation[]
>>> sqlContext.sql("select double(double(1) + double(2))").explain()
== Physical Plan ==
WholeStageCodegen
:  +- Project [pythonUDF#19 AS double((double(1) + double(2)))#16]
:     +- INPUT
+- !BatchPythonEvaluation double((pythonUDF#17 + pythonUDF#18)), [pythonUDF#17,pythonUDF#18,pythonUDF#19]
   +- !BatchPythonEvaluation double(2), [pythonUDF#17,pythonUDF#18]
      +- !BatchPythonEvaluation double(1), [pythonUDF#17]
         +- Scan OneRowRelation[]
```

TODO: will support multiple unrelated Python UDFs in one batch (another PR).

## How was this patch tested?

Added new unit tests for chained UDFs.

Author: Davies Liu <davies@databricks.com>

Closes #12014 from davies/py_udfs.
  • Loading branch information
Davies Liu authored and davies committed Mar 29, 2016
1 parent e58c4cb commit a7a93a1
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 35 deletions.
38 changes: 22 additions & 16 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ private[spark] class PythonRDD(
val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)

override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val runner = new PythonRunner(func, bufferSize, reuse_worker)
val runner = new PythonRunner(Seq(func), bufferSize, reuse_worker, false)
runner.compute(firstParent.iterator(split, context), split.index, context)
}
}
Expand All @@ -81,14 +81,18 @@ private[spark] case class PythonFunction(
* A helper class to run Python UDFs in Spark.
*/
private[spark] class PythonRunner(
func: PythonFunction,
funcs: Seq[PythonFunction],
bufferSize: Int,
reuse_worker: Boolean)
reuse_worker: Boolean,
rowBased: Boolean)
extends Logging {

private val envVars = func.envVars
private val pythonExec = func.pythonExec
private val accumulator = func.accumulator
// All the Python functions should have the same exec, version and envvars.
private val envVars = funcs.head.envVars
private val pythonExec = funcs.head.pythonExec
private val pythonVer = funcs.head.pythonVer

private val accumulator = funcs.head.accumulator // TODO: support accumulator in multiple UDF

def compute(
inputIterator: Iterator[_],
Expand Down Expand Up @@ -228,10 +232,8 @@ private[spark] class PythonRunner(

@volatile private var _exception: Exception = null

private val pythonVer = func.pythonVer
private val pythonIncludes = func.pythonIncludes
private val broadcastVars = func.broadcastVars
private val command = func.command
private val pythonIncludes = funcs.flatMap(_.pythonIncludes.asScala).toSet
private val broadcastVars = funcs.flatMap(_.broadcastVars.asScala)

setDaemon(true)

Expand All @@ -256,13 +258,13 @@ private[spark] class PythonRunner(
// sparkFilesDir
PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut)
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.size())
for (include <- pythonIncludes.asScala) {
dataOut.writeInt(pythonIncludes.size)
for (include <- pythonIncludes) {
PythonRDD.writeUTF(include, dataOut)
}
// Broadcast variables
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.asScala.map(_.id).toSet
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
val toRemove = oldBids.diff(newBids)
val cnt = toRemove.size + newBids.diff(oldBids).size
Expand All @@ -272,7 +274,7 @@ private[spark] class PythonRunner(
dataOut.writeLong(- bid - 1) // bid >= 0
oldBids.remove(bid)
}
for (broadcast <- broadcastVars.asScala) {
for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
Expand All @@ -282,8 +284,12 @@ private[spark] class PythonRunner(
}
dataOut.flush()
// Serialized command:
dataOut.writeInt(command.length)
dataOut.write(command)
dataOut.writeInt(if (rowBased) 1 else 0)
dataOut.writeInt(funcs.length)
funcs.foreach { f =>
dataOut.writeInt(f.command.length)
dataOut.write(f.command)
}
// Data values
PythonRDD.writeIteratorToStream(inputIterator, dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
Expand Down
16 changes: 11 additions & 5 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from itertools import imap as map

from pyspark import since, SparkContext
from pyspark.rdd import _wrap_function, ignore_unicode_prefix
from pyspark.rdd import _prepare_for_python_RDD, ignore_unicode_prefix
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
from pyspark.sql.types import StringType
from pyspark.sql.column import Column, _to_java_column, _to_seq
Expand Down Expand Up @@ -1648,6 +1648,14 @@ def sort_array(col, asc=True):

# ---------------------------- User Defined Function ----------------------------------

def _wrap_function(sc, func, returnType):
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, returnType, ser)
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command)
return sc._jvm.PythonFunction(bytearray(pickled_command), env, includes, sc.pythonExec,
sc.pythonVer, broadcast_vars, sc._javaAccumulator)


class UserDefinedFunction(object):
"""
User defined function in Python
Expand All @@ -1662,14 +1670,12 @@ def __init__(self, func, returnType, name=None):

def _create_judf(self, name):
from pyspark.sql import SQLContext
f, returnType = self.func, self.returnType # put them in closure `func`
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
sc = SparkContext.getOrCreate()
wrapped_func = _wrap_function(sc, func, ser, ser)
wrapped_func = _wrap_function(sc, self.func, self.returnType)
ctx = SQLContext.getOrCreate(sc)
jdt = ctx._ssql_ctx.parseDataType(self.returnType.json())
if name is None:
f = self.func
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
name, wrapped_func, jdt)
Expand Down
9 changes: 9 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,15 @@ def test_udf2(self):
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])

def test_chained_python_udf(self):
self.sqlCtx.registerFunction("double", lambda x: x + x, IntegerType())
[row] = self.sqlCtx.sql("SELECT double(1)").collect()
self.assertEqual(row[0], 2)
[row] = self.sqlCtx.sql("SELECT double(double(1))").collect()
self.assertEqual(row[0], 4)
[row] = self.sqlCtx.sql("SELECT double(double(1) + 1)").collect()
self.assertEqual(row[0], 6)

def test_udf_with_array_type(self):
d = [Row(l=list(range(3)), d={"key": list(range(5))})]
rdd = self.sc.parallelize(d)
Expand Down
33 changes: 29 additions & 4 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def add_path(path):
sys.path.insert(1, path)


def read_command(serializer, file):
command = serializer._read_with_length(file)
if isinstance(command, Broadcast):
command = serializer.loads(command.value)
return command


def chain(f, g):
"""chain two function together """
return lambda x: g(f(x))


def main(infile, outfile):
try:
boot_time = time.time()
Expand Down Expand Up @@ -95,10 +107,23 @@ def main(infile, outfile):
_broadcastRegistry.pop(bid)

_accumulatorRegistry.clear()
command = pickleSer._read_with_length(infile)
if isinstance(command, Broadcast):
command = pickleSer.loads(command.value)
func, profiler, deserializer, serializer = command
row_based = read_int(infile)
num_commands = read_int(infile)
if row_based:
profiler = None # profiling is not supported for UDF
row_func = None
for i in range(num_commands):
f, returnType, deserializer = read_command(pickleSer, infile)
if row_func is None:
row_func = f
else:
row_func = chain(row_func, f)
serializer = deserializer
func = lambda _, it: map(lambda x: returnType.toInternal(row_func(*x)), it)
else:
assert num_commands == 1
func, profiler, deserializer, serializer = read_command(pickleSer, infile)

init_time = time.time()

def process():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ import scala.collection.JavaConverters._
import net.razorvine.pickle.{Pickler, Unpickler}

import org.apache.spark.TaskContext
import org.apache.spark.api.python.PythonRunner
import org.apache.spark.api.python.{PythonFunction, PythonRunner}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow, JoinedRow, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.types.{StructField, StructType}

Expand All @@ -45,6 +45,18 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:

def children: Seq[SparkPlan] = child :: Nil

private def collectFunctions(udf: PythonUDF): (Seq[PythonFunction], Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
val (fs, children) = collectFunctions(u)
(fs ++ Seq(udf.func), children)
case children =>
// There should not be any other UDFs, or the children can't be evaluated directly.
assert(children.forall(_.find(_.isInstanceOf[PythonUDF]).isEmpty))
(Seq(udf.func), udf.children)
}
}

protected override def doExecute(): RDD[InternalRow] = {
val inputRDD = child.execute().map(_.copy())
val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536)
Expand All @@ -57,9 +69,11 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
// combine input with output from Python.
val queue = new java.util.concurrent.ConcurrentLinkedQueue[InternalRow]()

val (pyFuncs, children) = collectFunctions(udf)

val pickle = new Pickler
val currentRow = newMutableProjection(udf.children, child.output)()
val fields = udf.children.map(_.dataType)
val currentRow = newMutableProjection(children, child.output)()
val fields = children.map(_.dataType)
val schema = new StructType(fields.map(t => new StructField("", t, true)).toArray)

// Input iterator to Python: input rows are grouped so we send them in batches to Python.
Expand All @@ -75,11 +89,8 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child:
val context = TaskContext.get()

// Output iterator for results from Python.
val outputIterator = new PythonRunner(
udf.func,
bufferSize,
reuseWorker
).compute(inputIterator, context.partitionId(), context)
val outputIterator = new PythonRunner(pyFuncs, bufferSize, reuseWorker, true)
.compute(inputIterator, context.partitionId(), context)

val unpickle = new Unpickler
val row = new GenericMutableRow(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.python

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
Expand All @@ -25,17 +26,40 @@ import org.apache.spark.sql.catalyst.rules.Rule
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
* alone in a batch.
*
* Only extracts the PythonUDFs that could be evaluated in Python (the single child is PythonUDFs
* or all the children could be evaluated in JVM).
*
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {

private def hasPythonUDF(e: Expression): Boolean = {
e.find(_.isInstanceOf[PythonUDF]).isDefined
}

private def canEvaluateInPython(e: PythonUDF): Boolean = {
e.children match {
// single PythonUDF child could be chained and evaluated in Python
case Seq(u: PythonUDF) => canEvaluateInPython(u)
// Python UDF can't be evaluated directly in JVM
case children => !children.exists(hasPythonUDF)
}
}

private def collectEvaluatableUDF(expr: Expression): Seq[PythonUDF] = {
expr.collect {
case udf: PythonUDF if canEvaluateInPython(udf) => udf
}
}

def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Skip EvaluatePython nodes.
case plan: EvaluatePython => plan

case plan: LogicalPlan if plan.resolved =>
// Extract any PythonUDFs from the current operator.
val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf })
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
Expand Down

0 comments on commit a7a93a1

Please sign in to comment.