Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44834][PYTHON][SQL][TESTS] Add SQL query tests for Python UDTFs #42517

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2)
-- !query analysis
CreateViewCommand `t1`, VALUES (0, 1), (1, 2) t(c1, c2), false, true, LocalTempView, true
+- SubqueryAlias t
+- LocalRelation [c1#x, c2#x]


-- !query
SELECT * FROM udtf(1, 2)
-- !query analysis
Project [x#x, y#x]
+- Generate TestUDTF(1, 2)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM udtf(-1, 0)
-- !query analysis
Project [x#x, y#x]
+- Generate TestUDTF(-1, 0)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM udtf(0, -1)
-- !query analysis
Project [x#x, y#x]
+- Generate TestUDTF(0, -1)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM udtf(0, 0)
-- !query analysis
Project [x#x, y#x]
+- Generate TestUDTF(0, 0)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT a, b FROM udtf(1, 2) t(a, b)
-- !query analysis
Project [a#x, b#x]
+- SubqueryAlias t
+- Project [x#x AS a#x, y#x AS b#x]
+- Generate TestUDTF(1, 2)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM t1, LATERAL udtf(c1, c2)
-- !query analysis
Project [c1#x, c2#x, x#x, y#x]
+- LateralJoin lateral-subquery#x [c1#x && c2#x], Inner
: +- Generate TestUDTF(outer(c1#x), outer(c2#x))#x, false, [x#x, y#x]
: +- OneRowRelation
+- SubqueryAlias t1
+- View (`t1`, [c1#x,c2#x])
+- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x]
+- SubqueryAlias t
+- LocalRelation [c1#x, c2#x]


-- !query
SELECT * FROM t1 LEFT JOIN LATERAL udtf(c1, c2)
-- !query analysis
Project [c1#x, c2#x, x#x, y#x]
+- LateralJoin lateral-subquery#x [c1#x && c2#x], LeftOuter
: +- Generate TestUDTF(outer(c1#x), outer(c2#x))#x, false, [x#x, y#x]
: +- OneRowRelation
+- SubqueryAlias t1
+- View (`t1`, [c1#x,c2#x])
+- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x]
+- SubqueryAlias t
+- LocalRelation [c1#x, c2#x]


-- !query
SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2)
-- !query analysis
Project [c1#x, c2#x, x#x, y#x]
+- LateralJoin lateral-subquery#x [c1#x && c2#x], Inner
: +- Generate TestUDTF(outer(c1#x), outer(c2#x))#x, false, [x#x, y#x]
: +- OneRowRelation
+- SubqueryAlias t
+- Project [x#x AS c1#x, y#x AS c2#x]
+- Generate TestUDTF(1, 2)#x, false, [x#x, y#x]
+- OneRowRelation


-- !query
SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1)
-- !query analysis
[Analyzer test output redacted due to nondeterminism]
18 changes: 18 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2);

-- test basic udtf
SELECT * FROM udtf(1, 2);
SELECT * FROM udtf(-1, 0);
SELECT * FROM udtf(0, -1);
SELECT * FROM udtf(0, 0);

-- test column alias
SELECT a, b FROM udtf(1, 2) t(a, b);

-- test lateral join
SELECT * FROM t1, LATERAL udtf(c1, c2);
SELECT * FROM t1 LEFT JOIN LATERAL udtf(c1, c2);
SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2);

-- test non-deterministic input
SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1);
85 changes: 85 additions & 0 deletions sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
-- Automatically generated by SQLQueryTestSuite
-- !query
CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2)
-- !query schema
struct<>
-- !query output



-- !query
SELECT * FROM udtf(1, 2)
-- !query schema
struct<x:int,y:int>
-- !query output
1 -1
2 1


-- !query
SELECT * FROM udtf(-1, 0)
-- !query schema
struct<x:int,y:int>
-- !query output



-- !query
SELECT * FROM udtf(0, -1)
-- !query schema
struct<x:int,y:int>
-- !query output



-- !query
SELECT * FROM udtf(0, 0)
-- !query schema
struct<x:int,y:int>
-- !query output
0 0


-- !query
SELECT a, b FROM udtf(1, 2) t(a, b)
-- !query schema
struct<a:int,b:int>
-- !query output
1 -1
2 1


-- !query
SELECT * FROM t1, LATERAL udtf(c1, c2)
-- !query schema
struct<c1:int,c2:int,x:int,y:int>
-- !query output
1 2 1 -1
1 2 2 1


-- !query
SELECT * FROM t1 LEFT JOIN LATERAL udtf(c1, c2)
-- !query schema
struct<c1:int,c2:int,x:int,y:int>
-- !query output
1 2 1 -1
1 2 2 1


-- !query
SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2)
-- !query schema
struct<c1:int,c2:int,x:int,y:int>
-- !query output
2 1 1 -1
2 1 2 1


-- !query
SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1)
-- !query schema
struct<x:int,y:int>
-- !query output
1 0
1 0
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ object IntegratedUDFTestUtils extends SQLHelper {
val prettyName: String
}

sealed trait TestUDTF {
def apply(session: SparkSession, exprs: Column*): DataFrame

val prettyName: String
}

class PythonUDFWithoutId(
name: String,
func: PythonFunction,
Expand Down Expand Up @@ -409,6 +415,32 @@ object IntegratedUDFTestUtils extends SQLHelper {
udfDeterministic = deterministic)
}

case class TestPythonUDTF(name: String) extends TestUDTF {
private val pythonScript: String =
"""
|class TestUDTF:
| def eval(self, a: int, b: int):
| if a > 0 and b > 0:
| yield a, a - b
| yield b, b - a
| elif a == 0 and b == 0:
| yield 0, 0
| else:
| ...
|""".stripMargin

private[IntegratedUDFTestUtils] lazy val udtf = createUserDefinedPythonTableFunction(
name = "TestUDTF",
pythonScript = pythonScript,
returnType = StructType.fromDDL("x int, y int")
)

def apply(session: SparkSession, exprs: Column*): DataFrame =
udtf.apply(session, exprs: _*)

val prettyName: String = "Regular Python UDTF"
}

/**
* A Scalar Pandas UDF that takes one column, casts into string, executes the
* Python native function, and casts back to the type of input column.
Expand Down Expand Up @@ -588,4 +620,12 @@ object IntegratedUDFTestUtils extends SQLHelper {
case udf: TestScalaUDF => session.udf.register(udf.name, udf.udf)
case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]")
}

/**
* Register UDTFs used in the test cases.
*/
def registerTestUDTF(testUDTF: TestUDTF, session: SparkSession): Unit = testUDTF match {
case udtf: TestPythonUDTF => session.udtf.registerPython(udtf.name, udtf.udtf)
case other => throw new RuntimeException(s"Unknown UDTF class [${other.getClass}]")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
val udf: TestUDF
}

protected trait UDTFTest {
val udtf: TestUDTF
}

/** A regular test case. */
protected case class RegularTestCase(
name: String, inputFile: String, resultFile: String) extends TestCase {
Expand Down Expand Up @@ -237,6 +241,16 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
UDFAnalyzerTestCase(newName, inputFile, newResultFile, udf)
}

protected case class UDTFTestCase(
name: String,
inputFile: String,
resultFile: String,
udtf: TestUDTF) extends TestCase with UDTFTest {

override def asAnalyzerTest(newName: String, newResultFile: String): TestCase =
UDTFAnalyzerTestCase(newName, inputFile, newResultFile, udtf)
}

/** A UDAF test case. */
protected case class UDAFTestCase(
name: String,
Expand Down Expand Up @@ -285,6 +299,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
protected case class UDFAnalyzerTestCase(
name: String, inputFile: String, resultFile: String, udf: TestUDF)
extends AnalyzerTest with UDFTest
protected case class UDTFAnalyzerTestCase(
name: String, inputFile: String, resultFile: String, udtf: TestUDTF)
extends AnalyzerTest with UDTFTest
protected case class UDAFAnalyzerTestCase(
name: String, inputFile: String, resultFile: String, udf: TestUDF)
extends AnalyzerTest with UDFTest
Expand Down Expand Up @@ -488,6 +505,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
testCase match {
case udfTestCase: UDFTest =>
registerTestUDF(udfTestCase.udf, localSparkSession)
case udtfTestCase: UDTFTest =>
registerTestUDTF(udtfTestCase.udtf, localSparkSession)
case _ =>
}

Expand Down Expand Up @@ -574,6 +593,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
shouldTestPandasUDFs =>
s"${testCase.name}${System.lineSeparator()}" +
s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}"
case udtfTestCase: UDTFTest
if udtfTestCase.udtf.isInstanceOf[TestPythonUDTF] && shouldTestPythonUDFs =>
s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}"
case _ =>
s"${testCase.name}${System.lineSeparator()}"
}
Expand Down Expand Up @@ -620,6 +642,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
UDAFTestCase(
s"$testCaseName - ${udf.prettyName}", absPath, resultFile, udf)
}
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udtf")) {
Seq(TestPythonUDTF("udtf")).map { udtf =>
UDTFTestCase(
s"$testCaseName - ${udtf.prettyName}", absPath, resultFile, udtf
)
}
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}postgreSQL")) {
PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}ansi")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,8 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ
Seq.empty
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udaf")) {
Seq.empty
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udtf")) {
Seq.empty
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}postgreSQL")) {
PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}ansi")) {
Expand Down