Skip to content

Commit

Permalink
[SPARK-42051][SQL] Codegen Support for HiveGenericUDF
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

As a subtask of SPARK-42050, this PR adds Codegen Support for `HiveGenericUDF`

### Why are the changes needed?

improve codegen coverage and performance

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

new UT added

Closes apache#39555 from yaooqinn/SPARK-42051.

Authored-by: Kent Yao <yao@apache.org>
Signed-off-by: Kent Yao <yao@apache.org>
  • Loading branch information
yaooqinn committed Feb 1, 2023
1 parent b0ac061 commit 34fb408
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 10 deletions.
62 changes: 53 additions & 9 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.hive.HiveShim._
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -120,19 +121,18 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp
extends DeferredObject with HiveInspectors {

private val wrapper = wrapperFor(oi, dataType)
private var func: () => Any = _
def set(func: () => Any): Unit = {
private var func: Any = _
def set(func: Any): Unit = {
this.func = func
}
override def prepare(i: Int): Unit = {}
override def get(): AnyRef = wrapper(func()).asInstanceOf[AnyRef]
override def get(): AnyRef = wrapper(func).asInstanceOf[AnyRef]
}

private[hive] case class HiveGenericUDF(
name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression])
extends Expression
with HiveInspectors
with CodegenFallback
with Logging
with UserDefinedExpression {

Expand All @@ -154,18 +154,20 @@ private[hive] case class HiveGenericUDF(
function.initializeAndFoldConstants(argumentInspectors.toArray)
}

// Visible for codegen
@transient
private lazy val unwrapper = unwrapperFor(returnInspector)
lazy val unwrapper: Any => Any = unwrapperFor(returnInspector)

@transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
udfType != null && udfType.deterministic() && !udfType.stateful()
}

// Visible for codegen
@transient
private lazy val deferredObjects = argumentInspectors.zip(children).map { case (inspect, child) =>
new DeferredObjectAdapter(inspect, child.dataType)
lazy val deferredObjects: Array[DeferredObject] = argumentInspectors.zip(children).map {
case (inspect, child) => new DeferredObjectAdapter(inspect, child.dataType)
}.toArray[DeferredObject]

override lazy val dataType: DataType = inspectorToDataType(returnInspector)
Expand All @@ -178,7 +180,7 @@ private[hive] case class HiveGenericUDF(
while (i < length) {
val idx = i
deferredObjects(i).asInstanceOf[DeferredObjectAdapter]
.set(() => children(idx).eval(input))
.set(children(idx).eval(input))
i += 1
}
unwrapper(function.evaluate(deferredObjects))
Expand All @@ -192,6 +194,48 @@ private[hive] case class HiveGenericUDF(

override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Expression =
copy(children = newChildren)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val refTerm = ctx.addReferenceObj("this", this)
val childrenEvals = children.map(_.genCode(ctx))

val setDeferredObjects = childrenEvals.zipWithIndex.map {
case (eval, i) =>
val deferredObjectAdapterClz = classOf[DeferredObjectAdapter].getCanonicalName
s"""
|if (${eval.isNull}) {
| (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(null);
|} else {
| (($deferredObjectAdapterClz) $refTerm.deferredObjects()[$i]).set(${eval.value});
|}
|""".stripMargin
}

val resultType = CodeGenerator.boxedType(dataType)
val resultTerm = ctx.freshName("result")
ev.copy(code =
code"""
|${childrenEvals.map(_.code).mkString("\n")}
|${setDeferredObjects.mkString("\n")}
|$resultType $resultTerm = null;
|boolean ${ev.isNull} = false;
|try {
| $resultTerm = ($resultType) $refTerm.unwrapper().apply(
| $refTerm.function().evaluate($refTerm.deferredObjects()));
| ${ev.isNull} = $resultTerm == null;
|} catch (Throwable e) {
| throw QueryExecutionErrors.failedExecuteUserDefinedFunctionError(
| "${funcWrapper.functionClassName}",
| "${children.map(_.dataType.catalogString).mkString(", ")}",
| "${dataType.catalogString}",
| e);
|}
|${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
|""".stripMargin
)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectIns
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.io.{LongWritable, Writable}

import org.apache.spark.{SparkFiles, TestUtils}
import org.apache.spark.{SparkException, SparkFiles, TestUtils}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.functions.max
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -711,6 +712,37 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
}
}
}

test("SPARK-42051: HiveGenericUDF Codegen Support") {
withUserDefinedFunction("CodeGenHiveGenericUDF" -> false) {
sql(s"CREATE FUNCTION CodeGenHiveGenericUDF AS '${classOf[GenericUDFMaskHash].getName}'")
withTable("HiveGenericUDFTable") {
sql(s"create table HiveGenericUDFTable as select 'Spark SQL' as v")
val df = sql("SELECT CodeGenHiveGenericUDF(v) from HiveGenericUDFTable")
val plan = df.queryExecution.executedPlan
assert(plan.isInstanceOf[WholeStageCodegenExec])
checkAnswer(df, Seq(Row("14ab8df5135825bc9f5ff7c30609f02f")))
}
}
}

test("SPARK-42051: HiveGenericUDF Codegen Support w/ execution failure") {
withUserDefinedFunction("CodeGenHiveGenericUDF" -> false) {
sql(s"CREATE FUNCTION CodeGenHiveGenericUDF AS '${classOf[GenericUDFAssertTrue].getName}'")
withTable("HiveGenericUDFTable") {
sql(s"create table HiveGenericUDFTable as select false as v")
val df = sql("SELECT CodeGenHiveGenericUDF(v) from HiveGenericUDFTable")
val e = intercept[SparkException](df.collect()).getCause.asInstanceOf[SparkException]
checkError(
e,
"FAILED_EXECUTE_UDF",
parameters = Map(
"functionName" -> s"${classOf[GenericUDFAssertTrue].getName}",
"signature" -> "boolean",
"result" -> "void"))
}
}
}
}

class TestPair(x: Int, y: Int) extends Writable with Serializable {
Expand Down

0 comments on commit 34fb408

Please sign in to comment.