Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25560][SQL] Allow FunctionInjection in SparkExtensions #22576

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ package org.apache.spark.sql
import scala.collection.mutable

import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -65,6 +69,7 @@ class SparkSessionExtensions {
type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
type StrategyBuilder = SparkSession => Strategy
type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)

private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

Expand Down Expand Up @@ -168,4 +173,21 @@ class SparkSessionExtensions {
def injectParser(builder: ParserBuilder): Unit = {
parserBuilders += builder
}

private[this] val injectedFunctions = mutable.Buffer.empty[FunctionDescription]

private[sql] def registerFunctions(functionRegistry: FunctionRegistry) = {
for ((name, expressionInfo, function) <- injectedFunctions) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the stuff that changes the FunctionRegistry into the BaseSessionStateBuilder and just make this return the Seq[FunctionDescription]? The return type of this function a FunctionRegistry sort of implies that you are getting back a new registry instead of a mutated one. If we are mutating then I prefer to do that in the BaseSessionBuilder so it is obvious that this is safe to do because we mutating a clone. It also makes this code more inline with the rest of the extension class (not mutating). Sorry for the late change of heart.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ha we just changed a function in the opposite direction on my other commit. The project should probably pick one dorm and put it in the style guide. I'll make the chznge

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

functionRegistry.registerFunction(name, expressionInfo, function)
}
functionRegistry
}

/**
* Injects a custom function into the [[org.apache.spark.sql.catalyst.analysis.FunctionRegistry]]
* at runtime for all sessions.
*/
def injectFunction(functionDescription: FunctionDescription): Unit = {
injectedFunctions += functionDescription
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ abstract class BaseSessionStateBuilder(
* This either gets cloned from a pre-existing version or cloned from the built-in registry.
*/
protected lazy val functionRegistry: FunctionRegistry = {
parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone()
parentState.map(_.functionRegistry.clone())
.getOrElse(extensions.registerFunctions(FunctionRegistry.builtin.clone()))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.{DataType, IntegerType, StructType}

/**
* Test cases for the [[SparkSessionExtensions]].
Expand Down Expand Up @@ -90,6 +90,16 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}

test("inject function") {
val extensions = create { extensions =>
extensions.injectFunction(MyExtensions.myFunction)
}
withSession(extensions) { session =>
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
}
}

test("use custom class for extensions") {
val session = SparkSession.builder()
.master("local[1]")
Expand All @@ -98,6 +108,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
try {
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
} finally {
stop(session)
}
Expand Down Expand Up @@ -136,9 +148,17 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars
delegate.parseDataType(sqlText)
}

object MyExtensions {

val myFunction = (FunctionIdentifier("myFunction"),
new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended usage" ),
(myArgs: Seq[Expression]) => Literal(5, IntegerType))
}

class MyExtensions extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectPlannerStrategy(MySparkStrategy)
e.injectResolutionRule(MyRule)
e.injectFunction(MyExtensions.myFunction)
}
}