Skip to content

Commit

Permalink
[SPARK-25560][SQL] Allow FunctionInjection in SparkExtensions
Browse files Browse the repository at this point in the history
This allows an implementer of Spark Session Extensions to utilize a
method "injectFunction" which will add a new function to the default
Spark Session Catalogue.

## What changes were proposed in this pull request?

Adds a new function to SparkSessionExtensions

    def injectFunction(functionDescription: FunctionDescription)

Where function description is a new type

  type FunctionDescription = (FunctionIdentifier, FunctionBuilder)

The functions are loaded in BaseSessionBuilder when the function registry does not have a parent
function registry to get loaded from.

## How was this patch tested?

New unit tests are added for the extension in SparkSessionExtensionSuite

Closes apache#22576 from RussellSpitzer/SPARK-25560.

Authored-by: Russell Spitzer <Russell.Spitzer@gmail.com>
Signed-off-by: Herman van Hovell <hvanhovell@databricks.com>
  • Loading branch information
RussellSpitzer authored and jackylee-ch committed Feb 18, 2019
1 parent 3656e9f commit f7c19f5
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ package org.apache.spark.sql
import scala.collection.mutable

import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -68,6 +72,7 @@ class SparkSessionExtensions {
type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
type StrategyBuilder = SparkSession => Strategy
type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
type FunctionDescription = (FunctionIdentifier, ExpressionInfo, FunctionBuilder)

private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]

Expand Down Expand Up @@ -171,4 +176,21 @@ class SparkSessionExtensions {
def injectParser(builder: ParserBuilder): Unit = {
parserBuilders += builder
}

private[this] val injectedFunctions = mutable.Buffer.empty[FunctionDescription]

private[sql] def registerFunctions(functionRegistry: FunctionRegistry) = {
for ((name, expressionInfo, function) <- injectedFunctions) {
functionRegistry.registerFunction(name, expressionInfo, function)
}
functionRegistry
}

/**
* Injects a custom function into the [[org.apache.spark.sql.catalyst.analysis.FunctionRegistry]]
* at runtime for all sessions.
*/
def injectFunction(functionDescription: FunctionDescription): Unit = {
injectedFunctions += functionDescription
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ abstract class BaseSessionStateBuilder(
* This either gets cloned from a pre-existing version or cloned from the built-in registry.
*/
protected lazy val functionRegistry: FunctionRegistry = {
parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone()
parentState.map(_.functionRegistry.clone())
.getOrElse(extensions.registerFunctions(FunctionRegistry.builtin.clone()))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ package org.apache.spark.sql

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.types.{DataType, IntegerType, StructType}

/**
* Test cases for the [[SparkSessionExtensions]].
Expand Down Expand Up @@ -90,6 +90,16 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}

test("inject function") {
val extensions = create { extensions =>
extensions.injectFunction(MyExtensions.myFunction)
}
withSession(extensions) { session =>
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
}
}

test("use custom class for extensions") {
val session = SparkSession.builder()
.master("local[1]")
Expand All @@ -98,6 +108,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
try {
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
assert(session.sessionState.functionRegistry
.lookupFunction(MyExtensions.myFunction._1).isDefined)
} finally {
stop(session)
}
Expand Down Expand Up @@ -136,9 +148,17 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars
delegate.parseDataType(sqlText)
}

object MyExtensions {

val myFunction = (FunctionIdentifier("myFunction"),
new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended usage" ),
(myArgs: Seq[Expression]) => Literal(5, IntegerType))
}

class MyExtensions extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectPlannerStrategy(MySparkStrategy)
e.injectResolutionRule(MyRule)
e.injectFunction(MyExtensions.myFunction)
}
}

0 comments on commit f7c19f5

Please sign in to comment.