diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a233161713c3c..cd7aeb7cd4ac9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2206,11 +2206,19 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") // Propagate the column indexes for TABLE arguments to the PythonUDTF instance. + val f: FunctionTableSubqueryArgumentExpression = tableArgs.head._1 val tvfWithTableColumnIndexes = tvf match { case g @ Generate(pyudtf: PythonUDTF, _, _, _, _, _) - if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty => - val partitionColumnIndexes = - PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes) + if f.extraProjectedPartitioningExpressions.nonEmpty => + val partitionColumnIndexes = if (f.selectedInputExpressions.isEmpty) { + PythonUDTFPartitionColumnIndexes(f.partitioningExpressionIndexes) + } else { + // If the UDTF specified 'select' expression(s), we added a projection to compute + // them plus the 'partitionBy' expression(s) afterwards. + PythonUDTFPartitionColumnIndexes( + (0 until f.extraProjectedPartitioningExpressions.length) + .map(_ + f.selectedInputExpressions.length)) + } g.copy(generator = pyudtf.copy( pythonUDTFPartitionColumnIndexes = Some(partitionColumnIndexes))) case _ => tvf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala index 94465ccff796e..bfd3bc8051dff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -172,9 +172,12 @@ case class FunctionTableSubqueryArgumentExpression( } } - private lazy val extraProjectedPartitioningExpressions: Seq[Alias] = { + lazy val extraProjectedPartitioningExpressions: Seq[Alias] = { partitionByExpressions.filter { e => - !subqueryOutputs.contains(e) + !subqueryOutputs.contains(e) || + // Skip deduplicating the 'partitionBy' expression(s) against the attributes of the input + // table if the UDTF also specified 'select' expression(s). + selectedInputExpressions.nonEmpty }.zipWithIndex.map { case (expr, index) => Alias(expr, s"partition_by_$index")() } diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out index 74ea9261462d6..4b53f1c6f19c4 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out @@ -904,6 +904,26 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +SELECT * FROM UDTFPartitionByIndexingBug( + TABLE( + SELECT + 5 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + + UNION ALL + + SELECT + 4 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + ) +) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] + + -- !query DROP VIEW t1 -- !query analysis diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql index c83481f10dca6..a437b1f93b604 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql @@ -143,6 +143,22 @@ SELECT * FROM UDTFWithSinglePartition(1, invalid_arg_name => 2); SELECT * FROM UDTFWithSinglePartition(1, initial_count => 2); SELECT * FROM UDTFWithSinglePartition(initial_count => 1, initial_count => 2); SELECT * FROM UDTFInvalidPartitionByOrderByParseError(TABLE(t2)); +-- Exercise the UDTF partitioning bug. +SELECT * FROM UDTFPartitionByIndexingBug( + TABLE( + SELECT + 5 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + + UNION ALL + + SELECT + 4 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + ) +); -- cleanup DROP VIEW t1; diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out index 78ad8b7c02cd5..f99c6c30c07e2 100644 --- a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out @@ -1069,6 +1069,32 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException } +-- !query +SELECT * FROM UDTFPartitionByIndexingBug( + TABLE( + SELECT + 5 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + + UNION ALL + + SELECT + 4 AS unused_col, + 'hi' AS partition_col, + 1.0 AS double_col + ) +) +-- !query schema +struct +-- !query output +NULL 1.0 +NULL 1.0 +NULL 1.0 +NULL 1.0 +NULL 1.0 + + -- !query DROP VIEW t1 -- !query schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index c1ca48162d207..957be07607b66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -660,6 +660,49 @@ object IntegratedUDFTestUtils extends SQLHelper { orderBy = "OrderingColumn(\"input\")", select = "SelectedColumn(\"partition_col\")") + object UDTFPartitionByIndexingBug extends TestUDTF { + val pythonScript: String = + s""" + |from pyspark.sql.functions import ( + | AnalyzeArgument, + | AnalyzeResult, + | PartitioningColumn, + | SelectedColumn, + | udtf + |) + |from pyspark.sql.types import ( + | DoubleType, + | StringType, + | StructType, + |) + |class $name: + | @staticmethod + | def analyze(observed: AnalyzeArgument) -> AnalyzeResult: + | out_schema = StructType() + | out_schema.add("partition_col", StringType()) + | out_schema.add("double_col", DoubleType()) + | + | return AnalyzeResult( + | schema=out_schema, + | partitionBy=[PartitioningColumn("partition_col")], + | select=[ + | SelectedColumn("partition_col"), + | SelectedColumn("double_col"), + | ], + | ) + | + | def eval(self, *args, **kwargs): + | pass + | + | def terminate(self): + | for _ in range(5): + | yield { + | "partition_col": None, + | "double_col": 1.0, + | } + |""".stripMargin + } + object UDTFInvalidPartitionByOrderByParseError extends TestPythonUDTFPartitionByOrderByBase( partitionBy = "PartitioningColumn(\"unparsable\")", @@ -1216,6 +1259,7 @@ object IntegratedUDFTestUtils extends SQLHelper { UDTFPartitionByOrderBySelectExpr, UDTFPartitionByOrderBySelectComplexExpr, UDTFPartitionByOrderBySelectExprOnlyPartitionColumn, + UDTFPartitionByIndexingBug, InvalidAnalyzeMethodReturnsNonStructTypeSchema, InvalidAnalyzeMethodWithSinglePartitionNoInputTable, InvalidAnalyzeMethodWithPartitionByNoInputTable,