Skip to content

Commit

Permalink
[SPARK-48566][PYTHON] Fix bug where partition indices are incorrect w…
Browse files Browse the repository at this point in the history
…hen UDTF analyze() uses both select and partitionColumns

### What changes were proposed in this pull request?

This PR fixes a bug that resulted in an internal error with some combination of the Python UDTF "select" and "partitionBy" options of the "analyze" method.

Specifically, this logic in `Analyzer.scala` was wrong because it did not update the usage of `partitioningExpressionIndexes` to take the "select" expressions into account when they were introduced in apache#45007:

```
val tvfWithTableColumnIndexes = tvf match {
  case g  Generate(pyudtf: PythonUDTF, _, _, _, _, _)
      if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty =>
    //////////////////////////////////////////////////////////////////////////////
    // The bug is here: the 'partitioningExpressionIndexes' are not valid
    // if the UDTF "select" expressions are non-empty, since that prompts
    // us to add a new projection (of a possibly different number of
    // expressions) to evaluate them.
    //////////////////////////////////////////////////////////////////////////////
    val partitionColumnIndexes =
      PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes)
    g.copy(generator = pyudtf.copy(
      pythonUDTFPartitionColumnIndexes = Some(partitionColumnIndexes)))
  case _ => tvf
}
```

To reproduce:

```
from pyspark.sql.functions import (
    AnalyzeArgument,
    AnalyzeResult,
    PartitioningColumn,
    SelectedColumn,
    udtf
)

from pyspark.sql.types import (
    DoubleType,
    StringType,
    StructType,
)

udtf
class TestTvf:
    staticmethod
    def analyze(observed: AnalyzeArgument) -> AnalyzeResult:
        out_schema = StructType()
        out_schema.add("partition_col", StringType())
        out_schema.add("double_col", DoubleType())

        return AnalyzeResult(
            schema=out_schema,
            partitionBy=[PartitioningColumn("partition_col")],
            select=[
                SelectedColumn("partition_col"),
                SelectedColumn("double_col"),
            ],
        )

    def eval(self, *args, **kwargs):
        pass

    def terminate(self):
        for _ in range(10):
            yield {
                "partition_col": None,
                "double_col": 1.0,
            }

spark.udtf.register("serialize_test", TestTvf)

# Fails
(
    spark
    .sql(
        """
        SELECT * FROM serialize_test(
            TABLE(
                SELECT
                    5 AS unused_col,
                    'hi' AS partition_col,
                    1.0 AS double_col

                UNION ALL

                SELECT
                    4 AS unused_col,
                    'hi' AS partition_col,
                    1.0 AS double_col
            )
        )
        """
    )
    .toPandas()
)
```

### Why are the changes needed?

The above query returned internal errors before, but works now.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Additional golden file coverage

### Was this patch authored or co-authored using generative AI tooling?

Some light GitHub copilot usage

Closes apache#46918 from dtenedor/fix-udtf-bug.

Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Signed-off-by: Takuya Ueshin <ueshin@databricks.com>
  • Loading branch information
dtenedor authored and ueshin committed Jun 17, 2024
1 parent 66d8a29 commit 0864bbe
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2206,11 +2206,19 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}")

// Propagate the column indexes for TABLE arguments to the PythonUDTF instance.
val f: FunctionTableSubqueryArgumentExpression = tableArgs.head._1
val tvfWithTableColumnIndexes = tvf match {
case g @ Generate(pyudtf: PythonUDTF, _, _, _, _, _)
if tableArgs.head._1.partitioningExpressionIndexes.nonEmpty =>
val partitionColumnIndexes =
PythonUDTFPartitionColumnIndexes(tableArgs.head._1.partitioningExpressionIndexes)
if f.extraProjectedPartitioningExpressions.nonEmpty =>
val partitionColumnIndexes = if (f.selectedInputExpressions.isEmpty) {
PythonUDTFPartitionColumnIndexes(f.partitioningExpressionIndexes)
} else {
// If the UDTF specified 'select' expression(s), we added a projection to compute
// them plus the 'partitionBy' expression(s) afterwards.
PythonUDTFPartitionColumnIndexes(
(0 until f.extraProjectedPartitioningExpressions.length)
.map(_ + f.selectedInputExpressions.length))
}
g.copy(generator = pyudtf.copy(
pythonUDTFPartitionColumnIndexes = Some(partitionColumnIndexes)))
case _ => tvf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,12 @@ case class FunctionTableSubqueryArgumentExpression(
}
}

