Skip to content

Commit

Permalink
[SPARK-4785][SQL] Initilize Hive UDFs on the driver and serialize the…
Browse files Browse the repository at this point in the history
…m with a wrapper

Different from Hive 0.12.0, in Hive 0.13.1 UDF/UDAF/UDTF (aka Hive function) objects should only be initialized once on the driver side and then serialized to executors. However, not all function objects are serializable (e.g. GenericUDF doesn't implement Serializable). Hive 0.13.1 solves this issue with Kryo or XML serializer. Several utility ser/de methods are provided in class o.a.h.h.q.e.Utilities for this purpose. In this PR we chose Kryo for efficiency. The Kryo serializer used here is created in Hive. Spark Kryo serializer wasn't used because there's no available SparkConf instance.

Author: Cheng Hao <hao.cheng@intel.com>
Author: Cheng Lian <lian@databricks.com>

Closes #3640 from chenghao-intel/udf_serde and squashes the following commits:

8e13756 [Cheng Hao] Update the comment
74466a3 [Cheng Hao] refactor as feedbacks
396c0e1 [Cheng Hao] avoid Simple UDF to be serialized
e9c3212 [Cheng Hao] update the comment
19cbd46 [Cheng Hao] support udf instance ser/de after initialization

(cherry picked from commit 383c555)
Signed-off-by: Michael Armbrust <michael@databricks.com>
  • Loading branch information
chenghao-intel authored and marmbrus committed Dec 9, 2014
1 parent 31a6d4f commit e686742
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,10 @@ private[hive] object HiveQl {
Explode(attributes, nodeToExpr(child))

case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
HiveGenericUdtf(functionName, attributes, children.map(nodeToExpr))
HiveGenericUdtf(
new HiveFunctionWrapper(functionName),
attributes,
children.map(nodeToExpr))

case a: ASTNode =>
throw new NotImplementedError(
Expand Down
93 changes: 44 additions & 49 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,46 +54,31 @@ private[hive] abstract class HiveFunctionRegistry
val functionClassName = functionInfo.getFunctionClass.getName

if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveSimpleUdf(functionClassName, children)
HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdf(functionClassName, children)
HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children)
} else if (
classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdaf(functionClassName, children)
HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUdaf(functionClassName, children)
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveGenericUdtf(functionClassName, Nil, children)
HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
}
}

private[hive] trait HiveFunctionFactory {
val functionClassName: String

def createFunction[UDFType]() =
getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
}

private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory {
self: Product =>

type UDFType
private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with Logging {
type EvaluatedType = Any
type UDFType = UDF

def nullable = true

lazy val function = createFunction[UDFType]()

override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
}

private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression])
extends HiveUdf with HiveInspectors {

type UDFType = UDF
@transient
lazy val function = funcWrapper.createFunction[UDFType]()

@transient
protected lazy val method =
Expand Down Expand Up @@ -131,6 +116,8 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
.convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*),
returnInspector)
}

override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}

// Adapter from Catalyst ExpressionResult to Hive DeferredObject
Expand All @@ -144,16 +131,23 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector)
override def get(): AnyRef = wrap(func(), oi)
}

private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression])
extends HiveUdf with HiveInspectors {
private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression with HiveInspectors with Logging {
type UDFType = GenericUDF
type EvaluatedType = Any

def nullable = true

@transient
lazy val function = funcWrapper.createFunction[UDFType]()

@transient
protected lazy val argumentInspectors = children.map(toInspector)

@transient
protected lazy val returnInspector =
protected lazy val returnInspector = {
function.initializeAndFoldConstants(argumentInspectors.toArray)
}

@transient
protected lazy val isUDFDeterministic = {
Expand Down Expand Up @@ -183,18 +177,19 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
}
unwrap(function.evaluate(deferedObjects), returnInspector)
}

override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}

