Skip to content

Commit

Permalink
[SPARK-44834][PYTHON][SQL][TESTS] Add SQL query tests for Python UDTFs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds a new sql query test suite for running Python UDTFs in SQL. You can trigger the test using
```
 SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "sql/testOnly *SQLQueryTestSuite -- -z udtf/udtf.sql"
```

### Why are the changes needed?

To add more test cases for Python UDTFs.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added new golden file tests.

Closes #42517 from allisonwang-db/spark-44834-udtf-sql-test.

Authored-by: allisonwang-db <allison.wang@databricks.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
allisonwang-db authored and ueshin committed Aug 17, 2023
1 parent 047b224 commit be04ac1
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2)
-- !query analysis
CreateViewCommand `t1`, VALUES (0, 1), (1, 2) t(c1, c2), false, true, LocalTempView, true
+- SubqueryAlias t
+- LocalRelation [c1#x, c2#x]


-- !query
SELECT * FROM udtf(1, 2)
-- !query analysis
Project [x#x, y#x]
+- Generate TestUDTF(1, 2)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM udtf(-1, 0)
-- !query analysis
Project [x#x, y#x]
+- Generate TestUDTF(-1, 0)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM udtf(0, -1)
-- !query analysis
Project [x#x, y#x]
+- Generate TestUDTF(0, -1)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM udtf(0, 0)
-- !query analysis
Project [x#x, y#x]
+- Generate TestUDTF(0, 0)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT a, b FROM udtf(1, 2) t(a, b)
-- !query analysis
Project [a#x, b#x]
+- SubqueryAlias t
+- Project [x#x AS a#x, y#x AS b#x]
+- Generate TestUDTF(1, 2)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM t1, LATERAL udtf(c1, c2)
-- !query analysis
Project [c1#x, c2#x, x#x, y#x]
+- LateralJoin lateral-subquery#x [c1#x && c2#x], Inner
: +- Generate TestUDTF(outer(c1#x), outer(c2#x))#x, false, [x#x, y#x]
: +- OneRowRelation
+- SubqueryAlias t1
+- View (`t1`, [c1#x,c2#x])
+- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x]
+- SubqueryAlias t
+- LocalRelation [c1#x, c2#x]


-- !query
SELECT * FROM t1 LEFT JOIN LATERAL udtf(c1, c2)
-- !query analysis
Project [c1#x, c2#x, x#x, y#x]
+- LateralJoin lateral-subquery#x [c1#x && c2#x], LeftOuter
: +- Generate TestUDTF(outer(c1#x), outer(c2#x))#x, false, [x#x, y#x]
: +- OneRowRelation
+- SubqueryAlias t1
+- View (`t1`, [c1#x,c2#x])
+- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x]
+- SubqueryAlias t
+- LocalRelation [c1#x, c2#x]


-- !query
SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2)
-- !query analysis
Project [c1#x, c2#x, x#x, y#x]
+- LateralJoin lateral-subquery#x [c1#x && c2#x], Inner
: +- Generate TestUDTF(outer(c1#x), outer(c2#x))#x, false, [x#x, y#x]
: +- OneRowRelation
+- SubqueryAlias t
+- Project [x#x AS c1#x, y#x AS c2#x]
+- Generate TestUDTF(1, 2)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1)
-- !query analysis
[Analyzer test output redacted due to nondeterminism]
18 changes: 18 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2);

-- test basic udtf
SELECT * FROM udtf(1, 2);
SELECT * FROM udtf(-1, 0);
SELECT * FROM udtf(0, -1);
SELECT * FROM udtf(0, 0);

-- test column alias
SELECT a, b FROM udtf(1, 2) t(a, b);

-- test lateral join
SELECT * FROM t1, LATERAL udtf(c1, c2);
SELECT * FROM t1 LEFT JOIN LATERAL udtf(c1, c2);
SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2);

-- test non-deterministic input
SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1);
85 changes: 85 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2)
-- !query schema
struct<>
-- !query output



-- !query
SELECT * FROM udtf(1, 2)
-- !query schema
struct<x:int,y:int>
-- !query output
1 -1
2 1


-- !query
SELECT * FROM udtf(-1, 0)
-- !query schema
struct<x:int,y:int>
-- !query output



-- !query
SELECT * FROM udtf(0, -1)
-- !query schema
struct<x:int,y:int>
-- !query output



-- !query
SELECT * FROM udtf(0, 0)
-- !query schema
struct<x:int,y:int>
-- !query output
0 0


-- !query
SELECT a, b FROM udtf(1, 2) t(a, b)
-- !query schema
struct<a:int,b:int>
-- !query output
1 -1
2 1


-- !query
SELECT * FROM t1, LATERAL udtf(c1, c2)
-- !query schema
struct<c1:int,c2:int,x:int,y:int>
-- !query output
1 2 1 -1
1 2 2 1


-- !query
SELECT * FROM t1 LEFT JOIN LATERAL udtf(c1, c2)
-- !query schema
struct<c1:int,c2:int,x:int,y:int>
-- !query output
1 2 1 -1
1 2 2 1


-- !query
SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2)
-- !query schema
struct<c1:int,c2:int,x:int,y:int>
-- !query output
2 1 1 -1
2 1 2 1


-- !query
SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1)
-- !query schema
struct<x:int,y:int>
-- !query output
1 0
1 0
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ object IntegratedUDFTestUtils extends SQLHelper {
val prettyName: String
}

sealed trait TestUDTF {
def apply(session: SparkSession, exprs: Column*): DataFrame

val prettyName: String
}

class PythonUDFWithoutId(
name: String,
func: PythonFunction,
Expand Down Expand Up @@ -409,6 +415,32 @@ object IntegratedUDFTestUtils extends SQLHelper {
udfDeterministic = deterministic)
}

case class TestPythonUDTF(name: String) extends TestUDTF {
private val pythonScript: String =
"""
|class TestUDTF:
| def eval(self, a: int, b: int):
| if a > 0 and b > 0:
| yield a, a - b
| yield b, b - a
| elif a == 0 and b == 0:
| yield 0, 0
| else:
| ...
|""".stripMargin

private[IntegratedUDFTestUtils] lazy val udtf = createUserDefinedPythonTableFunction(
name = "TestUDTF",
pythonScript = pythonScript,
returnType = StructType.fromDDL("x int, y int")
)

def apply(session: SparkSession, exprs: Column*): DataFrame =
udtf.apply(session, exprs: _*)

val prettyName: String = "Regular Python UDTF"
}

/**
* A Scalar Pandas UDF that takes one column, casts into string, executes the
* Python native function, and casts back to the type of input column.
Expand Down Expand Up @@ -588,4 +620,12 @@ object IntegratedUDFTestUtils extends SQLHelper {
case udf: TestScalaUDF => session.udf.register(udf.name, udf.udf)
case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]")
}

/**
* Register UDTFs used in the test cases.
*/
def registerTestUDTF(testUDTF: TestUDTF, session: SparkSession): Unit = testUDTF match {
case udtf: TestPythonUDTF => session.udtf.registerPython(udtf.name, udtf.udtf)
case other => throw new RuntimeException(s"Unknown UDTF class [${other.getClass}]")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
val udf: TestUDF
}

protected trait UDTFTest {
val udtf: TestUDTF
}

/** A regular test case. */
protected case class RegularTestCase(
name: String, inputFile: String, resultFile: String) extends TestCase {
Expand Down Expand Up @@ -237,6 +241,16 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
UDFAnalyzerTestCase(newName, inputFile, newResultFile, udf)
}

protected case class UDTFTestCase(
name: String,
inputFile: String,
resultFile: String,
udtf: TestUDTF) extends TestCase with UDTFTest {

override def asAnalyzerTest(newName: String, newResultFile: String): TestCase =
UDTFAnalyzerTestCase(newName, inputFile, newResultFile, udtf)
}

/** A UDAF test case. */
protected case class UDAFTestCase(
name: String,
Expand Down Expand Up @@ -285,6 +299,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
protected case class UDFAnalyzerTestCase(
name: String, inputFile: String, resultFile: String, udf: TestUDF)
extends AnalyzerTest with UDFTest
protected case class UDTFAnalyzerTestCase(
name: String, inputFile: String, resultFile: String, udtf: TestUDTF)
extends AnalyzerTest with UDTFTest
protected case class UDAFAnalyzerTestCase(
name: String, inputFile: String, resultFile: String, udf: TestUDF)
extends AnalyzerTest with UDFTest
Expand Down Expand Up @@ -488,6 +505,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
testCase match {
case udfTestCase: UDFTest =>
registerTestUDF(udfTestCase.udf, localSparkSession)
case udtfTestCase: UDTFTest =>
registerTestUDTF(udtfTestCase.udtf, localSparkSession)
case _ =>
}

Expand Down Expand Up @@ -574,6 +593,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
shouldTestPandasUDFs =>
s"${testCase.name}${System.lineSeparator()}" +
s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}"
case udtfTestCase: UDTFTest
if udtfTestCase.udtf.isInstanceOf[TestPythonUDTF] && shouldTestPythonUDFs =>
s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}"
case _ =>
s"${testCase.name}${System.lineSeparator()}"
}
Expand Down Expand Up @@ -620,6 +642,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
UDAFTestCase(
s"$testCaseName - ${udf.prettyName}", absPath, resultFile, udf)
}
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udtf")) {
Seq(TestPythonUDTF("udtf")).map { udtf =>
UDTFTestCase(
s"$testCaseName - ${udtf.prettyName}", absPath, resultFile, udtf
)
}
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}postgreSQL")) {
PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}ansi")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ
Seq.empty
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udaf")) {
Seq.empty
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udtf")) {
Seq.empty
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}postgreSQL")) {
PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}ansi")) {
Expand Down

0 comments on commit be04ac1

Please sign in to comment.