diff --git a/datagen/pom.xml b/datagen/pom.xml
index 20b3403d3e1..9bdf897cfd7 100644
--- a/datagen/pom.xml
+++ b/datagen/pom.xml
@@ -33,6 +33,7 @@
**/*
package
+ ${project.build.outputDirectory}/datagen-version-info.properties
diff --git a/df_udf/README.md b/df_udf/README.md
new file mode 100644
index 00000000000..0226c365a42
--- /dev/null
+++ b/df_udf/README.md
@@ -0,0 +1,90 @@
+# Scala / Java UDFS implemented using data frame
+
+User Defined Functions (UDFs) are used for a number of reasons in Apache Spark. Much of the time it is to implement
+logic that is either very difficult or impossible to implement using existing SQL/Dataframe APIs directly. But they
+are also used as a way to standardize processing logic across an organization or for code reused.
+
+But UDFs come with some downsides. The biggest one is visibility into the processing being done. SQL is a language that
+can be highly optimized. But a UDF in most cases is a black box, that the SQL optimizer cannot do anything about.
+This can result in less than ideal query planning. Additionally, accelerated execution environments, like the
+RAPIDS Accelerator for Apache Spark have no easy way to replace UDFs with accelerated versions, which can result in
+slow performance.
+
+This attempts to add visibility to the code reuse use case by providing a way to implement a UDF in terms of dataframe
+commands.
+
+## Setup
+
+To do this include com.nvidia:df_udf_plugin as a dependency for your project and also include it on the
+classpath for your Apache Spark environment. Then include `com.nvidia.spark.DFUDFPlugin` in the config
+`spark.sql.extensions`. Now you can implement a UDF in terms of Dataframe operations.
+
+## Usage
+
+```scala
+import com.nvidia.spark.functions._
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.functions._
+
+val sum_array = df_udf((longArray: Column) =>
+ aggregate(longArray,
+ lit(0L),
+ (a, b) => coalesce(a, lit(0L)) + coalesce(b, lit(0L)),
+ a => a))
+spark.udf.register("sum_array", sum_array)
+```
+
+You can then use `sum_array` however you would have used any other UDF. This allows you to provide a drop in replacement
+implementation of an existing UDF.
+
+```scala
+Seq(Array(1L, 2L, 3L)).toDF("data").selectExpr("sum_array(data) as result").show()
+
++------+
+|result|
++------+
+| 6|
++------+
+```
+
+## Type Checks
+
+DataFrame APIs do not provide type safety when writing the code and that is the same here. There are no builtin type
+checks for inputs yet. Also, because of how types are resolved in Spark there is no way to adjust the query based on
+the types passed in. Type checks are handled by the SQL planner/optimizer after the UDF has been replaced. This means
+that the final SQL will not violate any type safety, but it also means that the errors might be confusing. For example,
+if I passed in an `ARRAY` to `sum_array` instead of an `ARRAY` I would get an error like
+
+```scala
+Seq(Array(1.0, 2.0, 3.0)).toDF("data").selectExpr("sum_array(data) as result").show()
+org.apache.spark.sql.AnalysisException: [DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE] Cannot resolve "aggregate(data, 0, lambdafunction((coalesce(namedlambdavariable(), 0) + coalesce(namedlambdavariable(), 0)), namedlambdavariable(), namedlambdavariable()), lambdafunction(namedlambdavariable(), namedlambdavariable()))" due to data type mismatch: Parameter 3 requires the "BIGINT" type, however "lambdafunction((coalesce(namedlambdavariable(), 0) + coalesce(namedlambdavariable(), 0)), namedlambdavariable(), namedlambdavariable())" has the type "DOUBLE".; line 1 pos 0;
+Project [aggregate(data#46, 0, lambdafunction((cast(coalesce(lambda x_9#49L, 0) as double) + coalesce(lambda y_10#50, cast(0 as double))), lambda x_9#49L, lambda y_10#50, false), lambdafunction(lambda x_11#51L, lambda x_11#51L, false)) AS result#48L]
++- Project [value#43 AS data#46]
+ +- LocalRelation [value#43]
+
+ at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.dataTypeMismatch(package.scala:73)
+ at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5(CheckAnalysis.scala:269)
+ at org.apache.spark.sql.catalyst.analysis.CheckAnalysis.$anonfun$checkAnalysis0$5$adapted(CheckAnalysis.scala:256)
+```
+
+Which is not as simple to understand as a normal UDF.
+
+```scala
+val sum_array = udf((a: Array[Long]) => a.sum)
+
+spark.udf.register("sum_array", sum_array)
+
+Seq(Array(1.0, 2.0, 3.0)).toDF("data").selectExpr("sum_array(data) as result").show()
+org.apache.spark.sql.AnalysisException: [CANNOT_UP_CAST_DATATYPE] Cannot up cast array element from "DOUBLE" to "BIGINT".
+ The type path of the target object is:
+- array element class: "long"
+- root class: "[J"
+You can either add an explicit cast to the input data or choose a higher precision type of the field in the target object
+at org.apache.spark.sql.errors.QueryCompilationErrors$.upCastFailureError(QueryCompilationErrors.scala:285)
+at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$.org$apache$spark$sql$catalyst$analysis$Analyzer$ResolveUpCast$$fail(Analyzer.scala:3646)
+at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$57$$anonfun$applyOrElse$234.applyOrElse(Analyzer.scala:3677)
+at org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveUpCast$$anonfun$apply$57$$anonfun$applyOrElse$234.applyOrElse(Analyzer.scala:3654)
+```
+
+We hope to add optional type checks in the future.
\ No newline at end of file
diff --git a/df_udf/pom.xml b/df_udf/pom.xml
new file mode 100644
index 00000000000..39f33880f34
--- /dev/null
+++ b/df_udf/pom.xml
@@ -0,0 +1,88 @@
+
+
+
+ 4.0.0
+
+ com.nvidia
+ rapids-4-spark-shim-deps-parent_2.12
+ 24.12.0-SNAPSHOT
+ ../shim-deps/pom.xml
+
+ df_udf_plugin_2.12
+ UDFs implemented in SQL/Dataframe
+ UDFs for Apache Spark implemented in SQL/Dataframe
+ 24.12.0-SNAPSHOT
+
+
+ df_udf
+
+ **/*
+ package
+ ${project.build.outputDirectory}/df_udf-version-info.properties
+
+
+
+
+ org.scala-lang
+ scala-library
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.test.version}
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+ true
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ org.apache.rat
+ apache-rat-plugin
+
+
+
+
+
+
+ ${project.build.directory}/extra-resources
+
+
+
+
diff --git a/df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala b/df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala
new file mode 100644
index 00000000000..7e1c0451c8a
--- /dev/null
+++ b/df_udf/src/main/scala/com/nvidia/spark/DFUDFPlugin.scala
@@ -0,0 +1,31 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark
+
+import org.apache.spark.sql.{SparkSession, SparkSessionExtensions}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+class DFUDFPlugin extends (SparkSessionExtensions => Unit) {
+ override def apply(extensions: SparkSessionExtensions): Unit = {
+ extensions.injectResolutionRule(logicalPlanRules)
+ }
+
+ def logicalPlanRules(sparkSession: SparkSession): Rule[LogicalPlan] = {
+ org.apache.spark.sql.nvidia.LogicalPlanRules()
+ }
+}
\ No newline at end of file
diff --git a/df_udf/src/main/scala/com/nvidia/spark/functions.scala b/df_udf/src/main/scala/com/nvidia/spark/functions.scala
new file mode 100644
index 00000000000..8c8eef3f825
--- /dev/null
+++ b/df_udf/src/main/scala/com/nvidia/spark/functions.scala
@@ -0,0 +1,232 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.api.java.{UDF0, UDF1, UDF10, UDF2, UDF3, UDF4, UDF5, UDF6, UDF7, UDF8, UDF9}
+import org.apache.spark.sql.expressions.UserDefinedFunction
+import org.apache.spark.sql.functions.udf
+import org.apache.spark.sql.nvidia._
+import org.apache.spark.sql.types.LongType
+
+// scalastyle:off
+object functions {
+// scalastyle:on
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function0[Column]): UserDefinedFunction =
+ udf(DFUDF0(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function1[Column, Column]): UserDefinedFunction =
+ udf(DFUDF1(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function2[Column, Column, Column]): UserDefinedFunction =
+ udf(DFUDF2(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function3[Column, Column, Column, Column]): UserDefinedFunction =
+ udf(DFUDF3(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function4[Column, Column, Column, Column, Column]): UserDefinedFunction =
+ udf(DFUDF4(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function5[Column, Column, Column, Column, Column, Column]): UserDefinedFunction =
+ udf(DFUDF5(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function6[Column, Column, Column, Column, Column, Column,
+ Column]): UserDefinedFunction =
+ udf(DFUDF6(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function7[Column, Column, Column, Column, Column, Column,
+ Column, Column]): UserDefinedFunction =
+ udf(DFUDF7(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function8[Column, Column, Column, Column, Column, Column,
+ Column, Column, Column]): UserDefinedFunction =
+ udf(DFUDF8(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function9[Column, Column, Column, Column, Column, Column,
+ Column, Column, Column, Column]): UserDefinedFunction =
+ udf(DFUDF9(f), LongType)
+
+ /**
+ * Defines a Scala closure of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to
+ * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: Function10[Column, Column, Column, Column, Column, Column,
+ Column, Column, Column, Column, Column]): UserDefinedFunction =
+ udf(DFUDF10(f), LongType)
+
+
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ // Java UDF functions
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF0[Column]): UserDefinedFunction = {
+ udf(JDFUDF0(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF1[Column, Column]): UserDefinedFunction = {
+ udf(JDFUDF1(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF2[Column, Column, Column]): UserDefinedFunction = {
+ udf(JDFUDF2(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF3[Column, Column, Column, Column]): UserDefinedFunction = {
+ udf(JDFUDF3(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF4[Column, Column, Column, Column, Column]): UserDefinedFunction = {
+ udf(JDFUDF4(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF5[Column, Column, Column, Column, Column, Column]): UserDefinedFunction = {
+ udf(JDFUDF5(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF6[Column, Column, Column, Column, Column, Column,
+ Column]): UserDefinedFunction = {
+ udf(JDFUDF6(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF7[Column, Column, Column, Column, Column, Column,
+ Column, Column]): UserDefinedFunction = {
+ udf(JDFUDF7(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF8[Column, Column, Column, Column, Column, Column,
+ Column, Column, Column]): UserDefinedFunction = {
+ udf(JDFUDF8(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF9[Column, Column, Column, Column, Column, Column,
+ Column, Column, Column, Column]): UserDefinedFunction = {
+ udf(JDFUDF9(f), LongType)
+ }
+
+ /**
+ * Defines a Java UDF instance of Columns as user-defined function (UDF).
+ * By default the returned UDF is deterministic. To change it to nondeterministic, call the
+ * API `UserDefinedFunction.asNondeterministic()`.
+ */
+ def df_udf(f: UDF10[Column, Column, Column, Column, Column, Column,
+ Column, Column, Column, Column, Column]): UserDefinedFunction = {
+ udf(JDFUDF10(f), LongType)
+ }
+
+}
\ No newline at end of file
diff --git a/df_udf/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala b/df_udf/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala
new file mode 100644
index 00000000000..24a123016d6
--- /dev/null
+++ b/df_udf/src/main/scala/org/apache/spark/sql/nvidia/LogicalPlanRules.scala
@@ -0,0 +1,37 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.nvidia
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
+
+case class LogicalPlanRules() extends Rule[LogicalPlan] with Logging {
+ val replacePartialFunc: PartialFunction[Expression, Expression] = {
+ case f: ScalaUDF if DFUDF.getDFUDF(f.function).isDefined =>
+ DFUDF.getDFUDF(f.function).map {
+ dfudf => DFUDFShims.columnToExpr(
+ dfudf(f.children.map(DFUDFShims.exprToColumn(_)).toArray))
+ }.getOrElse{
+ throw new IllegalStateException("Inconsistent results when extracting df_udf")
+ }
+ }
+
+ override def apply(plan: LogicalPlan): LogicalPlan =
+ plan.transformExpressions(replacePartialFunc)
+}
diff --git a/df_udf/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala b/df_udf/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala
new file mode 100644
index 00000000000..79f71ba4ca0
--- /dev/null
+++ b/df_udf/src/main/scala/org/apache/spark/sql/nvidia/dataframe_udfs.scala
@@ -0,0 +1,340 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.nvidia
+
+import java.lang.invoke.SerializedLambda
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.api.java._
+import org.apache.spark.util.Utils
+
+trait DFUDF {
+ def apply(input: Array[Column]): Column
+}
+
+case class DFUDF0(f: Function0[Column])
+ extends UDF0[Any] with DFUDF {
+ override def call(): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 0)
+ f()
+ }
+}
+
+case class DFUDF1(f: Function1[Column, Column])
+ extends UDF1[Any, Any] with DFUDF {
+ override def call(t1: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 1)
+ f(input(0))
+ }
+}
+
+case class DFUDF2(f: Function2[Column, Column, Column])
+ extends UDF2[Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 2)
+ f(input(0), input(1))
+ }
+}
+
+case class DFUDF3(f: Function3[Column, Column, Column, Column])
+ extends UDF3[Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 3)
+ f(input(0), input(1), input(2))
+ }
+}
+
+case class DFUDF4(f: Function4[Column, Column, Column, Column, Column])
+ extends UDF4[Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 4)
+ f(input(0), input(1), input(2), input(3))
+ }
+}
+
+case class DFUDF5(f: Function5[Column, Column, Column, Column, Column, Column])
+ extends UDF5[Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 5)
+ f(input(0), input(1), input(2), input(3), input(4))
+ }
+}
+
+case class DFUDF6(f: Function6[Column, Column, Column, Column, Column, Column, Column])
+ extends UDF6[Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 6)
+ f(input(0), input(1), input(2), input(3), input(4), input(5))
+ }
+}
+
+case class DFUDF7(f: Function7[Column, Column, Column, Column, Column, Column, Column, Column])
+ extends UDF7[Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 7)
+ f(input(0), input(1), input(2), input(3), input(4), input(5), input(6))
+ }
+}
+
+case class DFUDF8(f: Function8[Column, Column, Column, Column, Column, Column, Column, Column,
+ Column])
+ extends UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 8)
+ f(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7))
+ }
+}
+
+case class DFUDF9(f: Function9[Column, Column, Column, Column, Column, Column, Column, Column,
+ Column, Column])
+ extends UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any,
+ t9: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 9)
+ f(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8))
+ }
+}
+
+case class DFUDF10(f: Function10[Column, Column, Column, Column, Column, Column, Column, Column,
+ Column, Column, Column])
+ extends UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any,
+ t9: Any, t10: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 10)
+ f(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8),
+ input(9))
+ }
+}
+
+case class JDFUDF0(f: UDF0[Column])
+ extends UDF0[Any] with DFUDF {
+ override def call(): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 0)
+ f.call()
+ }
+}
+
+case class JDFUDF1(f: UDF1[Column, Column])
+ extends UDF1[Any, Any] with DFUDF {
+ override def call(t1: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 1)
+ f.call(input(0))
+ }
+}
+
+case class JDFUDF2(f: UDF2[Column, Column, Column])
+ extends UDF2[Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 2)
+ f.call(input(0), input(1))
+ }
+}
+
+case class JDFUDF3(f: UDF3[Column, Column, Column, Column])
+ extends UDF3[Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 3)
+ f.call(input(0), input(1), input(2))
+ }
+}
+
+case class JDFUDF4(f: UDF4[Column, Column, Column, Column, Column])
+ extends UDF4[Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 4)
+ f.call(input(0), input(1), input(2), input(3))
+ }
+}
+
+case class JDFUDF5(f: UDF5[Column, Column, Column, Column, Column, Column])
+ extends UDF5[Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 5)
+ f.call(input(0), input(1), input(2), input(3), input(4))
+ }
+}
+
+case class JDFUDF6(f: UDF6[Column, Column, Column, Column, Column, Column, Column])
+ extends UDF6[Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 6)
+ f.call(input(0), input(1), input(2), input(3), input(4), input(5))
+ }
+}
+
+case class JDFUDF7(f: UDF7[Column, Column, Column, Column, Column, Column, Column, Column])
+ extends UDF7[Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 7)
+ f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6))
+ }
+}
+
+case class JDFUDF8(f: UDF8[Column, Column, Column, Column, Column, Column, Column, Column,
+ Column])
+ extends UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 8)
+ f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7))
+ }
+}
+
+case class JDFUDF9(f: UDF9[Column, Column, Column, Column, Column, Column, Column, Column,
+ Column, Column])
+ extends UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any,
+ t9: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 9)
+ f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8))
+ }
+}
+
+case class JDFUDF10(f: UDF10[Column, Column, Column, Column, Column, Column, Column, Column,
+ Column, Column, Column])
+ extends UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any] with DFUDF {
+ override def call(t1: Any, t2: Any, t3: Any, t4: Any, t5: Any, t6: Any, t7: Any, t8: Any,
+ t9: Any, t10: Any): Any = {
+ throw new IllegalStateException("TODO better error message. This should have been replaced")
+ }
+
+ override def apply(input: Array[Column]): Column = {
+ assert(input.length == 10)
+ f.call(input(0), input(1), input(2), input(3), input(4), input(5), input(6), input(7), input(8),
+ input(9))
+ }
+}
+
+object DFUDF {
+ /**
+ * Determine if the UDF function implements the DFUDF.
+ */
+ def getDFUDF(function: AnyRef): Option[DFUDF] = {
+ function match {
+ case f: DFUDF => Some(f)
+ case f =>
+ try {
+ // This may be a lambda that Spark's UDFRegistration wrapped around a Java UDF instance.
+ val clazz = f.getClass
+ if (Utils.getSimpleName(clazz).toLowerCase().contains("lambda")) {
+ // Try to find a `writeReplace` method, further indicating it is likely a lambda
+ // instance, and invoke it to serialize the lambda. Once serialized, captured arguments
+ // can be examined to locate the Java UDF instance.
+ // Note this relies on implementation details of Spark's UDFRegistration class.
+ val writeReplace = clazz.getDeclaredMethod("writeReplace")
+ writeReplace.setAccessible(true)
+ val serializedLambda = writeReplace.invoke(f).asInstanceOf[SerializedLambda]
+ if (serializedLambda.getCapturedArgCount == 1) {
+ serializedLambda.getCapturedArg(0) match {
+ case c: DFUDF => Some(c)
+ case _ => None
+ }
+ } else {
+ None
+ }
+ } else {
+ None
+ }
+ } catch {
+ case _: ClassCastException | _: NoSuchMethodException | _: SecurityException => None
+ }
+ }
+ }
+}
diff --git a/df_udf/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala b/df_udf/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala
new file mode 100644
index 00000000000..5b51aeeb991
--- /dev/null
+++ b/df_udf/src/main/spark320/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala
@@ -0,0 +1,50 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "320"}
+{"spark": "321"}
+{"spark": "321cdh"}
+{"spark": "322"}
+{"spark": "323"}
+{"spark": "324"}
+{"spark": "330"}
+{"spark": "330cdh"}
+{"spark": "330db"}
+{"spark": "331"}
+{"spark": "332"}
+{"spark": "332cdh"}
+{"spark": "332db"}
+{"spark": "333"}
+{"spark": "334"}
+{"spark": "340"}
+{"spark": "341"}
+{"spark": "341db"}
+{"spark": "342"}
+{"spark": "343"}
+{"spark": "350"}
+{"spark": "351"}
+{"spark": "352"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.nvidia
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions.Expression
+
+object DFUDFShims {
+ def columnToExpr(c: Column): Expression = c.expr
+ def exprToColumn(e: Expression): Column = Column(e)
+}
diff --git a/df_udf/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala b/df_udf/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala
new file mode 100644
index 00000000000..e67dfb450d8
--- /dev/null
+++ b/df_udf/src/main/spark400/scala/org/apache/spark/sql/nvidia/DFUDFShims.scala
@@ -0,0 +1,29 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*** spark-rapids-shim-json-lines
+{"spark": "400"}
+spark-rapids-shim-json-lines ***/
+package org.apache.spark.sql.nvidia
+
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
+
+object DFUDFShims {
+ def columnToExpr(c: Column): Expression = c
+ def exprToColumn(e: Expression): Column = e
+}
diff --git a/df_udf/src/test/scala/com/nvidia/spark/functionsSuite.scala b/df_udf/src/test/scala/com/nvidia/spark/functionsSuite.scala
new file mode 100644
index 00000000000..ae6d46aefdf
--- /dev/null
+++ b/df_udf/src/test/scala/com/nvidia/spark/functionsSuite.scala
@@ -0,0 +1,443 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.nvidia.spark
+
+import com.nvidia.spark.functions._
+
+import org.apache.spark.sql.{Column, Row}
+import org.apache.spark.sql.api.java._
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.nvidia.SparkTestBase
+import org.apache.spark.sql.types._
+
+class functionsSuite extends SparkTestBase {
+ test("basic 0 arg df_udf") {
+ val zero = df_udf(() => lit(0))
+ withSparkSession{ spark =>
+ spark.udf.register("zero", zero)
+ assertSame(Array(
+ Row(0L, 0),
+ Row(1L, 0)),
+ spark.range(2).selectExpr("id", "zero()").collect())
+ assertSame(Array(
+ Row(0L, 0),
+ Row(1L, 0)),
+ spark.range(2).select(col("id"), zero()).collect())
+ }
+ }
+
+ test("basic 1 arg df_udf") {
+ val inc = df_udf((input: Column) => input + 1)
+ withSparkSession { spark =>
+ spark.udf.register("inc", inc)
+ assertSame(Array(
+ Row(0L, 1L),
+ Row(1L, 2L)),
+ spark.range(2).selectExpr("id", "inc(id)").collect())
+ assertSame(Array(
+ Row(0L, 1L),
+ Row(1L, 2L)),
+ spark.range(2).select(col("id"), inc(col("id"))).collect())
+ }
+ }
+
+
+ test("basic 2 arg df_udf") {
+ val add = df_udf((a: Column, b:Column) => a + b)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 0L),
+ Row(1L, 2L)),
+ spark.range(2).selectExpr("id", "add(id, id)").collect())
+ assertSame(Array(
+ Row(0L, 0L),
+ Row(1L, 2L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"))).collect())
+ }
+ }
+
+ test("basic 3 arg df_udf") {
+ val add = df_udf((a: Column, b:Column, c:Column) => a + b + c)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 0L),
+ Row(1L, 3L)),
+ spark.range(2).selectExpr("id", "add(id, id, id)").collect())
+ assertSame(Array(
+ Row(0L, 0L),
+ Row(1L, 3L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), col("id"))).collect())
+ }
+ }
+
+ test("basic 4 arg df_udf") {
+ val add = df_udf((a: Column, b:Column, c:Column, d:Column) => a + b + c + d)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 1L),
+ Row(1L, 4L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id)").collect())
+ assertSame(Array(
+ Row(0L, 1L),
+ Row(1L, 4L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), col("id"))).collect())
+ }
+ }
+
+ test("basic 5 arg df_udf") {
+ val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column) =>
+ a + b + c + d + e)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 5L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1)").collect())
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 5L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1))).collect())
+ }
+ }
+
+ test("basic 6 arg df_udf") {
+ val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column, f:Column) =>
+ a + b + c + d + e + f)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 6L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id)").collect())
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 6L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"))).collect())
+ }
+ }
+
+ test("basic 7 arg df_udf") {
+ val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column,
+ f:Column, g:Column) => a + b + c + d + e + f + g)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 7L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id)").collect())
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 7L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"), col("id"))).collect())
+ }
+ }
+
+ test("basic 8 arg df_udf") {
+ val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column,
+ f:Column, g:Column, h:Column) => a + b + c + d + e + f + g + h)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 9L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2)").collect())
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 9L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"), col("id"), lit(2))).collect())
+ }
+ }
+
+ test("basic 9 arg df_udf") {
+ val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column,
+ f:Column, g:Column, h:Column, i:Column) =>
+ a + b + c + d + e + f + g + h + i)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 10L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2, id)").collect())
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 10L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"), col("id"), lit(2), col("id"))).collect())
+ }
+ }
+
+ test("basic 10 arg df_udf") {
+ val add = df_udf((a: Column, b:Column, c:Column, d:Column, e:Column,
+ f:Column, g:Column, h:Column, i:Column, j:Column) =>
+ a + b + c + d + e + f + g + h + i + j)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 11L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2, id, id)").collect())
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 11L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"), col("id"), lit(2), col("id"), col("id"))).collect())
+ }
+ }
+
+ test("nested df_udf") {
+ val add = df_udf((a: Column, b:Column) => a + b)
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 22L),
+ Row(1L, 25L)),
+ spark.range(2).selectExpr("id", "add(add(id, 12), add(add(id, id), 10))").collect())
+ }
+ }
+
+ test("complex df_udf") {
+ val extractor = df_udf((json: Column) => {
+ val schema = StructType(Seq(StructField("values", ArrayType(LongType))))
+ val extracted_json = from_json(json, schema, Map.empty[String, String])
+ aggregate(extracted_json("values"),
+ lit(0L),
+ (a, b) => coalesce(a, lit(0L)) + coalesce(b, lit(0L)),
+ a => a)
+ })
+ withSparkSession { spark =>
+ import spark.implicits._
+ spark.udf.register("extractor", extractor)
+ assertSame(Array(
+ Row(6L),
+ Row(3L)),
+ Seq("""{"values":[1,2,3]}""",
+ """{"values":[1, null, null, 2]}""").toDF("json").selectExpr("extractor(json)").collect())
+ }
+ }
+
+ test("j basic 0 arg df_udf") {
+ val zero = df_udf(new UDF0[Column] {
+ override def call(): Column = lit(0)
+ })
+ withSparkSession{ spark =>
+ spark.udf.register("zero", zero)
+ assertSame(Array(
+ Row(0L, 0),
+ Row(1L, 0)),
+ spark.range(2).selectExpr("id", "zero()").collect())
+ assertSame(Array(
+ Row(0L, 0),
+ Row(1L, 0)),
+ spark.range(2).select(col("id"), zero()).collect())
+ }
+ }
+
+ test("jbasic 1 arg df_udf") {
+ val inc = df_udf(new UDF1[Column, Column] {
+ override def call(a: Column): Column = a + 1
+ })
+ withSparkSession { spark =>
+ spark.udf.register("inc", inc)
+ assertSame(Array(
+ Row(0L, 1L),
+ Row(1L, 2L)),
+ spark.range(2).selectExpr("id", "inc(id)").collect())
+ assertSame(Array(
+ Row(0L, 1L),
+ Row(1L, 2L)),
+ spark.range(2).select(col("id"), inc(col("id"))).collect())
+ }
+ }
+
+ test("jbasic 2 arg df_udf") {
+ val add = df_udf(new UDF2[Column, Column, Column] {
+ override def call(a: Column, b:Column): Column = a + b
+ })
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 0L),
+ Row(1L, 2L)),
+ spark.range(2).selectExpr("id", "add(id, id)").collect())
+ assertSame(Array(
+ Row(0L, 0L),
+ Row(1L, 2L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"))).collect())
+ }
+ }
+
+ test("jbasic 3 arg df_udf") {
+ val add = df_udf(new UDF3[Column, Column, Column, Column] {
+ override def call(a: Column, b: Column, c: Column): Column = a + b + c
+ })
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 0L),
+ Row(1L, 3L)),
+ spark.range(2).selectExpr("id", "add(id, id, id)").collect())
+ assertSame(Array(
+ Row(0L, 0L),
+ Row(1L, 3L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), col("id"))).collect())
+ }
+ }
+
+ test("jbasic 4 arg df_udf") {
+ val add = df_udf(new UDF4[Column, Column, Column, Column, Column] {
+ override def call(a: Column, b:Column, c:Column, d:Column): Column = a + b + c + d
+ })
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 1L),
+ Row(1L, 4L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id)").collect())
+ assertSame(Array(
+ Row(0L, 1L),
+ Row(1L, 4L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1), col("id"))).collect())
+ }
+ }
+
+ test("jbasic 5 arg df_udf") {
+ val add = df_udf(new UDF5[Column, Column, Column, Column, Column, Column] {
+ override def call(a: Column, b: Column, c: Column, d: Column, e: Column): Column =
+ a + b + c + d + e
+ })
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 5L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1)").collect())
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 5L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1))).collect())
+ }
+ }
+
+ test("jbasic 6 arg df_udf") {
+ val add = df_udf(new UDF6[Column, Column, Column, Column, Column, Column, Column] {
+ override def call(a: Column, b:Column, c:Column, d:Column, e:Column, f:Column) =
+ a + b + c + d + e + f
+ })
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 6L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id)").collect())
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 6L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"))).collect())
+ }
+ }
+
+ test("jbasic 7 arg df_udf") {
+ val add = df_udf(new UDF7[Column, Column, Column, Column, Column, Column, Column,
+ Column] {
+ override def call(a: Column, b:Column, c:Column, d:Column, e:Column,
+ f:Column, g:Column): Column = a + b + c + d + e + f + g
+ })
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 7L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id)").collect())
+ assertSame(Array(
+ Row(0L, 2L),
+ Row(1L, 7L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"), col("id"))).collect())
+ }
+ }
+
+ test("jbasic 8 arg df_udf") {
+ val add = df_udf(new UDF8[Column, Column, Column, Column, Column, Column, Column,
+ Column, Column] {
+ override def call(a: Column, b: Column, c: Column, d: Column, e: Column,
+ f: Column, g: Column, h: Column): Column = a + b + c + d + e + f + g + h
+ })
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 9L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2)").collect())
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 9L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"), col("id"), lit(2))).collect())
+ }
+ }
+
+ test("jbasic 9 arg df_udf") {
+ val add = df_udf(new UDF9[Column, Column, Column, Column, Column, Column, Column,
+ Column, Column, Column] {
+ override def call(a: Column, b:Column, c:Column, d:Column, e:Column,
+ f:Column, g:Column, h:Column, i:Column): Column =
+ a + b + c + d + e + f + g + h + i
+ })
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 10L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2, id)").collect())
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 10L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"), col("id"), lit(2), col("id"))).collect())
+ }
+ }
+
+ test("jbasic 10 arg df_udf") {
+ val add = df_udf(new UDF10[Column, Column, Column, Column, Column, Column, Column,
+ Column, Column, Column, Column] {
+ override def call(a: Column, b:Column, c:Column, d:Column, e:Column,
+ f:Column, g:Column, h:Column, i:Column, j:Column): Column =
+ a + b + c + d + e + f + g + h + i + j
+ })
+ withSparkSession { spark =>
+ spark.udf.register("add", add)
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 11L)),
+ spark.range(2).selectExpr("id", "add(id, id, 1, id, 1, id, id, 2, id, id)").collect())
+ assertSame(Array(
+ Row(0L, 4L),
+ Row(1L, 11L)),
+ spark.range(2).select(col("id"), add(col("id"), col("id"), lit(1),
+ col("id"), lit(1), col("id"), col("id"), lit(2), col("id"), col("id"))).collect())
+ }
+ }
+}
\ No newline at end of file
diff --git a/df_udf/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala b/df_udf/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala
new file mode 100644
index 00000000000..2bd6697ffad
--- /dev/null
+++ b/df_udf/src/test/scala/org/apache/spark/sql/nvidia/SparkTestBase.scala
@@ -0,0 +1,175 @@
+/*
+ * Copyright (c) 2024, NVIDIA CORPORATION.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.nvidia
+
+import java.io.File
+import java.nio.file.Files
+import java.util.{Locale, TimeZone}
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.funsuite.AnyFunSuite
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Row, SparkSession}
+
+object SparkSessionHolder extends Logging {
+ private var spark = createSparkSession()
+ private var origConf = spark.conf.getAll
+ private var origConfKeys = origConf.keys.toSet
+
+ private def setAllConfs(confs: Array[(String, String)]): Unit = confs.foreach {
+ case (key, value) if spark.conf.get(key, null) != value =>
+ spark.conf.set(key, value)
+ case _ => // No need to modify it
+ }
+
+ private def createSparkSession(): SparkSession = {
+ SparkSession.cleanupAnyExistingSession()
+
+ TimeZone.setDefault(TimeZone.getTimeZone("UTC"))
+ Locale.setDefault(Locale.US)
+
+ val builder = SparkSession.builder()
+ .master("local[1]")
+ .config("spark.sql.extensions", "com.nvidia.spark.DFUDFPlugin")
+ .config("spark.sql.warehouse.dir", sparkWarehouseDir.getAbsolutePath)
+ .appName("dataframe udf tests")
+
+ builder.getOrCreate()
+ }
+
+ private def reinitSession(): Unit = {
+ spark = createSparkSession()
+ origConf = spark.conf.getAll
+ origConfKeys = origConf.keys.toSet
+ }
+
+ def sparkSession: SparkSession = {
+ if (SparkSession.getActiveSession.isEmpty) {
+ reinitSession()
+ }
+ spark
+ }
+
+ def resetSparkSessionConf(): Unit = {
+ if (SparkSession.getActiveSession.isEmpty) {
+ reinitSession()
+ } else {
+ setAllConfs(origConf.toArray)
+ val currentKeys = spark.conf.getAll.keys.toSet
+ val toRemove = currentKeys -- origConfKeys
+ if (toRemove.contains("spark.shuffle.manager")) {
+ // cannot unset the config so need to reinitialize
+ reinitSession()
+ } else {
+ toRemove.foreach(spark.conf.unset)
+ }
+ }
+ logDebug(s"RESET CONF TO: ${spark.conf.getAll}")
+ }
+
+ def withSparkSession[U](conf: SparkConf, f: SparkSession => U): U = {
+ resetSparkSessionConf()
+ logDebug(s"SETTING CONF: ${conf.getAll.toMap}")
+ setAllConfs(conf.getAll)
+ logDebug(s"RUN WITH CONF: ${spark.conf.getAll}\n")
+ f(spark)
+ }
+
+ private lazy val sparkWarehouseDir: File = {
+ new File(System.getProperty("java.io.tmpdir")).mkdirs()
+ val path = Files.createTempDirectory("spark-warehouse")
+ val file = new File(path.toString)
+ file.deleteOnExit()
+ file
+ }
+}
+
+/**
+ * Base to be able to run tests with a spark context
+ */
+trait SparkTestBase extends AnyFunSuite with BeforeAndAfterAll {
+ def withSparkSession[U](f: SparkSession => U): U = {
+ withSparkSession(new SparkConf, f)
+ }
+
+ def withSparkSession[U](conf: SparkConf, f: SparkSession => U): U = {
+ SparkSessionHolder.withSparkSession(conf, f)
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ SparkSession.cleanupAnyExistingSession()
+ }
+
+ def assertSame(expected: Any, actual: Any, epsilon: Double = 0.0,
+ path: List[String] = List.empty): Unit = {
+ def assertDoublesAreEqualWithinPercentage(expected: Double,
+ actual: Double, path: List[String]): Unit = {
+ if (expected != actual) {
+ if (expected != 0) {
+ val v = Math.abs((expected - actual) / expected)
+ assert(v <= epsilon,
+ s"$path: ABS($expected - $actual) / ABS($actual) == $v is not <= $epsilon ")
+ } else {
+ val v = Math.abs(expected - actual)
+ assert(v <= epsilon, s"$path: ABS($expected - $actual) == $v is not <= $epsilon ")
+ }
+ }
+ }
+ (expected, actual) match {
+ case (a: Float, b: Float) if a.isNaN && b.isNaN =>
+ case (a: Double, b: Double) if a.isNaN && b.isNaN =>
+ case (null, null) =>
+ case (null, other) => fail(s"$path: expected is null, but actual is $other")
+ case (other, null) => fail(s"$path: expected is $other, but actual is null")
+ case (a: Array[_], b: Array[_]) =>
+ assert(a.length == b.length,
+ s"$path: expected (${a.toList}) and actual (${b.toList}) lengths don't match")
+ a.indices.foreach { i =>
+ assertSame(a(i), b(i), epsilon, path :+ i.toString)
+ }
+ case (a: Map[_, _], b: Map[_, _]) =>
+ throw new IllegalStateException(s"Maps are not supported yet for comparison $a vs $b")
+ case (a: Iterable[_], b: Iterable[_]) =>
+ assert(a.size == b.size,
+ s"$path: expected (${a.toList}) and actual (${b.toList}) lengths don't match")
+ var i = 0
+ a.zip(b).foreach {
+ case (l, r) =>
+ assertSame(l, r, epsilon, path :+ i.toString)
+ i += 1
+ }
+ case (a: Product, b: Product) =>
+ assertSame(a.productIterator.toSeq, b.productIterator.toSeq, epsilon, path)
+ case (a: Row, b: Row) =>
+ assertSame(a.toSeq, b.toSeq, epsilon, path)
+ // 0.0 == -0.0, turn float/double to bits before comparison, to distinguish 0.0 and -0.0.
+ case (a: Double, b: Double) if epsilon <= 0 =>
+ java.lang.Double.doubleToRawLongBits(a) == java.lang.Double.doubleToRawLongBits(b)
+ case (a: Double, b: Double) if epsilon > 0 =>
+ assertDoublesAreEqualWithinPercentage(a, b, path)
+ case (a: Float, b: Float) if epsilon <= 0 =>
+ java.lang.Float.floatToRawIntBits(a) == java.lang.Float.floatToRawIntBits(b)
+ case (a: Float, b: Float) if epsilon > 0 =>
+ assertDoublesAreEqualWithinPercentage(a, b, path)
+ case (a, b) =>
+ assert(a == b, s"$path: $a != $b")
+ }
+ }
+}
diff --git a/pom.xml b/pom.xml
index 7a4b7e56d85..bfb8a50946e 100644
--- a/pom.xml
+++ b/pom.xml
@@ -73,6 +73,7 @@
aggregator
datagen
+ df_udf
dist
integration_tests
shuffle-plugin
diff --git a/scala2.13/datagen/pom.xml b/scala2.13/datagen/pom.xml
index 6c01e912f94..d53ebc014c7 100644
--- a/scala2.13/datagen/pom.xml
+++ b/scala2.13/datagen/pom.xml
@@ -33,6 +33,7 @@
**/*
package
+ ${project.build.outputDirectory}/datagen-version-info.properties
diff --git a/scala2.13/df_udf/pom.xml b/scala2.13/df_udf/pom.xml
new file mode 100644
index 00000000000..04f7a6deb28
--- /dev/null
+++ b/scala2.13/df_udf/pom.xml
@@ -0,0 +1,88 @@
+
+
+
+ 4.0.0
+
+ com.nvidia
+ rapids-4-spark-shim-deps-parent_2.13
+ 24.12.0-SNAPSHOT
+ ../shim-deps/pom.xml
+
+ df_udf_plugin_2.13
+ UDFs implemented in SQL/Dataframe
+ UDFs for Apache Spark implemented in SQL/Dataframe
+ 24.12.0-SNAPSHOT
+
+
+ df_udf
+
+ **/*
+ package
+ ${project.build.outputDirectory}/df_udf-version-info.properties
+
+
+
+
+ org.scala-lang
+ scala-library
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${spark.test.version}
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+ true
+
+
+
+ net.alchim31.maven
+ scala-maven-plugin
+
+
+ org.scalatest
+ scalatest-maven-plugin
+
+
+ org.apache.rat
+ apache-rat-plugin
+
+
+
+
+
+
+ ${project.build.directory}/extra-resources
+
+
+
+
diff --git a/scala2.13/pom.xml b/scala2.13/pom.xml
index f17a90f4633..e22f311561a 100644
--- a/scala2.13/pom.xml
+++ b/scala2.13/pom.xml
@@ -73,6 +73,7 @@
aggregator
datagen
+ df_udf
dist
integration_tests
shuffle-plugin