From 86106fadcaed6c1a4768138b3d72e8c892b7cd7f Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Sat, 17 Nov 2018 15:36:45 +0100 Subject: [PATCH] address comments --- .../spark/sql/catalyst/optimizer/subquery.scala | 3 +-- .../scala/org/apache/spark/sql/SubquerySuite.scala | 14 ++++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index c11470a22908e..ea658ad461f7e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -76,8 +76,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { private def dedupSubqueryOnSelfJoin(values: Seq[Expression], sub: LogicalPlan): LogicalPlan = { val leftRefs = AttributeSet.fromAttributeSets(values.map(_.references)) - val rightRefs = AttributeSet(sub.output) - val duplicates = leftRefs.intersect(rightRefs) + val duplicates = leftRefs.intersect(sub.outputSet) if (duplicates.isEmpty) { sub } else { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 5a525a6d4ab28..007a063d071cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -1284,12 +1284,14 @@ class SubquerySuite extends QueryTest with SharedSQLContext { test("SPARK-26078: deduplicate fake self joins for IN subqueries") { withTempView("a", "b") { - val a = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row("a", 2), Row("b", 1))), - StructType(Seq(StructField("id", StringType), StructField("num", IntegerType)))) - val b = spark.createDataFrame(spark.sparkContext.parallelize(Seq(Row("a", 2), Row("b", 1))), - StructType(Seq(StructField("id", StringType), StructField("num", IntegerType)))) - a.createOrReplaceTempView("a") - b.createOrReplaceTempView("b") + def genTestViewWithName(name: String): Unit = { + val df = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row("a", 2), Row("b", 1))), + StructType(Seq(StructField("id", StringType), StructField("num", IntegerType)))) + df.createOrReplaceTempView(name) + } + genTestViewWithName("a") + genTestViewWithName("b") val df1 = spark.sql( """