diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py index 3c016e04adf2e..94a12bfb3f656 100644 --- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py @@ -203,6 +203,18 @@ def test_case_insensitive_grouping_column(self): ).applyInPandas(lambda r, l: r + l, "column long, value long").first() self.assertEqual(row.asDict(), Row(column=2, value=2).asDict()) + def test_self_join(self): + # SPARK-34319: self-join with FlatMapCoGroupsInPandas + df = self.spark.createDataFrame([(1, 1)], ("column", "value")) + + row = df.groupby("ColUmn").cogroup( + df.groupby("COLUMN") + ).applyInPandas(lambda r, l: r + l, "column long, value long") + + row = row.join(row).first() + + self.assertEqual(row.asDict(), Row(column=2, value=2).asDict()) + @staticmethod def _test_with_key(left, right, isLeft): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 99f6062a0d243..53696a368f9a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1405,6 +1405,10 @@ class Analyzer(override val catalogManager: CatalogManager) if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f66871ee75ecc..196c2e58ea9a5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -631,6 +631,29 @@ class AnalysisSuite extends AnalysisTest with Matchers { Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) } + test("SPARK-34319: analysis fails on self-join with FlatMapCoGroupsInPandas") { + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("a", LongType))), + Seq.empty, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + true) + val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val project1 = Project(Seq(UnresolvedAttribute("a")), testRelation) + val project2 = Project(Seq(UnresolvedAttribute("a")), testRelation2) + val flatMapGroupsInPandas = FlatMapCoGroupsInPandas( + Seq(UnresolvedAttribute("a")), + Seq(UnresolvedAttribute("a")), + pythonUdf, + output, + project1, + project2) + val left = SubqueryAlias("temp0", flatMapGroupsInPandas) + val right = SubqueryAlias("temp1", flatMapGroupsInPandas) + val join = Join(left, right, Inner, None, JoinHint.NONE) + assertAnalysisSuccess( + Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) + } + test("SPARK-24488 Generator with multiple aliases") { assertAnalysisSuccess( listRelation.select(Explode($"list").as("first_alias").as("second_alias")))