Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Ngone51 committed Feb 2, 2021
1 parent 15445a8 commit 0ab2b7a
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,18 @@ def test_case_insensitive_grouping_column(self):
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())

def test_self_join(self):
# SPARK-34319: self-join with FlatMapCoGroupsInPandas
df = self.spark.createDataFrame([(1, 1)], ("column", "value"))

row = df.groupby("ColUmn").cogroup(
df.groupby("COLUMN")
).applyInPandas(lambda r, l: r + l, "column long, value long")

row = row.join(row).first()

self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())

@staticmethod
def _test_with_key(left, right, isLeft):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1405,6 +1405,10 @@ class Analyzer(override val catalogManager: CatalogManager)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))

case oldVersion @ FlatMapCoGroupsInPandas(_, _, _, output, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
Seq((oldVersion, oldVersion.copy(output = output.map(_.newInstance()))))

case oldVersion: Generate
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,29 @@ class AnalysisSuite extends AnalysisTest with Matchers {
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
}

test("SPARK-34319: analysis fails on self-join with FlatMapCoGroupsInPandas") {
val pythonUdf = PythonUDF("pyUDF", null,
StructType(Seq(StructField("a", LongType))),
Seq.empty,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
true)
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
val project1 = Project(Seq(UnresolvedAttribute("a")), testRelation)
val project2 = Project(Seq(UnresolvedAttribute("a")), testRelation2)
val flatMapGroupsInPandas = FlatMapCoGroupsInPandas(
Seq(UnresolvedAttribute("a")),
Seq(UnresolvedAttribute("a")),
pythonUdf,
output,
project1,
project2)
val left = SubqueryAlias("temp0", flatMapGroupsInPandas)
val right = SubqueryAlias("temp1", flatMapGroupsInPandas)
val join = Join(left, right, Inner, None, JoinHint.NONE)
assertAnalysisSuccess(
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
}

test("SPARK-24488 Generator with multiple aliases") {
assertAnalysisSuccess(
listRelation.select(Explode($"list").as("first_alias").as("second_alias")))
Expand Down

0 comments on commit 0ab2b7a

Please sign in to comment.