diff --git a/python/pyspark/sql/tests/test_arrow_cogrouped_map.py b/python/pyspark/sql/tests/test_arrow_cogrouped_map.py index a90574b7f1928..27a520d2843ed 100644 --- a/python/pyspark/sql/tests/test_arrow_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_arrow_cogrouped_map.py @@ -299,6 +299,16 @@ def summarize(left, right): "+---------+------------+----------+-------------+\n", ) + def test_self_join(self): + df = self.spark.createDataFrame([(1, 1)], ("k", "v")) + + def arrow_func(key, left, right): + return pa.Table.from_pydict({"x": [2], "y": [2]}) + + df2 = df.groupby("k").cogroup(df.groupby("k")).applyInArrow(arrow_func, "x long, y long") + + self.assertEqual(df2.join(df2).count(), 1) + class CogroupedMapInArrowTests(CogroupedMapInArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/sql/tests/test_arrow_grouped_map.py b/python/pyspark/sql/tests/test_arrow_grouped_map.py index f9947d0788b87..213810e882fd9 100644 --- a/python/pyspark/sql/tests/test_arrow_grouped_map.py +++ b/python/pyspark/sql/tests/test_arrow_grouped_map.py @@ -255,6 +255,16 @@ def foo(_): self.assertEqual(r.a, "hi") self.assertEqual(r.b, 1) + def test_self_join(self): + df = self.spark.createDataFrame([(1, 1)], ("k", "v")) + + def arrow_func(key, table): + return pa.Table.from_pydict({"x": [2], "y": [2]}) + + df2 = df.groupby("k").applyInArrow(arrow_func, schema="x long, y long") + + self.assertEqual(df2.join(df2).count(), 1) + class GroupedMapInArrowTests(GroupedMapInArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala index c1535343d7686..52be631d94d85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala @@ -132,6 +132,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] { _.output.map(_.exprId.id), newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance()))) + case f: FlatMapGroupsInArrow => + deduplicateAndRenew[FlatMapGroupsInArrow]( + existingRelations, + f, + _.output.map(_.exprId.id), + newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance()))) + case f: FlatMapCoGroupsInPandas => deduplicateAndRenew[FlatMapCoGroupsInPandas]( existingRelations, @@ -139,6 +146,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] { _.output.map(_.exprId.id), newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance()))) + case f: FlatMapCoGroupsInArrow => + deduplicateAndRenew[FlatMapCoGroupsInArrow]( + existingRelations, + f, + _.output.map(_.exprId.id), + newFlatMap => newFlatMap.copy(output = newFlatMap.output.map(_.newInstance()))) + case m: MapInPandas => deduplicateAndRenew[MapInPandas]( existingRelations,