diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out new file mode 100644 index 0000000000000..acf96794378e1 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/udtf/udtf.sql.out @@ -0,0 +1,96 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2) +-- !query analysis +CreateViewCommand `t1`, VALUES (0, 1), (1, 2) t(c1, c2), false, true, LocalTempView, true + +- SubqueryAlias t + +- LocalRelation [c1#x, c2#x] + + +-- !query +SELECT * FROM udtf(1, 2) +-- !query analysis +Project [x#x, y#x] ++- Generate TestUDTF(1, 2)#x, false, [x#x, y#x] + +- OneRowRelation + + +-- !query +SELECT * FROM udtf(-1, 0) +-- !query analysis +Project [x#x, y#x] ++- Generate TestUDTF(-1, 0)#x, false, [x#x, y#x] + +- OneRowRelation + + +-- !query +SELECT * FROM udtf(0, -1) +-- !query analysis +Project [x#x, y#x] ++- Generate TestUDTF(0, -1)#x, false, [x#x, y#x] + +- OneRowRelation + + +-- !query +SELECT * FROM udtf(0, 0) +-- !query analysis +Project [x#x, y#x] ++- Generate TestUDTF(0, 0)#x, false, [x#x, y#x] + +- OneRowRelation + + +-- !query +SELECT a, b FROM udtf(1, 2) t(a, b) +-- !query analysis +Project [a#x, b#x] ++- SubqueryAlias t + +- Project [x#x AS a#x, y#x AS b#x] + +- Generate TestUDTF(1, 2)#x, false, [x#x, y#x] + +- OneRowRelation + + +-- !query +SELECT * FROM t1, LATERAL udtf(c1, c2) +-- !query analysis +Project [c1#x, c2#x, x#x, y#x] ++- LateralJoin lateral-subquery#x [c1#x && c2#x], Inner + : +- Generate TestUDTF(outer(c1#x), outer(c2#x))#x, false, [x#x, y#x] + : +- OneRowRelation + +- SubqueryAlias t1 + +- View (`t1`, [c1#x,c2#x]) + +- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x] + +- SubqueryAlias t + +- LocalRelation [c1#x, c2#x] + + +-- !query +SELECT * FROM t1 LEFT JOIN LATERAL udtf(c1, c2) +-- !query analysis +Project [c1#x, c2#x, x#x, y#x] ++- LateralJoin lateral-subquery#x [c1#x && c2#x], LeftOuter + : +- Generate TestUDTF(outer(c1#x), outer(c2#x))#x, false, [x#x, y#x] + : +- OneRowRelation + +- SubqueryAlias t1 + +- View (`t1`, [c1#x,c2#x]) + +- Project [cast(c1#x as int) AS c1#x, cast(c2#x as int) AS c2#x] + +- SubqueryAlias t + +- LocalRelation [c1#x, c2#x] + + +-- !query +SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2) +-- !query analysis +Project [c1#x, c2#x, x#x, y#x] ++- LateralJoin lateral-subquery#x [c1#x && c2#x], Inner + : +- Generate TestUDTF(outer(c1#x), outer(c2#x))#x, false, [x#x, y#x] + : +- OneRowRelation + +- SubqueryAlias t + +- Project [x#x AS c1#x, y#x AS c2#x] + +- Generate TestUDTF(1, 2)#x, false, [x#x, y#x] + +- OneRowRelation + + +-- !query +SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1) +-- !query analysis +[Analyzer test output redacted due to nondeterminism] diff --git a/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql new file mode 100644 index 0000000000000..66044604d64c0 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udtf/udtf.sql @@ -0,0 +1,18 @@ +CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2); + +-- test basic udtf +SELECT * FROM udtf(1, 2); +SELECT * FROM udtf(-1, 0); +SELECT * FROM udtf(0, -1); +SELECT * FROM udtf(0, 0); + +-- test column alias +SELECT a, b FROM udtf(1, 2) t(a, b); + +-- test lateral join +SELECT * FROM t1, LATERAL udtf(c1, c2); +SELECT * FROM t1 LEFT JOIN LATERAL udtf(c1, c2); +SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2); + +-- test non-deterministic input +SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1); diff --git a/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out new file mode 100644 index 0000000000000..4f91ed3b70e58 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udtf/udtf.sql.out @@ -0,0 +1,85 @@ +-- Automatically generated by SQLQueryTestSuite +-- !query +CREATE OR REPLACE TEMPORARY VIEW t1 AS VALUES (0, 1), (1, 2) t(c1, c2) +-- !query schema +struct<> +-- !query output + + + +-- !query +SELECT * FROM udtf(1, 2) +-- !query schema +struct +-- !query output +1 -1 +2 1 + + +-- !query +SELECT * FROM udtf(-1, 0) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT * FROM udtf(0, -1) +-- !query schema +struct +-- !query output + + + +-- !query +SELECT * FROM udtf(0, 0) +-- !query schema +struct +-- !query output +0 0 + + +-- !query +SELECT a, b FROM udtf(1, 2) t(a, b) +-- !query schema +struct +-- !query output +1 -1 +2 1 + + +-- !query +SELECT * FROM t1, LATERAL udtf(c1, c2) +-- !query schema +struct +-- !query output +1 2 1 -1 +1 2 2 1 + + +-- !query +SELECT * FROM t1 LEFT JOIN LATERAL udtf(c1, c2) +-- !query schema +struct +-- !query output +1 2 1 -1 +1 2 2 1 + + +-- !query +SELECT * FROM udtf(1, 2) t(c1, c2), LATERAL udtf(c1, c2) +-- !query schema +struct +-- !query output +2 1 1 -1 +2 1 2 1 + + +-- !query +SELECT * FROM udtf(cast(rand(0) AS int) + 1, 1) +-- !query schema +struct +-- !query output +1 0 +1 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala index 10a821c1d1ed7..5883b4e4e3609 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala @@ -319,6 +319,12 @@ object IntegratedUDFTestUtils extends SQLHelper { val prettyName: String } + sealed trait TestUDTF { + def apply(session: SparkSession, exprs: Column*): DataFrame + + val prettyName: String + } + class PythonUDFWithoutId( name: String, func: PythonFunction, @@ -409,6 +415,32 @@ object IntegratedUDFTestUtils extends SQLHelper { udfDeterministic = deterministic) } + case class TestPythonUDTF(name: String) extends TestUDTF { + private val pythonScript: String = + """ + |class TestUDTF: + | def eval(self, a: int, b: int): + | if a > 0 and b > 0: + | yield a, a - b + | yield b, b - a + | elif a == 0 and b == 0: + | yield 0, 0 + | else: + | ... + |""".stripMargin + + private[IntegratedUDFTestUtils] lazy val udtf = createUserDefinedPythonTableFunction( + name = "TestUDTF", + pythonScript = pythonScript, + returnType = StructType.fromDDL("x int, y int") + ) + + def apply(session: SparkSession, exprs: Column*): DataFrame = + udtf.apply(session, exprs: _*) + + val prettyName: String = "Regular Python UDTF" + } + /** * A Scalar Pandas UDF that takes one column, casts into string, executes the * Python native function, and casts back to the type of input column. @@ -588,4 +620,12 @@ object IntegratedUDFTestUtils extends SQLHelper { case udf: TestScalaUDF => session.udf.register(udf.name, udf.udf) case other => throw new RuntimeException(s"Unknown UDF class [${other.getClass}]") } + + /** + * Register UDTFs used in the test cases. + */ + def registerTestUDTF(testUDTF: TestUDTF, session: SparkSession): Unit = testUDTF match { + case udtf: TestPythonUDTF => session.udtf.registerPython(udtf.name, udtf.udtf) + case other => throw new RuntimeException(s"Unknown UDTF class [${other.getClass}]") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index b354642eeded6..71af1fd69c347 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -202,6 +202,10 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper val udf: TestUDF } + protected trait UDTFTest { + val udtf: TestUDTF + } + /** A regular test case. */ protected case class RegularTestCase( name: String, inputFile: String, resultFile: String) extends TestCase { @@ -237,6 +241,16 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper UDFAnalyzerTestCase(newName, inputFile, newResultFile, udf) } + protected case class UDTFTestCase( + name: String, + inputFile: String, + resultFile: String, + udtf: TestUDTF) extends TestCase with UDTFTest { + + override def asAnalyzerTest(newName: String, newResultFile: String): TestCase = + UDTFAnalyzerTestCase(newName, inputFile, newResultFile, udtf) + } + /** A UDAF test case. */ protected case class UDAFTestCase( name: String, @@ -285,6 +299,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper protected case class UDFAnalyzerTestCase( name: String, inputFile: String, resultFile: String, udf: TestUDF) extends AnalyzerTest with UDFTest + protected case class UDTFAnalyzerTestCase( + name: String, inputFile: String, resultFile: String, udtf: TestUDTF) + extends AnalyzerTest with UDTFTest protected case class UDAFAnalyzerTestCase( name: String, inputFile: String, resultFile: String, udf: TestUDF) extends AnalyzerTest with UDFTest @@ -488,6 +505,8 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper testCase match { case udfTestCase: UDFTest => registerTestUDF(udfTestCase.udf, localSparkSession) + case udtfTestCase: UDTFTest => + registerTestUDTF(udtfTestCase.udtf, localSparkSession) case _ => } @@ -574,6 +593,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper shouldTestPandasUDFs => s"${testCase.name}${System.lineSeparator()}" + s"Python: $pythonVer Pandas: $pandasVer PyArrow: $pyarrowVer${System.lineSeparator()}" + case udtfTestCase: UDTFTest + if udtfTestCase.udtf.isInstanceOf[TestPythonUDTF] && shouldTestPythonUDFs => + s"${testCase.name}${System.lineSeparator()}Python: $pythonVer${System.lineSeparator()}" case _ => s"${testCase.name}${System.lineSeparator()}" } @@ -620,6 +642,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper UDAFTestCase( s"$testCaseName - ${udf.prettyName}", absPath, resultFile, udf) } + } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udtf")) { + Seq(TestPythonUDTF("udtf")).map { udtf => + UDTFTestCase( + s"$testCaseName - ${udtf.prettyName}", absPath, resultFile, udtf + ) + } } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}postgreSQL")) { PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}ansi")) { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala index 91c75581b33e2..adbc42ab24587 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerQueryTestSuite.scala @@ -259,6 +259,8 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ Seq.empty } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udaf")) { Seq.empty + } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}udtf")) { + Seq.empty } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}postgreSQL")) { PgSQLTestCase(testCaseName, absPath, resultFile) :: Nil } else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}ansi")) {