Skip to content

Commit

Permalink
[SPARK-26293][SQL] Cast exception when having python udf in subquery
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This is a regression introduced by #22104 at Spark 2.4.0.

When we have Python UDF in subquery, we will hit an exception
```
Caused by: java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.AttributeReference cannot be cast to org.apache.spark.sql.catalyst.expressions.PythonUDF
	at scala.collection.immutable.Stream.map(Stream.scala:414)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:98)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:815)
...
```

#22104 turned `ExtractPythonUDFs` from a physical rule to optimizer rule. However, there is a difference between a physical rule and optimizer rule. A physical rule always runs once, an optimizer rule may be applied twice on a query tree even the rule is located in a batch that only runs once.

For a subquery, the `OptimizeSubqueries` rule will execute the entire optimizer on the query plan inside subquery. Later on subquery will be turned to joins, and the optimizer rules will be applied to it again.

Unfortunately, the `ExtractPythonUDFs` rule is not idempotent. When it's applied twice on a query plan inside subquery, it will produce a malformed plan. It extracts Python UDF from Python exec plans.

This PR proposes 2 changes to be double safe:
1. `ExtractPythonUDFs` should skip python exec plans, to make the rule idempotent
2. `ExtractPythonUDFs` should skip subquery

## How was this patch tested?

a new test.

Closes #23248 from cloud-fan/python.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan committed Dec 11, 2018
1 parent cbe9230 commit 7d5f6e8
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 40 deletions.
52 changes: 19 additions & 33 deletions python/pyspark/sql/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from pyspark import SparkContext
from pyspark.sql import SparkSession, Column, Row
from pyspark.sql.functions import UserDefinedFunction
from pyspark.sql.functions import UserDefinedFunction, udf
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
Expand Down Expand Up @@ -102,7 +102,6 @@ def test_udf_registration_return_type_not_none(self):

def test_nondeterministic_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
self.assertEqual(udf_random_col.deterministic, False)
Expand All @@ -113,7 +112,6 @@ def test_nondeterministic_udf(self):

def test_nondeterministic_udf2(self):
import random
from pyspark.sql.functions import udf
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
self.assertEqual(random_udf.deterministic, False)
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
Expand All @@ -132,7 +130,6 @@ def test_nondeterministic_udf2(self):

def test_nondeterministic_udf3(self):
# regression test for SPARK-23233
from pyspark.sql.functions import udf
f = udf(lambda x: x)
# Here we cache the JVM UDF instance.
self.spark.range(1).select(f("id"))
Expand All @@ -144,7 +141,7 @@ def test_nondeterministic_udf3(self):
self.assertFalse(deterministic)

def test_nondeterministic_udf_in_aggregate(self):
from pyspark.sql.functions import udf, sum
from pyspark.sql.functions import sum
import random
udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
df = self.spark.range(10)
Expand Down Expand Up @@ -181,7 +178,6 @@ def test_multiple_udfs(self):
self.assertEqual(tuple(row), (6, 5))

def test_udf_in_filter_on_top_of_outer_join(self):
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(a=1)])
df = left.join(right, on='a', how='left_outer')
Expand All @@ -190,7 +186,6 @@ def test_udf_in_filter_on_top_of_outer_join(self):

def test_udf_in_filter_on_top_of_join(self):
# regression test for SPARK-18589
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -199,7 +194,6 @@ def test_udf_in_filter_on_top_of_join(self):

def test_udf_in_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -211,7 +205,7 @@ def test_udf_in_join_condition(self):

def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import col
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a: str(a), StringType())
Expand All @@ -223,7 +217,6 @@ def test_udf_in_left_outer_join_condition(self):

def test_udf_in_left_semi_join_condition(self):
# regression test for SPARK-25314
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -236,7 +229,6 @@ def test_udf_in_left_semi_join_condition(self):
def test_udf_and_common_filter_in_join_condition(self):
# regression test for SPARK-25314
# test the complex scenario with both udf and common filter
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -247,7 +239,6 @@ def test_udf_and_common_filter_in_join_condition(self):
def test_udf_and_common_filter_in_left_semi_join_condition(self):
# regression test for SPARK-25314
# test the complex scenario with both udf and common filter
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand All @@ -258,7 +249,6 @@ def test_udf_and_common_filter_in_left_semi_join_condition(self):
def test_udf_not_supported_in_join_condition(self):
# regression test for SPARK-25314
# test python udf is not supported in join type besides left_semi and inner join.
from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
Expand Down Expand Up @@ -301,7 +291,7 @@ def test_broadcast_in_udf(self):

def test_udf_with_filter_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col
from pyspark.sql.functions import col
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a < 2, BooleanType())
Expand All @@ -310,7 +300,7 @@ def test_udf_with_filter_function(self):

def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col, sum
from pyspark.sql.functions import col, sum
from pyspark.sql.types import BooleanType

my_filter = udf(lambda a: a == 1, BooleanType())
Expand All @@ -326,7 +316,7 @@ def test_udf_with_aggregate_function(self):
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])

