From affe80958d366f399466a9dba8e03da7f3b7b9bf Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Nov 2018 20:38:42 +0800 Subject: [PATCH] [SPARK-26147][SQL] only pull out unevaluable python udf from join condition ## What changes were proposed in this pull request? https://github.com/apache/spark/pull/22326 made a mistake that, not all python UDFs are unevaluable in join condition. Only python UDFs that refer to attributes from both join side are unevaluable. This PR fixes this mistake. ## How was this patch tested? a new test Closes #23153 from cloud-fan/join. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- python/pyspark/sql/tests/test_udf.py | 12 ++ .../spark/sql/catalyst/optimizer/joins.scala | 22 ++-- ...PullOutPythonUDFInJoinConditionSuite.scala | 120 ++++++++++++------ 3 files changed, 106 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py index d2dfb52f54475..ed298f724d551 100644 --- a/python/pyspark/sql/tests/test_udf.py +++ b/python/pyspark/sql/tests/test_udf.py @@ -209,6 +209,18 @@ def test_udf_in_join_condition(self): with self.sql_conf({"spark.sql.crossJoin.enabled": True}): self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_in_left_outer_join_condition(self): + # regression test for SPARK-26147 + from pyspark.sql.functions import udf, col + left = self.spark.createDataFrame([Row(a=1)]) + right = self.spark.createDataFrame([Row(b=1)]) + f = udf(lambda a: str(a), StringType()) + # The join condition can't be pushed down, as it refers to attributes from both sides. + # The Python UDF only refer to attributes from one side, so it's evaluable. + df = left.join(right, f("a") == col("b").cast("string"), how="left_outer") + with self.sql_conf({"spark.sql.crossJoin.enabled": True}): + self.assertEqual(df.collect(), [Row(a=1, b=1)]) + def test_udf_in_left_semi_join_condition(self): # regression test for SPARK-25314 from pyspark.sql.functions import udf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index 7149edee0173e..6ebb194d71c2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -155,19 +155,20 @@ object EliminateOuterJoin extends Rule[LogicalPlan] with PredicateHelper { } /** - * PythonUDF in join condition can not be evaluated, this rule will detect the PythonUDF - * and pull them out from join condition. For python udf accessing attributes from only one side, - * they are pushed down by operation push down rules. If not (e.g. user disables filter push - * down rules), we need to pull them out in this rule too. + * PythonUDF in join condition can't be evaluated if it refers to attributes from both join sides. + * See `ExtractPythonUDFs` for details. This rule will detect un-evaluable PythonUDF and pull them + * out from join condition. */ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateHelper { - def hasPythonUDF(expression: Expression): Boolean = { - expression.collectFirst { case udf: PythonUDF => udf }.isDefined + + private def hasUnevaluablePythonUDF(expr: Expression, j: Join): Boolean = { + expr.find { e => + PythonUDF.isScalarPythonUDF(e) && !canEvaluate(e, j.left) && !canEvaluate(e, j.right) + }.isDefined } override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case j @ Join(_, _, joinType, condition) - if condition.isDefined && hasPythonUDF(condition.get) => + case j @ Join(_, _, joinType, Some(cond)) if hasUnevaluablePythonUDF(cond, j) => if (!joinType.isInstanceOf[InnerLike] && joinType != LeftSemi) { // The current strategy only support InnerLike and LeftSemi join because for other type, // it breaks SQL semantic if we run the join condition as a filter after join. If we pass @@ -179,10 +180,9 @@ object PullOutPythonUDFInJoinCondition extends Rule[LogicalPlan] with PredicateH } // If condition expression contains python udf, it will be moved out from // the new join conditions. - val (udf, rest) = - splitConjunctivePredicates(condition.get).partition(hasPythonUDF) + val (udf, rest) = splitConjunctivePredicates(cond).partition(hasUnevaluablePythonUDF(_, j)) val newCondition = if (rest.isEmpty) { - logWarning(s"The join condition:$condition of the join plan contains PythonUDF only," + + logWarning(s"The join condition:$cond of the join plan contains PythonUDF only," + s" it will be moved out and the join plan will be turned to cross join.") None } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala index d3867f2b6bd0e..3f1c91df7f2e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PullOutPythonUDFInJoinConditionSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.optimizer -import org.scalatest.Matchers._ - import org.apache.spark.api.python.PythonEvalType import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -28,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf._ -import org.apache.spark.sql.types.BooleanType +import org.apache.spark.sql.types.{BooleanType, IntegerType} class PullOutPythonUDFInJoinConditionSuite extends PlanTest { @@ -40,13 +38,29 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { CheckCartesianProducts) :: Nil } - val testRelationLeft = LocalRelation('a.int, 'b.int) - val testRelationRight = LocalRelation('c.int, 'd.int) + val attrA = 'a.int + val attrB = 'b.int + val attrC = 'c.int + val attrD = 'd.int + + val testRelationLeft = LocalRelation(attrA, attrB) + val testRelationRight = LocalRelation(attrC, attrD) + + // This join condition refers to attributes from 2 tables, but the PythonUDF inside it only + // refer to attributes from one side. + val evaluableJoinCond = { + val pythonUDF = PythonUDF("evaluable", null, + IntegerType, + Seq(attrA), + PythonEvalType.SQL_BATCHED_UDF, + udfDeterministic = true) + pythonUDF === attrC + } - // Dummy python UDF for testing. Unable to execute. - val pythonUDF = PythonUDF("pythonUDF", null, + // This join condition is a PythonUDF which refers to attributes from 2 tables. + val unevaluableJoinCond = PythonUDF("unevaluable", null, BooleanType, - Seq.empty, + Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) @@ -66,62 +80,76 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { } } - test("inner join condition with python udf only") { - val query = testRelationLeft.join( + test("inner join condition with python udf") { + val query1 = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(pythonUDF)) - val expected = testRelationLeft.join( + condition = Some(unevaluableJoinCond)) + val expected1 = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(pythonUDF).analyze - comparePlanWithCrossJoinEnable(query, expected) + condition = None).where(unevaluableJoinCond).analyze + comparePlanWithCrossJoinEnable(query1, expected1) + + // evaluable PythonUDF will not be touched + val query2 = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) } - test("left semi join condition with python udf only") { - val query = testRelationLeft.join( + test("left semi join condition with python udf") { + val query1 = testRelationLeft.join( testRelationRight, joinType = LeftSemi, - condition = Some(pythonUDF)) - val expected = testRelationLeft.join( + condition = Some(unevaluableJoinCond)) + val expected1 = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(pythonUDF).select('a, 'b).analyze - comparePlanWithCrossJoinEnable(query, expected) + condition = None).where(unevaluableJoinCond).select('a, 'b).analyze + comparePlanWithCrossJoinEnable(query1, expected1) + + // evaluable PythonUDF will not be touched + val query2 = testRelationLeft.join( + testRelationRight, + joinType = LeftSemi, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) } - test("python udf and common condition") { + test("unevaluable python udf and common condition") { val query = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(pythonUDF && 'a.attr === 'c.attr)) + condition = Some(unevaluableJoinCond && 'a.attr === 'c.attr)) val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some('a.attr === 'c.attr)).where(pythonUDF).analyze + condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond).analyze val optimized = Optimize.execute(query.analyze) comparePlans(optimized, expected) } - test("python udf or common condition") { + test("unevaluable python udf or common condition") { val query = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some(pythonUDF || 'a.attr === 'c.attr)) + condition = Some(unevaluableJoinCond || 'a.attr === 'c.attr)) val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = None).where(pythonUDF || 'a.attr === 'c.attr).analyze + condition = None).where(unevaluableJoinCond || 'a.attr === 'c.attr).analyze comparePlanWithCrossJoinEnable(query, expected) } - test("pull out whole complex condition with multiple python udf") { + test("pull out whole complex condition with multiple unevaluable python udf") { val pythonUDF1 = PythonUDF("pythonUDF1", null, BooleanType, - Seq.empty, + Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) - val condition = (pythonUDF || 'a.attr === 'c.attr) && pythonUDF1 + val condition = (unevaluableJoinCond || 'a.attr === 'c.attr) && pythonUDF1 val query = testRelationLeft.join( testRelationRight, @@ -134,13 +162,13 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { comparePlanWithCrossJoinEnable(query, expected) } - test("partial pull out complex condition with multiple python udf") { + test("partial pull out complex condition with multiple unevaluable python udf") { val pythonUDF1 = PythonUDF("pythonUDF1", null, BooleanType, - Seq.empty, + Seq(attrA, attrC), PythonEvalType.SQL_BATCHED_UDF, udfDeterministic = true) - val condition = (pythonUDF || pythonUDF1) && 'a.attr === 'c.attr + val condition = (unevaluableJoinCond || pythonUDF1) && 'a.attr === 'c.attr val query = testRelationLeft.join( testRelationRight, @@ -149,23 +177,41 @@ class PullOutPythonUDFInJoinConditionSuite extends PlanTest { val expected = testRelationLeft.join( testRelationRight, joinType = Inner, - condition = Some('a.attr === 'c.attr)).where(pythonUDF || pythonUDF1).analyze + condition = Some('a.attr === 'c.attr)).where(unevaluableJoinCond || pythonUDF1).analyze + val optimized = Optimize.execute(query.analyze) + comparePlans(optimized, expected) + } + + test("pull out unevaluable python udf when it's mixed with evaluable one") { + val query = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond && unevaluableJoinCond)) + val expected = testRelationLeft.join( + testRelationRight, + joinType = Inner, + condition = Some(evaluableJoinCond)).where(unevaluableJoinCond).analyze val optimized = Optimize.execute(query.analyze) comparePlans(optimized, expected) } test("throw an exception for not support join type") { for (joinType <- unsupportedJoinTypes) { - val thrownException = the [AnalysisException] thrownBy { + val e = intercept[AnalysisException] { val query = testRelationLeft.join( testRelationRight, joinType, - condition = Some(pythonUDF)) + condition = Some(unevaluableJoinCond)) Optimize.execute(query.analyze) } - assert(thrownException.message.contentEquals( + assert(e.message.contentEquals( s"Using PythonUDF in join condition of join type $joinType is not supported.")) + + val query2 = testRelationLeft.join( + testRelationRight, + joinType, + condition = Some(evaluableJoinCond)) + comparePlans(Optimize.execute(query2), query2) } } } -