private lazy val extraProjectedPartitioningExpressions: Seq[Alias] = {
lazy val extraProjectedPartitioningExpressions: Seq[Alias] = {
partitionByExpressions.filter { e =>
!subqueryOutputs.contains(e)
!subqueryOutputs.contains(e) ||
// Skip deduplicating the 'partitionBy' expression(s) against the attributes of the input
// table if the UDTF also specified 'select' expression(s).
selectedInputExpressions.nonEmpty
}.zipWithIndex.map { case (expr, index) =>
Alias(expr, s"partition_by_$index")()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,26 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
}


-- !query
SELECT * FROM UDTFPartitionByIndexingBug(
TABLE(
SELECT
5 AS unused_col,
'hi' AS partition_col,
1.0 AS double_col

UNION ALL

SELECT
4 AS unused_col,
'hi' AS partition_col,
1.0 AS double_col
)
)
-- !query analysis
[Analyzer test output redacted due to nondeterminism]


-- !query
DROP VIEW t1
-- !query analysis
Expand Down
16 changes: 16 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ SELECT * FROM UDTFWithSinglePartition(1, invalid_arg_name => 2);
SELECT * FROM UDTFWithSinglePartition(1, initial_count => 2);
SELECT * FROM UDTFWithSinglePartition(initial_count => 1, initial_count => 2);
SELECT * FROM UDTFInvalidPartitionByOrderByParseError(TABLE(t2));
-- Exercise the UDTF partitioning bug.
SELECT * FROM UDTFPartitionByIndexingBug(
TABLE(
SELECT
5 AS unused_col,
'hi' AS partition_col,
1.0 AS double_col

UNION ALL

SELECT
4 AS unused_col,
'hi' AS partition_col,
1.0 AS double_col
)
);

-- cleanup
DROP VIEW t1;
Expand Down
26 changes: 26 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,32 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException
}


-- !query
SELECT * FROM UDTFPartitionByIndexingBug(
TABLE(
SELECT
5 AS unused_col,
'hi' AS partition_col,
1.0 AS double_col

UNION ALL

SELECT
4 AS unused_col,
'hi' AS partition_col,
1.0 AS double_col
)
)
-- !query schema
struct<partition_col:string,double_col:double>
-- !query output
NULL 1.0
NULL 1.0
NULL 1.0
NULL 1.0
NULL 1.0


-- !query
DROP VIEW t1
-- !query schema
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,49 @@ object IntegratedUDFTestUtils extends SQLHelper {
orderBy = "OrderingColumn(\"input\")",
select = "SelectedColumn(\"partition_col\")")

object UDTFPartitionByIndexingBug extends TestUDTF {
val pythonScript: String =
s"""
|from pyspark.sql.functions import (
| AnalyzeArgument,
| AnalyzeResult,
| PartitioningColumn,
| SelectedColumn,
| udtf
|)
|from pyspark.sql.types import (
| DoubleType,
| StringType,
| StructType,
|)
|class $name:
| @staticmethod
| def analyze(observed: AnalyzeArgument) -> AnalyzeResult:
| out_schema = StructType()
| out_schema.add("partition_col", StringType())
| out_schema.add("double_col", DoubleType())
|
| return AnalyzeResult(
| schema=out_schema,
| partitionBy=[PartitioningColumn("partition_col")],
| select=[
| SelectedColumn("partition_col"),
| SelectedColumn("double_col"),
| ],
| )
|
| def eval(self, *args, **kwargs):
| pass
|
| def terminate(self):
| for _ in range(5):
| yield {
| "partition_col": None,
| "double_col": 1.0,
| }
|""".stripMargin
}

object UDTFInvalidPartitionByOrderByParseError
extends TestPythonUDTFPartitionByOrderByBase(
partitionBy = "PartitioningColumn(\"unparsable\")",
Expand Down Expand Up @@ -1216,6 +1259,7 @@ object IntegratedUDFTestUtils extends SQLHelper {
UDTFPartitionByOrderBySelectExpr,
UDTFPartitionByOrderBySelectComplexExpr,
UDTFPartitionByOrderBySelectExprOnlyPartitionColumn,
UDTFPartitionByIndexingBug,
InvalidAnalyzeMethodReturnsNonStructTypeSchema,
InvalidAnalyzeMethodWithSinglePartitionNoInputTable,
InvalidAnalyzeMethodWithPartitionByNoInputTable,
Expand Down

0 comments on commit 0864bbe

Please sign in to comment.