diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py index 3c016e04adf2e..94a12bfb3f656 100644 --- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py @@ -203,6 +203,18 @@ def test_case_insensitive_grouping_column(self): ).applyInPandas(lambda r, l: r + l, "column long, value long").first() self.assertEqual(row.asDict(), Row(column=2, value=2).asDict()) + def test_self_join(self): + # SPARK-34319: self-join with FlatMapCoGroupsInPandas + df = self.spark.createDataFrame([(1, 1)], ("column", "value")) + + row = df.groupby("ColUmn").cogroup( + df.groupby("COLUMN") + ).applyInPandas(lambda r, l: r + l, "column long, value long") + + row = row.join(row).first() + + self.assertEqual(row.asDict(), Row(column=2, value=2).asDict()) + @staticmethod def _test_with_key(left, right, isLeft): diff --git a/python/pyspark/sql/tests/test_pandas_map.py b/python/pyspark/sql/tests/test_pandas_map.py index d53face702201..e8f92de417dda 100644 --- a/python/pyspark/sql/tests/test_pandas_map.py +++ b/python/pyspark/sql/tests/test_pandas_map.py @@ -112,6 +112,14 @@ def func(iterator): expected = df.collect() self.assertEqual(actual, expected) + def test_self_join(self): + # SPARK-34319: self-join with MapInPandas + df1 = self.spark.range(10) + df2 = df1.mapInPandas(lambda iter: iter, 'id long') + actual = df2.join(df2).collect() + expected = df1.join(df1).collect() + self.assertEqual(sorted(actual), sorted(expected)) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_map import * # noqa: F401 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 99f6062a0d243..a69468429e54b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1405,6 +1405,14 @@ class Analyzer(override val catalogManager: CatalogManager) if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + + case oldVersion @ MapInPandas(_, output, _) + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => + Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance())))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f66871ee75ecc..014db4863ac76 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -631,6 +631,48 @@ class AnalysisSuite extends AnalysisTest with Matchers { Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) } + test("SPARK-34319: analysis fails on self-join with FlatMapCoGroupsInPandas") { + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("a", LongType))), + Seq.empty, + PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, + true) + val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val project1 = Project(Seq(UnresolvedAttribute("a")), testRelation) + val project2 = Project(Seq(UnresolvedAttribute("a")), testRelation2) + val flatMapGroupsInPandas = FlatMapCoGroupsInPandas( + Seq(UnresolvedAttribute("a")), + Seq(UnresolvedAttribute("a")), + pythonUdf, + output, + project1, + project2) + val left = SubqueryAlias("temp0", flatMapGroupsInPandas) + val right = SubqueryAlias("temp1", flatMapGroupsInPandas) + val join = Join(left, right, Inner, None, JoinHint.NONE) + assertAnalysisSuccess( + Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) + } + + test("SPARK-34319: analysis fails on self-join with MapInPandas") { + val pythonUdf = PythonUDF("pyUDF", null, + StructType(Seq(StructField("a", LongType))), + Seq.empty, + PythonEvalType.SQL_MAP_PANDAS_ITER_UDF, + true) + val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes + val project = Project(Seq(UnresolvedAttribute("a")), testRelation) + val mapInPandas = MapInPandas( + pythonUdf, + output, + project) + val left = SubqueryAlias("temp0", mapInPandas) + val right = SubqueryAlias("temp1", mapInPandas) + val join = Join(left, right, Inner, None, JoinHint.NONE) + assertAnalysisSuccess( + Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join)) + } + test("SPARK-24488 Generator with multiple aliases") { assertAnalysisSuccess( listRelation.select(Explode($"list").as("first_alias").as("second_alias")))