def test_udf_in_generate(self):
from pyspark.sql.functions import udf, explode
from pyspark.sql.functions import explode
df = self.spark.range(5)
f = udf(lambda x: list(range(x)), ArrayType(LongType()))
row = df.select(explode(f(*df))).groupBy().sum().first()
Expand All @@ -353,7 +343,6 @@ def test_udf_in_generate(self):
self.assertEqual(res[3][1], 1)

def test_udf_with_order_by_and_limit(self):
from pyspark.sql.functions import udf
my_copy = udf(lambda x: x, IntegerType())
df = self.spark.range(10).orderBy("id")
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
Expand Down Expand Up @@ -394,14 +383,14 @@ def test_non_existed_udaf(self):
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))

def test_udf_with_input_file_name(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.functions import input_file_name
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
self.assertTrue(row[0].find("people1.json") != -1)

def test_udf_with_input_file_name_for_hadooprdd(self):
from pyspark.sql.functions import udf, input_file_name
from pyspark.sql.functions import input_file_name

def filename(path):
return path
Expand All @@ -427,9 +416,6 @@ def test_udf_defers_judf_initialization(self):
# This is separate of UDFInitializationTests
# to avoid context initialization
# when udf is called

from pyspark.sql.functions import UserDefinedFunction

f = UserDefinedFunction(lambda x: x, StringType())

self.assertIsNone(
Expand All @@ -445,8 +431,6 @@ def test_udf_defers_judf_initialization(self):
)

def test_udf_with_string_return_type(self):
from pyspark.sql.functions import UserDefinedFunction

add_one = UserDefinedFunction(lambda x: x + 1, "integer")
make_pair = UserDefinedFunction(lambda x: (-x, x), "struct<x:integer,y:integer>")
make_array = UserDefinedFunction(
Expand All @@ -460,13 +444,11 @@ def test_udf_with_string_return_type(self):
self.assertTupleEqual(expected, actual)

def test_udf_shouldnt_accept_noncallable_object(self):
from pyspark.sql.functions import UserDefinedFunction

non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())

def test_udf_with_decorator(self):
from pyspark.sql.functions import lit, udf
from pyspark.sql.functions import lit
from pyspark.sql.types import IntegerType, DoubleType

@udf(IntegerType())
Expand Down Expand Up @@ -523,7 +505,6 @@ def as_double(x):
)

def test_udf_wrapper(self):
from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType

def f(x):
Expand Down Expand Up @@ -569,7 +550,7 @@ def test_nonparam_udf_with_aggregate(self):
# SPARK-24721
@unittest.skipIf(not test_compiled, test_not_compiled_message)
def test_datasource_with_udf(self):
from pyspark.sql.functions import udf, lit, col
from pyspark.sql.functions import lit, col

path = tempfile.mkdtemp()
shutil.rmtree(path)
Expand Down Expand Up @@ -609,8 +590,6 @@ def test_datasource_with_udf(self):

# SPARK-25591
def test_same_accumulator_in_udfs(self):
from pyspark.sql.functions import udf

data_schema = StructType([StructField("a", IntegerType(), True),
StructField("b", IntegerType(), True)])
data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
Expand All @@ -632,6 +611,15 @@ def second_udf(x):
data.collect()
self.assertEqual(test_accum.value, 101)

# SPARK-26293
def test_udf_in_subquery(self):
f = udf(lambda x: x, "long")
with self.tempView("v"):
self.spark.range(1).filter(f("id") >= 0).createTempView("v")
sql = self.spark.sql
result = sql("select i from values(0L) as data(i) where i in (select id from v)")
self.assertEqual(result.collect(), [Row(i=0)])


class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
Expand All @@ -642,8 +630,6 @@ def tearDown(self):
SparkContext._active_spark_context.stop()

def test_udf_init_shouldnt_initialize_context(self):
from pyspark.sql.functions import UserDefinedFunction

UserDefinedFunction(lambda x: x, StringType())

self.assertIsNone(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
/**
* A logical plan that evaluates a [[PythonUDF]].
*/
case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode
case class ArrowEvalPython(
udfs: Seq[PythonUDF],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
}

/**
* A physical plan that evaluates a [[PythonUDF]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* A logical plan that evaluates a [[PythonUDF]]
*/
case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
extends UnaryNode
case class BatchEvalPython(
udfs: Seq[PythonUDF],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
}

/**
* A physical plan that evaluates a [[PythonUDF]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule


Expand Down Expand Up @@ -131,8 +131,20 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
expressions.flatMap(collectEvaluableUDFs)
}

def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case plan: LogicalPlan => extract(plan)
def apply(plan: LogicalPlan): LogicalPlan = plan match {
// SPARK-26293: A subquery will be rewritten into join later, and will go through this rule
// eventually. Here we skip subquery, as Python UDF only needs to be extracted once.
case _: Subquery => plan

case _ => plan transformUp {
// A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and
// `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't
// extract Python UDFs from them.
case p: BatchEvalPython => p
case p: ArrowEvalPython => p

case plan: LogicalPlan => extract(plan)
}
}

/**
Expand Down

0 comments on commit 7d5f6e8

Please sign in to comment.