private[hive] case class HiveGenericUdaf(
functionClassName: String,
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression]) extends AggregateExpression
with HiveInspectors
with HiveFunctionFactory {
with HiveInspectors {

type UDFType = AbstractGenericUDAFResolver

@transient
protected lazy val resolver: AbstractGenericUDAFResolver = createFunction()
protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction()

@transient
protected lazy val objectInspector = {
Expand All @@ -209,22 +204,22 @@ private[hive] case class HiveGenericUdaf(

def nullable: Boolean = true

override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"

def newInstance() = new HiveUdafFunction(functionClassName, children, this)
def newInstance() = new HiveUdafFunction(funcWrapper, children, this)
}

/** It is used as a wrapper for the hive functions which uses UDAF interface */
private[hive] case class HiveUdaf(
functionClassName: String,
funcWrapper: HiveFunctionWrapper,
children: Seq[Expression]) extends AggregateExpression
with HiveInspectors
with HiveFunctionFactory {
with HiveInspectors {

type UDFType = UDAF

@transient
protected lazy val resolver: AbstractGenericUDAFResolver = new GenericUDAFBridge(createFunction())
protected lazy val resolver: AbstractGenericUDAFResolver =
new GenericUDAFBridge(funcWrapper.createFunction())

@transient
protected lazy val objectInspector = {
Expand All @@ -239,10 +234,10 @@ private[hive] case class HiveUdaf(

def nullable: Boolean = true

override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"

def newInstance() =
new HiveUdafFunction(functionClassName, children, this, true)
new HiveUdafFunction(funcWrapper, children, this, true)
}

/**
Expand All @@ -257,13 +252,13 @@ private[hive] case class HiveUdaf(
* user defined aggregations, which have clean semantics even in a partitioned execution.
*/
private[hive] case class HiveGenericUdtf(
functionClassName: String,
funcWrapper: HiveFunctionWrapper,
aliasNames: Seq[String],
children: Seq[Expression])
extends Generator with HiveInspectors with HiveFunctionFactory {
extends Generator with HiveInspectors {

@transient
protected lazy val function: GenericUDTF = createFunction()
protected lazy val function: GenericUDTF = funcWrapper.createFunction()

@transient
protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
Expand Down Expand Up @@ -320,25 +315,24 @@ private[hive] case class HiveGenericUdtf(
}
}

override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
override def toString = s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})"
}

private[hive] case class HiveUdafFunction(
functionClassName: String,
funcWrapper: HiveFunctionWrapper,
exprs: Seq[Expression],
base: AggregateExpression,
isUDAFBridgeRequired: Boolean = false)
extends AggregateFunction
with HiveInspectors
with HiveFunctionFactory {
with HiveInspectors {

def this() = this(null, null, null)

private val resolver =
if (isUDAFBridgeRequired) {
new GenericUDAFBridge(createFunction[UDAF]())
new GenericUDAFBridge(funcWrapper.createFunction[UDAF]())
} else {
createFunction[AbstractGenericUDAFResolver]()
funcWrapper.createFunction[AbstractGenericUDAFResolver]()
}

private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
Expand All @@ -361,3 +355,4 @@ private[hive] case class HiveUdafFunction(
function.iterate(buffer, inputs)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ class HiveUdfSuite extends QueryTest {
| getStruct(1).f5 FROM src LIMIT 1
""".stripMargin).first() === Row(1, 2, 3, 4, 5))
}

test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
checkAnswer(
sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"),
8
)
}

test("hive struct udf") {
sql(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@ import scala.language.implicitConversions

import org.apache.spark.sql.catalyst.types.DecimalType

class HiveFunctionWrapper(var functionClassName: String) extends java.io.Serializable {
// for Serialization
def this() = this(null)

import org.apache.spark.util.Utils._
def createFunction[UDFType <: AnyRef](): UDFType = {
getContextOrSparkClassLoader
.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
}
}

/**
* A compatibility layer for interacting with Hive version 0.12.0.
*/
Expand Down
107 changes: 107 additions & 0 deletions sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive

import java.util.{ArrayList => JArrayList}
import java.util.Properties

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapred.InputFormat
Expand All @@ -42,6 +43,112 @@ import org.apache.spark.sql.catalyst.types.decimal.Decimal
import scala.collection.JavaConversions._
import scala.language.implicitConversions


/**
* This class provides the UDF creation and also the UDF instance serialization and
* de-serialization cross process boundary.
*
* Detail discussion can be found at https://github.com/apache/spark/pull/3640
*
* @param functionClassName UDF class name
*/
class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable {
// for Serialization
def this() = this(null)

import java.io.{OutputStream, InputStream}
import com.esotericsoftware.kryo.Kryo
import org.apache.spark.util.Utils._
import org.apache.hadoop.hive.ql.exec.Utilities
import org.apache.hadoop.hive.ql.exec.UDF

@transient
private val methodDeSerialize = {
val method = classOf[Utilities].getDeclaredMethod(
"deserializeObjectByKryo",
classOf[Kryo],
classOf[InputStream],
classOf[Class[_]])
method.setAccessible(true)

method
}

@transient
private val methodSerialize = {
val method = classOf[Utilities].getDeclaredMethod(
"serializeObjectByKryo",
classOf[Kryo],
classOf[Object],
classOf[OutputStream])
method.setAccessible(true)

method
}

def deserializePlan[UDFType](is: java.io.InputStream, clazz: Class[_]): UDFType = {
methodDeSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), is, clazz)
.asInstanceOf[UDFType]
}

def serializePlan(function: AnyRef, out: java.io.OutputStream): Unit = {
methodSerialize.invoke(null, Utilities.runtimeSerializationKryo.get(), function, out)
}

private var instance: AnyRef = null

def writeExternal(out: java.io.ObjectOutput) {
// output the function name
out.writeUTF(functionClassName)

// Write a flag if instance is null or not
out.writeBoolean(instance != null)
if (instance != null) {
// Some of the UDF are serializable, but some others are not
// Hive Utilities can handle both cases
val baos = new java.io.ByteArrayOutputStream()
serializePlan(instance, baos)
val functionInBytes = baos.toByteArray

// output the function bytes
out.writeInt(functionInBytes.length)
out.write(functionInBytes, 0, functionInBytes.length)
}
}

def readExternal(in: java.io.ObjectInput) {
// read the function name
functionClassName = in.readUTF()

if (in.readBoolean()) {
// if the instance is not null
// read the function in bytes
val functionInBytesLength = in.readInt()
val functionInBytes = new Array[Byte](functionInBytesLength)
in.read(functionInBytes, 0, functionInBytesLength)

// deserialize the function object via Hive Utilities
instance = deserializePlan[AnyRef](new java.io.ByteArrayInputStream(functionInBytes),
getContextOrSparkClassLoader.loadClass(functionClassName))
}
}

def createFunction[UDFType <: AnyRef](): UDFType = {
if (instance != null) {
instance.asInstanceOf[UDFType]
} else {
val func = getContextOrSparkClassLoader
.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
if (!func.isInstanceOf[UDF]) {
// We cache the function if it's no the Simple UDF,
// as we always have to create new instance for Simple UDF
instance = func
}
func
}
}
}

/**
* A compatibility layer for interacting with Hive version 0.13.1.
*/
Expand Down

0 comments on commit e686742

Please sign in to comment.