From 442fdb8be42789d9a3fac8361f339f4e2a304fb8 Mon Sep 17 00:00:00 2001 From: panbingkun Date: Mon, 3 Jul 2023 15:30:03 +0900 Subject: [PATCH 01/13] [SPARK-43476][PYTHON][TESTS] Enable SeriesStringTests.test_string_replace for pandas 2.0.0 ### What changes were proposed in this pull request? The pr aims to enable SeriesStringTests.test_string_replace for pandas 2.0.0. ### Why are the changes needed? Improve UT coverage. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Pass GA. - Manually test: ''' (base) panbingkun:~/Developer/spark/spark-community$python/run-tests --testnames 'pyspark.pandas.tests.test_series_string SeriesStringTests.test_string_replace' Running PySpark tests. Output is in /Users/panbingkun/Developer/spark/spark-community/python/unit-tests.log Will test against the following Python executables: ['python3.9'] Will test the following Python tests: ['pyspark.pandas.tests.test_series_string SeriesStringTests.test_string_replace'] python3.9 python_implementation is CPython python3.9 version is: Python 3.9.13 Starting test(python3.9): pyspark.pandas.tests.test_series_string SeriesStringTests.test_string_replace (temp output: /Users/panbingkun/Developer/spark/spark-community/python/target/d51a913a-b400-4d1b-adb3-97765bb463bd/python3.9__pyspark.pandas.tests.test_series_string_SeriesStringTests.test_string_replace__izk1fx8o.log) Finished test(python3.9): pyspark.pandas.tests.test_series_string SeriesStringTests.test_string_replace (13s) Tests passed in 13 seconds ''' Closes #41823 from panbingkun/SPARK-43476. Authored-by: panbingkun Signed-off-by: Hyukjin Kwon --- python/pyspark/pandas/tests/test_series_string.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/python/pyspark/pandas/tests/test_series_string.py b/python/pyspark/pandas/tests/test_series_string.py index 3c2bd58da1a28..956567bc5a4ed 100644 --- a/python/pyspark/pandas/tests/test_series_string.py +++ b/python/pyspark/pandas/tests/test_series_string.py @@ -246,10 +246,6 @@ def test_string_repeat(self): with self.assertRaises(TypeError): self.check_func(lambda x: x.str.repeat(repeats=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9])) - @unittest.skipIf( - LooseVersion(pd.__version__) >= LooseVersion("2.0.0"), - "TODO(SPARK-43476): Enable SeriesStringTests.test_string_replace for pandas 2.0.0.", - ) def test_string_replace(self): self.check_func(lambda x: x.str.replace("a.", "xx", regex=True)) self.check_func(lambda x: x.str.replace("a.", "xx", regex=False)) @@ -259,10 +255,11 @@ def test_string_replace(self): def repl(m): return m.group(0)[::-1] - self.check_func(lambda x: x.str.replace(r"[a-z]+", repl)) + regex_pat = re.compile(r"[a-z]+") + self.check_func(lambda x: x.str.replace(regex_pat, repl, regex=True)) # compiled regex with flags regex_pat = re.compile(r"WHITESPACE", flags=re.IGNORECASE) - self.check_func(lambda x: x.str.replace(regex_pat, "---")) + self.check_func(lambda x: x.str.replace(regex_pat, "---", regex=True)) def test_string_rfind(self): self.check_func(lambda x: x.str.rfind("a")) From 356eada314e88a4c0a262c6aa28e76045880e38f Mon Sep 17 00:00:00 2001 From: Joe Wang Date: Mon, 3 Jul 2023 15:37:52 +0900 Subject: [PATCH 02/13] [SPARK-42828][PYTHON][SQL] More explicit Python type annotations for GroupedData ### What changes were proposed in this pull request? Be more explicit in the `Callable` type annotation for `dfapi` and `df_varargs_api` to explicitly return a `DataFrame`. ### Why are the changes needed? In PySpark 3.3.x, type hints now infer the return value of something like `df.groupBy(...).count()` to be `Any`, whereas it should be `DataFrame`. This breaks type checking. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? No runtime changes introduced, so just relied on CI tests. Closes #40460 from j03wang/grouped-data-type. Authored-by: Joe Wang Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/group.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index 9568a971229b5..1b64e7666fd9a 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -32,7 +32,7 @@ __all__ = ["GroupedData"] -def dfapi(f: Callable) -> Callable: +def dfapi(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]: def _api(self: "GroupedData") -> DataFrame: name = f.__name__ jdf = getattr(self._jgd, name)() @@ -43,7 +43,7 @@ def _api(self: "GroupedData") -> DataFrame: return _api -def df_varargs_api(f: Callable) -> Callable: +def df_varargs_api(f: Callable[..., DataFrame]) -> Callable[..., DataFrame]: def _api(self: "GroupedData", *cols: str) -> DataFrame: name = f.__name__ jdf = getattr(self._jgd, name)(_to_seq(self.session._sc, cols)) From 45ae9c5cc67d379f5bbeadf8c56c032f2bdaaac0 Mon Sep 17 00:00:00 2001 From: narek_karapetian Date: Mon, 3 Jul 2023 10:13:12 +0300 Subject: [PATCH 03/13] [SPARK-42169][SQL] Implement code generation for to_csv function (StructsToCsv) ### What changes were proposed in this pull request? This PR enhances `StructsToCsv` class with `doGenCode` function instead of extending it from `CodegenFallback` trait (performance improvement). ### Why are the changes needed? It will improve performance. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? an additional test case were added to `org.apache.spark.sql.CsvFunctionsSuite` class. Closes #39719 from NarekDW/SPARK-42169. Authored-by: narek_karapetian Signed-off-by: Max Gekk --- .../catalyst/expressions/csvExpressions.scala | 11 ++- .../expressions/CsvExpressionsSuite.scala | 7 ++ .../benchmarks/CSVBenchmark-jdk11-results.txt | 82 ++++++++-------- .../benchmarks/CSVBenchmark-jdk17-results.txt | 82 ++++++++-------- sql/core/benchmarks/CSVBenchmark-results.txt | 94 +++++++++---------- 5 files changed, 144 insertions(+), 132 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index e47cf493d4c16..cdab9faacd418 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.csv._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.internal.SQLConf @@ -245,8 +245,7 @@ case class StructsToCsv( options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) - extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes - with NullIntolerant { + extends UnaryExpression with TimeZoneAwareExpression with ExpectsInputTypes with NullIntolerant { override def nullable: Boolean = true def this(options: Map[String, String], child: Expression) = this(options, child, None) @@ -293,4 +292,10 @@ case class StructsToCsv( override protected def withNewChildInternal(newChild: Expression): StructsToCsv = copy(child = newChild) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val structsToCsv = ctx.addReferenceObj("structsToCsv", this) + nullSafeCodeGen(ctx, ev, + eval => s"${ev.value} = (UTF8String) $structsToCsv.converter().apply($eval);") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index 1d174ed214523..a89cb58c3e03b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -246,4 +246,11 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P CsvToStructs(schema, Map.empty, Literal.create("1 day")), InternalRow(new CalendarInterval(0, 1, 0))) } + + test("StructsToCsv should not generate codes beyond 64KB") { + val range = Range.inclusive(1, 5000) + val struct = CreateStruct.create(range.map(Literal.apply)) + val expected = range.mkString(",") + checkEvaluation(StructsToCsv(Map.empty, struct), expected) + } } diff --git a/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt b/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt index 7b5ea10bc4e66..7fca105a8c254 100644 --- a/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt +++ b/sql/core/benchmarks/CSVBenchmark-jdk11-results.txt @@ -2,69 +2,69 @@ Benchmark to measure CSV read/write performance ================================================================================================ -OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Parsing quoted values: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -One quoted string 38218 38618 520 0.0 764362.7 1.0X +One quoted string 43871 44151 336 0.0 877415.7 1.0X -OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Wide rows with 1000 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 1000 columns 97679 98487 1143 0.0 97678.6 1.0X -Select 100 columns 39193 39339 193 0.0 39193.1 2.5X -Select one column 32781 33041 265 0.0 32780.7 3.0X -count() 7154 7228 86 0.1 7153.5 13.7X -Select 100 columns, one bad input field 53968 54158 165 0.0 53967.9 1.8X -Select 100 columns, corrupt record field 59730 60100 484 0.0 59730.2 1.6X +Select 1000 columns 115001 115810 1382 0.0 115001.2 1.0X +Select 100 columns 45575 45646 84 0.0 45575.5 2.5X +Select one column 38701 38744 67 0.0 38700.7 3.0X +count() 8544 8556 12 0.1 8544.0 13.5X +Select 100 columns, one bad input field 67789 67841 79 0.0 67788.5 1.7X +Select 100 columns, corrupt record field 74026 74050 26 0.0 74026.4 1.6X -OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Count a dataset with 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 10 columns + count() 15305 15627 282 0.7 1530.5 1.0X -Select 1 column + count() 13688 13777 106 0.7 1368.8 1.1X -count() 3189 3214 39 3.1 318.9 4.8X +Select 10 columns + count() 16855 16980 179 0.6 1685.5 1.0X +Select 1 column + count() 11053 11075 29 0.9 1105.3 1.5X +count() 3646 3664 17 2.7 364.6 4.6X -OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Create a dataset of timestamps 1630 1641 9 6.1 163.0 1.0X -to_csv(timestamp) 11606 11665 76 0.9 1160.6 0.1X -write timestamps to files 10636 10742 121 0.9 1063.6 0.2X -Create a dataset of dates 1854 1879 25 5.4 185.4 0.9X -to_csv(date) 7522 7563 37 1.3 752.2 0.2X -write dates to files 6435 6526 85 1.6 643.5 0.3X +Create a dataset of timestamps 1864 1904 35 5.4 186.4 1.0X +to_csv(timestamp) 12050 12258 279 0.8 1205.0 0.2X +write timestamps to files 12564 12586 22 0.8 1256.4 0.1X +Create a dataset of dates 2093 2106 20 4.8 209.3 0.9X +to_csv(date) 7216 7236 33 1.4 721.6 0.3X +write dates to files 7300 7382 71 1.4 730.0 0.3X -OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------- -read timestamp text from files 2245 2310 57 4.5 224.5 1.0X -read timestamps from files 27283 27875 513 0.4 2728.3 0.1X -infer timestamps from files 55465 56311 859 0.2 5546.5 0.0X -read date text from files 2054 2088 38 4.9 205.4 1.1X -read date from files 15957 16190 202 0.6 1595.7 0.1X -infer date from files 33163 33319 135 0.3 3316.3 0.1X -timestamp strings 2518 2594 71 4.0 251.8 0.9X -parse timestamps from Dataset[String] 30168 30266 87 0.3 3016.8 0.1X -infer timestamps from Dataset[String] 58608 59332 728 0.2 5860.8 0.0X -date strings 2803 2847 44 3.6 280.3 0.8X -parse dates from Dataset[String] 17613 17877 421 0.6 1761.3 0.1X -from_csv(timestamp) 27736 28241 482 0.4 2773.6 0.1X -from_csv(date) 16415 16816 367 0.6 1641.5 0.1X -infer error timestamps from Dataset[String] with default format 18335 18494 138 0.5 1833.5 0.1X -infer error timestamps from Dataset[String] with user-provided format 18327 18598 422 0.5 1832.7 0.1X -infer error timestamps from Dataset[String] with legacy format 18713 18907 267 0.5 1871.3 0.1X +read timestamp text from files 2432 2458 40 4.1 243.2 1.0X +read timestamps from files 31897 31950 79 0.3 3189.7 0.1X +infer timestamps from files 65093 65196 90 0.2 6509.3 0.0X +read date text from files 2201 2211 15 4.5 220.1 1.1X +read date from files 16138 18869 NaN 0.6 1613.8 0.2X +infer date from files 33633 33742 126 0.3 3363.3 0.1X +timestamp strings 2909 2930 34 3.4 290.9 0.8X +parse timestamps from Dataset[String] 34951 34984 39 0.3 3495.1 0.1X +infer timestamps from Dataset[String] 68347 68448 92 0.1 6834.7 0.0X +date strings 3234 3256 24 3.1 323.4 0.8X +parse dates from Dataset[String] 18591 18657 96 0.5 1859.1 0.1X +from_csv(timestamp) 32386 32476 78 0.3 3238.6 0.1X +from_csv(date) 17333 17402 67 0.6 1733.3 0.1X +infer error timestamps from Dataset[String] with default format 21486 21565 68 0.5 2148.6 0.1X +infer error timestamps from Dataset[String] with user-provided format 21683 21697 16 0.5 2168.3 0.1X +infer error timestamps from Dataset[String] with legacy format 21327 21379 85 0.5 2132.7 0.1X -OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 11.0.19+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -w/o filters 19420 19520 87 0.0 194201.0 1.0X -pushdown disabled 19196 19507 409 0.0 191958.0 1.0X -w/ filters 1380 1402 19 0.1 13796.9 14.1X +w/o filters 22031 22075 46 0.0 220305.7 1.0X +pushdown disabled 21935 21958 21 0.0 219353.1 1.0X +w/ filters 1466 1481 15 0.1 14662.5 15.0X diff --git a/sql/core/benchmarks/CSVBenchmark-jdk17-results.txt b/sql/core/benchmarks/CSVBenchmark-jdk17-results.txt index 9b86f23749645..24c56a42963c3 100644 --- a/sql/core/benchmarks/CSVBenchmark-jdk17-results.txt +++ b/sql/core/benchmarks/CSVBenchmark-jdk17-results.txt @@ -2,69 +2,69 @@ Benchmark to measure CSV read/write performance ================================================================================================ -OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Parsing quoted values: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -One quoted string 41215 41413 184 0.0 824303.0 1.0X +One quoted string 45085 45217 227 0.0 901702.6 1.0X -OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Wide rows with 1000 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 1000 columns 82745 83284 859 0.0 82744.6 1.0X -Select 100 columns 31408 31505 99 0.0 31407.6 2.6X -Select one column 26527 26578 53 0.0 26526.6 3.1X -count() 5168 5214 40 0.2 5167.9 16.0X -Select 100 columns, one bad input field 50701 50802 120 0.0 50700.8 1.6X -Select 100 columns, corrupt record field 55347 55377 27 0.0 55347.2 1.5X +Select 1000 columns 84298 84785 814 0.0 84297.9 1.0X +Select 100 columns 31424 31438 14 0.0 31424.4 2.7X +Select one column 26201 26308 124 0.0 26200.9 3.2X +count() 5215 5226 11 0.2 5214.8 16.2X +Select 100 columns, one bad input field 47515 47615 98 0.0 47514.7 1.8X +Select 100 columns, corrupt record field 52608 52658 62 0.0 52607.6 1.6X -OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Count a dataset with 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 10 columns + count() 14368 14376 12 0.7 1436.8 1.0X -Select 1 column + count() 8791 8834 46 1.1 879.1 1.6X -count() 2597 2613 13 3.8 259.7 5.5X +Select 10 columns + count() 15507 15522 14 0.6 1550.7 1.0X +Select 1 column + count() 9380 9397 15 1.1 938.0 1.7X +count() 2932 2959 40 3.4 293.2 5.3X -OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Create a dataset of timestamps 1448 1475 30 6.9 144.8 1.0X -to_csv(timestamp) 9021 9033 13 1.1 902.1 0.2X -write timestamps to files 8104 8113 8 1.2 810.4 0.2X -Create a dataset of dates 1510 1527 15 6.6 151.0 1.0X -to_csv(date) 6114 6121 12 1.6 611.4 0.2X -write dates to files 5191 5196 5 1.9 519.1 0.3X +Create a dataset of timestamps 1486 1495 8 6.7 148.6 1.0X +to_csv(timestamp) 8333 8351 21 1.2 833.3 0.2X +write timestamps to files 8628 8633 7 1.2 862.8 0.2X +Create a dataset of dates 1698 1713 14 5.9 169.8 0.9X +to_csv(date) 5566 5579 15 1.8 556.6 0.3X +write dates to files 5561 5585 21 1.8 556.1 0.3X -OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------- -read timestamp text from files 1891 1900 11 5.3 189.1 1.0X -read timestamps from files 25100 25122 27 0.4 2510.0 0.1X -infer timestamps from files 50501 50568 110 0.2 5050.1 0.0X -read date text from files 1813 1816 4 5.5 181.3 1.0X -read date from files 15558 15589 27 0.6 1555.8 0.1X -infer date from files 31269 31335 84 0.3 3126.9 0.1X -timestamp strings 2126 2135 10 4.7 212.6 0.9X -parse timestamps from Dataset[String] 27361 27404 46 0.4 2736.1 0.1X -infer timestamps from Dataset[String] 52775 52897 146 0.2 5277.5 0.0X -date strings 2421 2432 19 4.1 242.1 0.8X -parse dates from Dataset[String] 17745 17810 75 0.6 1774.5 0.1X -from_csv(timestamp) 25839 25938 133 0.4 2583.9 0.1X -from_csv(date) 16625 16690 60 0.6 1662.5 0.1X -infer error timestamps from Dataset[String] with default format 20289 20376 76 0.5 2028.9 0.1X -infer error timestamps from Dataset[String] with user-provided format 20245 20326 108 0.5 2024.5 0.1X -infer error timestamps from Dataset[String] with legacy format 20274 20314 36 0.5 2027.4 0.1X +read timestamp text from files 1910 1911 3 5.2 191.0 1.0X +read timestamps from files 26650 26657 7 0.4 2665.0 0.1X +infer timestamps from files 53172 53219 63 0.2 5317.2 0.0X +read date text from files 1859 1863 4 5.4 185.9 1.0X +read date from files 15246 15259 20 0.7 1524.6 0.1X +infer date from files 31002 31006 5 0.3 3100.2 0.1X +timestamp strings 2252 2257 5 4.4 225.2 0.8X +parse timestamps from Dataset[String] 28833 28871 34 0.3 2883.3 0.1X +infer timestamps from Dataset[String] 55417 55526 116 0.2 5541.7 0.0X +date strings 2561 2568 6 3.9 256.1 0.7X +parse dates from Dataset[String] 17580 17601 19 0.6 1758.0 0.1X +from_csv(timestamp) 26802 27121 280 0.4 2680.2 0.1X +from_csv(date) 16119 16126 6 0.6 1611.9 0.1X +infer error timestamps from Dataset[String] with default format 19595 19846 229 0.5 1959.5 0.1X +infer error timestamps from Dataset[String] with user-provided format 19816 19854 37 0.5 1981.6 0.1X +infer error timestamps from Dataset[String] with legacy format 19810 19849 42 0.5 1981.0 0.1X -OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1037-azure +OpenJDK 64-Bit Server VM 17.0.7+7 on Linux 5.15.0-1040-azure Intel(R) Xeon(R) Platinum 8272CL CPU @ 2.60GHz Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -w/o filters 15487 15499 13 0.0 154874.0 1.0X -pushdown disabled 15405 15411 5 0.0 154051.4 1.0X -w/ filters 1166 1174 7 0.1 11660.4 13.3X +w/o filters 16689 16693 5 0.0 166885.8 1.0X +pushdown disabled 16610 16615 5 0.0 166095.3 1.0X +w/ filters 1094 1096 2 0.1 10936.1 15.3X diff --git a/sql/core/benchmarks/CSVBenchmark-results.txt b/sql/core/benchmarks/CSVBenchmark-results.txt index eb1ec99123d23..ff67054b93d54 100644 --- a/sql/core/benchmarks/CSVBenchmark-results.txt +++ b/sql/core/benchmarks/CSVBenchmark-results.txt @@ -2,69 +2,69 @@ Benchmark to measure CSV read/write performance ================================================================================================ -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_372-b07 on Linux 5.15.0-1040-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Parsing quoted values: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -One quoted string 55478 55679 175 0.0 1109556.3 1.0X +One quoted string 43827 44673 740 0.0 876536.0 1.0X -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_372-b07 on Linux 5.15.0-1040-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Wide rows with 1000 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 1000 columns 113407 117690 NaN 0.0 113407.3 1.0X -Select 100 columns 42483 43350 918 0.0 42483.3 2.7X -Select one column 36959 37454 437 0.0 36958.5 3.1X -count() 10248 11871 1413 0.1 10248.2 11.1X -Select 100 columns, one bad input field 61143 61339 276 0.0 61143.4 1.9X -Select 100 columns, corrupt record field 65546 65662 170 0.0 65546.5 1.7X +Select 1000 columns 93035 94150 1041 0.0 93035.3 1.0X +Select 100 columns 34333 34440 185 0.0 34333.3 2.7X +Select one column 28763 28860 116 0.0 28763.1 3.2X +count() 7449 7665 300 0.1 7448.9 12.5X +Select 100 columns, one bad input field 50278 50458 175 0.0 50277.6 1.9X +Select 100 columns, corrupt record field 53481 53833 540 0.0 53480.7 1.7X -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_372-b07 on Linux 5.15.0-1040-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Count a dataset with 10 columns: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Select 10 columns + count() 12993 13063 83 0.8 1299.3 1.0X -Select 1 column + count() 11275 11448 159 0.9 1127.5 1.2X -count() 2804 2870 65 3.6 280.4 4.6X +Select 10 columns + count() 13070 13085 19 0.8 1307.0 1.0X +Select 1 column + count() 11406 11437 35 0.9 1140.6 1.1X +count() 2840 2873 30 3.5 284.0 4.6X -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_372-b07 on Linux 5.15.0-1040-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Write dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -Create a dataset of timestamps 1213 1270 50 8.2 121.3 1.0X -to_csv(timestamp) 9959 9998 45 1.0 995.9 0.1X -write timestamps to files 8851 9069 199 1.1 885.1 0.1X -Create a dataset of dates 1575 1758 283 6.3 157.5 0.8X -to_csv(date) 6708 6761 89 1.5 670.8 0.2X -write dates to files 5294 5330 38 1.9 529.4 0.2X +Create a dataset of timestamps 1150 1169 26 8.7 115.0 1.0X +to_csv(timestamp) 9488 9499 15 1.1 948.8 0.1X +write timestamps to files 9194 9205 13 1.1 919.4 0.1X +Create a dataset of dates 1497 1506 15 6.7 149.7 0.8X +to_csv(date) 6030 6041 18 1.7 603.0 0.2X +write dates to files 5722 5729 7 1.7 572.2 0.2X -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_372-b07 on Linux 5.15.0-1040-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Read dates and timestamps: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ----------------------------------------------------------------------------------------------------------------------------------------------------- -read timestamp text from files 1822 1844 26 5.5 182.2 1.0X -read timestamps from files 26595 26727 194 0.4 2659.5 0.1X -infer timestamps from files 53063 53427 450 0.2 5306.3 0.0X -read date text from files 1621 1656 34 6.2 162.1 1.1X -read date from files 13226 13452 197 0.8 1322.6 0.1X -infer date from files 26920 28034 1013 0.4 2692.0 0.1X -timestamp strings 2663 2721 77 3.8 266.3 0.7X -parse timestamps from Dataset[String] 29204 29608 352 0.3 2920.4 0.1X -infer timestamps from Dataset[String] 57302 57486 198 0.2 5730.2 0.0X -date strings 2835 2890 50 3.5 283.5 0.6X -parse dates from Dataset[String] 15775 15965 184 0.6 1577.5 0.1X -from_csv(timestamp) 27509 27967 418 0.4 2750.9 0.1X -from_csv(date) 14847 15059 325 0.7 1484.7 0.1X -infer error timestamps from Dataset[String] with default format 17424 17695 317 0.6 1742.4 0.1X -infer error timestamps from Dataset[String] with user-provided format 17585 17706 110 0.6 1758.5 0.1X -infer error timestamps from Dataset[String] with legacy format 17775 17855 69 0.6 1777.5 0.1X +read timestamp text from files 1528 1560 28 6.5 152.8 1.0X +read timestamps from files 27594 27600 8 0.4 2759.4 0.1X +infer timestamps from files 54923 54958 49 0.2 5492.3 0.0X +read date text from files 1388 1389 2 7.2 138.8 1.1X +read date from files 13358 13388 43 0.7 1335.8 0.1X +infer date from files 27254 27304 46 0.4 2725.4 0.1X +timestamp strings 2688 2698 11 3.7 268.8 0.6X +parse timestamps from Dataset[String] 30710 30731 21 0.3 3071.0 0.0X +infer timestamps from Dataset[String] 58123 58211 122 0.2 5812.3 0.0X +date strings 2804 2805 1 3.6 280.4 0.5X +parse dates from Dataset[String] 15409 15459 58 0.6 1540.9 0.1X +from_csv(timestamp) 29102 29113 17 0.3 2910.2 0.1X +from_csv(date) 15682 15687 6 0.6 1568.2 0.1X +infer error timestamps from Dataset[String] with default format 17912 17926 12 0.6 1791.2 0.1X +infer error timestamps from Dataset[String] with user-provided format 17892 17911 26 0.6 1789.2 0.1X +infer error timestamps from Dataset[String] with legacy format 17929 17935 10 0.6 1792.9 0.1X -OpenJDK 64-Bit Server VM 1.8.0_362-b09 on Linux 5.15.0-1037-azure -Intel(R) Xeon(R) Platinum 8171M CPU @ 2.60GHz +OpenJDK 64-Bit Server VM 1.8.0_372-b07 on Linux 5.15.0-1040-azure +Intel(R) Xeon(R) Platinum 8370C CPU @ 2.80GHz Filters pushdown: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------ -w/o filters 18371 18553 205 0.0 183711.1 1.0X -pushdown disabled 18462 18770 290 0.0 184620.0 1.0X -w/ filters 1836 1871 50 0.1 18357.8 10.0X +w/o filters 17003 17018 14 0.0 170025.5 1.0X +pushdown disabled 17092 17103 10 0.0 170919.6 1.0X +w/ filters 1340 1352 13 0.1 13395.9 12.7X From 8390b03df62e7f808dc214c69e340fc1e70fb517 Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Mon, 3 Jul 2023 16:26:03 +0900 Subject: [PATCH 04/13] [SPARK-44200][SQL] Support TABLE argument parser rule for TableValuedFunction ### What changes were proposed in this pull request? Adds a new SQL syntax for `TableValuedFunction`. The syntax supports passing such relations one of two ways: 1. `SELECT ... FROM tvf_call(TABLE t)` 2. `SELECT ... FROM tvf_call(TABLE ())` In the former case, the relation argument directly refers to the name of a table in the catalog. In the latter case, the relation argument comprises a table subquery that may itself refer to one or more tables in its own FROM clause. For example, for the given user defined table values function: ```py udtf(returnType="a: int") class TestUDTF: def eval(self, row: Row): if row[0] > 5: yield row[0], spark.udtf.register("test_udtf", TestUDTF) spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id FROM range(0, 8)") ``` , the following SQLs should work: ```py >>> spark.sql("SELECT * FROM test_udtf(TABLE v)").collect() [Row(a=6), Row(a=7)] ``` or ```py >>> spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id + 1 FROM v))").collect() [Row(a=6), Row(a=7), Row(a=8)] ``` ### Why are the changes needed? To support `TABLE` argument parser rule for TableValuedFunction. ### Does this PR introduce _any_ user-facing change? Yes, new syntax for SQL. ### How was this patch tested? Added the related tests. Closes #41750 from ueshin/issues/SPARK-44200/table_argument. Authored-by: Takuya UESHIN Signed-off-by: Hyukjin Kwon --- .../main/resources/error/error-classes.json | 5 + python/pyspark/sql/tests/test_udtf.py | 194 +++++++++++++++++- .../sql/catalyst/parser/SqlBaseParser.g4 | 23 ++- .../sql/catalyst/analysis/Analyzer.scala | 30 ++- ...ctionTableSubqueryArgumentExpression.scala | 65 ++++++ .../sql/catalyst/parser/AstBuilder.scala | 37 +++- .../plans/logical/basicLogicalOperators.scala | 5 + .../sql/catalyst/trees/TreePatterns.scala | 1 + .../sql/errors/QueryCompilationErrors.scala | 7 + .../apache/spark/sql/internal/SQLConf.scala | 10 + .../sql/catalyst/parser/PlanParserSuite.scala | 38 ++++ .../spark/sql/catalyst/plans/PlanTest.scala | 2 + .../sql/errors/QueryParsingErrorsSuite.scala | 1 - 13 files changed, 411 insertions(+), 7 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 027d09eae1045..753701cf581c2 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -2198,6 +2198,11 @@ ], "sqlState" : "42P01" }, + "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS" : { + "message" : [ + "There are too many table arguments for table-valued function. It allows one table argument, but got: . If you want to allow it, please set \"spark.sql.allowMultipleTableArguments.enabled\" to \"true\"" + ] + }, "TASK_WRITE_FAILED" : { "message" : [ "Task failed while writing rows to ." diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index ccf271ceec24c..43ab07950429d 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -27,7 +27,7 @@ from pyspark.testing.sqlutils import ReusedSQLTestCase -class UDTFTestsMixin(ReusedSQLTestCase): +class UDTFTestsMixin: def test_simple_udtf(self): class TestUDTF: def eval(self): @@ -397,6 +397,198 @@ def test_udtf(a: int): with self.assertRaisesRegex(TypeError, err_msg): udtf(test_udtf, returnType="a: int") + def test_udtf_with_table_argument_query(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT id FROM range(0, 8)))").collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_int_and_table_argument_query(self): + class TestUDTF: + def eval(self, i: int, row: Row): + if row["id"] > i: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql( + "SELECT * FROM test_udtf(5, TABLE (SELECT id FROM range(0, 8)))" + ).collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_table_argument_identifier(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + + with self.tempView("v"): + self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id FROM range(0, 8)") + self.assertEqual( + self.spark.sql("SELECT * FROM test_udtf(TABLE v)").collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_int_and_table_argument_identifier(self): + class TestUDTF: + def eval(self, i: int, row: Row): + if row["id"] > i: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + + with self.tempView("v"): + self.spark.sql("CREATE OR REPLACE TEMPORARY VIEW v as SELECT id FROM range(0, 8)") + self.assertEqual( + self.spark.sql("SELECT * FROM test_udtf(5, TABLE v)").collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_table_argument_unknown_identifier(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + + with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"): + self.spark.sql("SELECT * FROM test_udtf(TABLE v)").collect() + + def test_udtf_with_table_argument_malformed_query(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + + with self.assertRaisesRegex(AnalysisException, "TABLE_OR_VIEW_NOT_FOUND"): + self.spark.sql("SELECT * FROM test_udtf(TABLE (SELECT * FROM v))").collect() + + def test_udtf_with_table_argument_cte_inside(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql( + """ + SELECT * FROM test_udtf(TABLE ( + WITH t AS ( + SELECT id FROM range(0, 8) + ) + SELECT * FROM t + )) + """ + ).collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_table_argument_cte_outside(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql( + """ + WITH t AS ( + SELECT id FROM range(0, 8) + ) + SELECT * FROM test_udtf(TABLE (SELECT id FROM t)) + """ + ).collect(), + [Row(a=6), Row(a=7)], + ) + + self.assertEqual( + self.spark.sql( + """ + WITH t AS ( + SELECT id FROM range(0, 8) + ) + SELECT * FROM test_udtf(TABLE t) + """ + ).collect(), + [Row(a=6), Row(a=7)], + ) + + # TODO(SPARK-44233): Fix the subquery resolution. + @unittest.skip("Fails to resolve the subquery.") + def test_udtf_with_table_argument_lateral_join(self): + class TestUDTF: + def eval(self, row: Row): + if row["id"] > 5: + yield row["id"], + + func = udtf(TestUDTF, returnType="a: int") + self.spark.udtf.register("test_udtf", func) + self.assertEqual( + self.spark.sql( + """ + SELECT * FROM + range(0, 8) AS t, + LATERAL test_udtf(TABLE t) + """ + ).collect(), + [Row(a=6), Row(a=7)], + ) + + def test_udtf_with_table_argument_multiple(self): + class TestUDTF: + def eval(self, a: Row, b: Row): + yield a[0], b[0] + + func = udtf(TestUDTF, returnType="a: int, b: int") + self.spark.udtf.register("test_udtf", func) + + query = """ + SELECT * FROM test_udtf( + TABLE (SELECT id FROM range(0, 2)), + TABLE (SELECT id FROM range(0, 3))) + """ + + with self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": False}): + with self.assertRaisesRegex( + AnalysisException, "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS" + ): + self.spark.sql(query).collect() + + with self.sql_conf({"spark.sql.tvf.allowMultipleTableArguments.enabled": True}): + self.assertEqual( + self.spark.sql(query).collect(), + [ + Row(a=0, b=0), + Row(a=1, b=0), + Row(a=0, b=1), + Row(a=1, b=1), + Row(a=0, b=2), + Row(a=1, b=2), + ], + ) + class UDTFTests(UDTFTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 index ab6c0d0861f89..0390785ab5d82 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4 @@ -788,8 +788,29 @@ inlineTable : VALUES expression (COMMA expression)* tableAlias ; +functionTableSubqueryArgument + : TABLE identifierReference + | TABLE LEFT_PAREN query RIGHT_PAREN + ; + +functionTableNamedArgumentExpression + : key=identifier FAT_ARROW table=functionTableSubqueryArgument + ; + +functionTableReferenceArgument + : functionTableSubqueryArgument + | functionTableNamedArgumentExpression + ; + +functionTableArgument + : functionArgument + | functionTableReferenceArgument + ; + functionTable - : funcName=functionName LEFT_PAREN (functionArgument (COMMA functionArgument)*)? RIGHT_PAREN tableAlias + : funcName=functionName LEFT_PAREN + (functionTableArgument (COMMA functionTableArgument)*)? + RIGHT_PAREN tableAlias ; tableAlias diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 47c266e7d18af..94d341ed1d71a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2058,7 +2058,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => withPosition(u) { try { - resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { + val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.name) if (CatalogV2Util.isSessionCatalog(catalog)) { v1SessionCatalog.resolvePersistentTableFunction( @@ -2068,6 +2068,30 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor catalog, "table-valued functions") } } + + val tableArgs = mutable.ArrayBuffer.empty[LogicalPlan] + val tvf = resolvedFunc.transformAllExpressionsWithPruning( + _.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION), ruleId) { + case t: FunctionTableSubqueryArgumentExpression => + val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") + tableArgs.append(SubqueryAlias(alias, t.evaluable)) + UnresolvedAttribute(Seq(alias, "c")) + } + if (tableArgs.nonEmpty) { + if (!conf.tvfAllowMultipleTableArguments && tableArgs.size > 1) { + throw QueryCompilationErrors.tableValuedFunctionTooManyTableArgumentsError( + tableArgs.size) + } + val alias = SubqueryAlias.generateSubqueryName(s"_${tableArgs.size}") + Project( + Seq(UnresolvedStar(Some(Seq(alias)))), + LateralJoin( + tableArgs.reduceLeft(Join(_, _, Inner, None, JoinHint.NONE)), + LateralSubquery(SubqueryAlias(alias, tvf)), Inner, None) + ) + } else { + tvf + } } catch { case _: NoSuchFunctionException => u.failAnalysis( @@ -2416,6 +2440,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor InSubquery(values, expr.asInstanceOf[ListQuery]) case s @ LateralSubquery(sub, _, exprId, _, _) if !sub.resolved => resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId)) + case a @ FunctionTableSubqueryArgumentExpression(sub, _, exprId) if !sub.resolved => + resolveSubQuery(a, outer)(FunctionTableSubqueryArgumentExpression(_, _, exprId)) } } @@ -2436,6 +2462,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor resolveSubQueries(r, r) case j: Join if j.childrenResolved && j.duplicateResolved => resolveSubQueries(j, j) + case tvf: UnresolvedTableValuedFunction => + resolveSubQueries(tvf, tvf) case s: SupportsSubquery if s.childrenResolved => resolveSubQueries(s, s) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala new file mode 100644 index 0000000000000..6d50273125162 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/FunctionTableSubqueryArgumentExpression.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.trees.TreePattern.{FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION, TreePattern} +import org.apache.spark.sql.types.DataType + +/** + * This is the parsed representation of a relation argument for a TableValuedFunction call. + * The syntax supports passing such relations one of two ways: + * + * 1. SELECT ... FROM tvf_call(TABLE t) + * 2. SELECT ... FROM tvf_call(TABLE ()) + * + * In the former case, the relation argument directly refers to the name of a + * table in the catalog. In the latter case, the relation argument comprises + * a table subquery that may itself refer to one or more tables in its own + * FROM clause. + */ +case class FunctionTableSubqueryArgumentExpression( + plan: LogicalPlan, + outerAttrs: Seq[Expression] = Seq.empty, + exprId: ExprId = NamedExpression.newExprId) + extends SubqueryExpression(plan, outerAttrs, exprId, Seq.empty, None) with Unevaluable { + + override def dataType: DataType = plan.schema + override def nullable: Boolean = false + override def withNewPlan(plan: LogicalPlan): FunctionTableSubqueryArgumentExpression = + copy(plan = plan) + override def hint: Option[HintInfo] = None + override def withNewHint(hint: Option[HintInfo]): FunctionTableSubqueryArgumentExpression = + copy() + override def toString: String = s"table-argument#${exprId.id} $conditionString" + override lazy val canonicalized: Expression = { + FunctionTableSubqueryArgumentExpression( + plan.canonicalized, + outerAttrs.map(_.canonicalized), + ExprId(0)) + } + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): FunctionTableSubqueryArgumentExpression = + copy(outerAttrs = newChildren) + + final override def nodePatternsInternal: Seq[TreePattern] = + Seq(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION) + + lazy val evaluable: LogicalPlan = Project(Seq(Alias(CreateStruct(plan.output), "c")()), plan) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 9a395924c451c..488b4e467351f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1551,6 +1551,33 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit RelationTimeTravel(plan, timestamp, version) } + /** + * Create a relation argument for a table-valued function argument. + */ + override def visitFunctionTableSubqueryArgument( + ctx: FunctionTableSubqueryArgumentContext): Expression = withOrigin(ctx) { + val p = Option(ctx.identifierReference).map { r => + createUnresolvedRelation(r) + }.getOrElse { + plan(ctx.query) + } + FunctionTableSubqueryArgumentExpression(p) + } + + private def extractFunctionTableNamedArgument( + expr: FunctionTableReferenceArgumentContext, funcName: String) : Expression = { + Option(expr.functionTableNamedArgumentExpression).map { n => + if (conf.getConf(SQLConf.ALLOW_NAMED_FUNCTION_ARGUMENTS)) { + NamedArgumentExpression( + n.key.getText, visitFunctionTableSubqueryArgument(n.functionTableSubqueryArgument)) + } else { + throw QueryCompilationErrors.namedArgumentsNotEnabledError(funcName, n.key.getText) + } + }.getOrElse { + visitFunctionTableSubqueryArgument(expr.functionTableSubqueryArgument) + } + } + /** * Create a table-valued function call with arguments, e.g. range(1000) */ @@ -1569,8 +1596,12 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit if (ident.length > 1) { throw QueryParsingErrors.invalidTableValuedFunctionNameError(ident, ctx) } - val args = func.functionArgument.asScala.map { e => - extractNamedArgument(e, func.functionName.getText) + val funcName = func.functionName.getText + val args = func.functionTableArgument.asScala.map { e => + Option(e.functionArgument).map(extractNamedArgument(_, funcName)) + .getOrElse { + extractFunctionTableNamedArgument(e.functionTableReferenceArgument, funcName) + } }.toSeq val tvf = UnresolvedTableValuedFunction(ident, args) @@ -1634,7 +1665,7 @@ class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit // normal subquery names, so that parent operators can only access the columns in subquery by // unqualified names. Users can still use this special qualifier to access columns if they // know it, but that's not recommended. - SubqueryAlias("__auto_generated_subquery_name", relation) + SubqueryAlias(SubqueryAlias.generateSubqueryName(), relation) } else { mayApplyAliasPlan(ctx.tableAlias, relation) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index e23966775e9f6..c5ac030484134 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1707,7 +1707,12 @@ object SubqueryAlias { child: LogicalPlan): SubqueryAlias = { SubqueryAlias(AliasIdentifier(multipartIdentifier.last, multipartIdentifier.init), child) } + + def generateSubqueryName(suffix: String = ""): String = { + s"__auto_generated_subquery_name$suffix" + } } + /** * Sample the dataset. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala index 11d5cf54df4b4..b806ebbed52d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala @@ -46,6 +46,7 @@ object TreePattern extends Enumeration { val EXISTS_SUBQUERY = Value val EXPRESSION_WITH_RANDOM_SEED: Value = Value val EXTRACT_VALUE: Value = Value + val FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION: Value = Value val GENERATE: Value = Value val GENERATOR: Value = Value val HIGH_ORDER_FUNCTION: Value = Value diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e02708105d2fe..48223cb34e1a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -1907,6 +1907,13 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase { "ability" -> ability)) } + def tableValuedFunctionTooManyTableArgumentsError(num: Int): Throwable = { + new AnalysisException( + errorClass = "TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS", + messageParameters = Map("num" -> num.toString) + ) + } + def identifierTooManyNamePartsError(originalIdentifier: String): Throwable = { new AnalysisException( errorClass = "IDENTIFIER_TOO_MANY_NAME_PARTS", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 270508139e49b..ecff6bef8aecc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2753,6 +2753,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val TVF_ALLOW_MULTIPLE_TABLE_ARGUMENTS_ENABLED = + buildConf("spark.sql.tvf.allowMultipleTableArguments.enabled") + .doc("When true, allows multiple table arguments for table-valued functions, " + + "receiving the cartesian product of all the rows of these tables.") + .version("3.5.0") + .booleanConf + .createWithDefault(false) + val RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION = buildConf("spark.sql.execution.rangeExchange.sampleSizePerPartition") .internal() @@ -4926,6 +4934,8 @@ class SQLConf extends Serializable with Logging { def supportQuotedRegexColumnName: Boolean = getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME) + def tvfAllowMultipleTableArguments: Boolean = getConf(TVF_ALLOW_MULTIPLE_TABLE_ARGUMENTS_ENABLED) + def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 228a287e14f49..4bad3ced70586 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -1441,6 +1441,44 @@ class PlanParserSuite extends AnalysisTest { NamedArgumentExpression("group", Literal("abc")) :: Nil).select(star())) } + test("table valued function with table arguments") { + assertEqual( + "select * from my_tvf(table v1, table (select 1))", + UnresolvedTableValuedFunction("my_tvf", + FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1"))) :: + FunctionTableSubqueryArgumentExpression( + Project(Seq(UnresolvedAlias(Literal(1))), OneRowRelation())) :: Nil).select(star())) + + // All named arguments + assertEqual( + "select * from my_tvf(arg1 => table v1, arg2 => table (select 1))", + UnresolvedTableValuedFunction("my_tvf", + NamedArgumentExpression("arg1", + FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1")))) :: + NamedArgumentExpression("arg2", + FunctionTableSubqueryArgumentExpression( + Project(Seq(UnresolvedAlias(Literal(1))), OneRowRelation()))) :: Nil).select(star())) + + // Unnamed and named arguments + assertEqual( + "select * from my_tvf(2, table v1, arg1 => table (select 1))", + UnresolvedTableValuedFunction("my_tvf", + Literal(2) :: + FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1"))) :: + NamedArgumentExpression("arg1", + FunctionTableSubqueryArgumentExpression( + Project(Seq(UnresolvedAlias(Literal(1))), OneRowRelation()))) :: Nil).select(star())) + + // Mixed arguments + assertEqual( + "select * from my_tvf(arg1 => table v1, 2, arg2 => true)", + UnresolvedTableValuedFunction("my_tvf", + NamedArgumentExpression("arg1", + FunctionTableSubqueryArgumentExpression(UnresolvedRelation(Seq("v1")))) :: + Literal(2) :: + NamedArgumentExpression("arg2", Literal(true)) :: Nil).select(star())) + } + test("SPARK-32106: TRANSFORM plan") { // verify schema less assertEqual( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 911ddfeb13b4b..ebf48c5f863d2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -96,6 +96,8 @@ trait PlanTestBase extends PredicateHelper with SQLHelper with SQLConfHelper { s udf.copy(resultId = ExprId(0)) case udaf: PythonUDAF => udaf.copy(resultId = ExprId(0)) + case a: FunctionTableSubqueryArgumentExpression => + a.copy(plan = normalizeExprIds(a.plan), exprId = ExprId(0)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala index a7d5046245df9..2731760f7ef05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/errors/QueryParsingErrorsSuite.scala @@ -401,7 +401,6 @@ class QueryParsingErrorsSuite extends QueryTest with SharedSparkSession with SQL checkParseSyntaxError("select * from my_tvf(arg1 => )", "')'") checkParseSyntaxError("select * from my_tvf(arg1 => , 42)", "','") checkParseSyntaxError("select * from my_tvf(my_tvf.arg1 => 'value1')", "'=>'") - checkParseSyntaxError("select * from my_tvf(arg1 => table t1)", "'t1'", hint = ": extra input 't1'") } test("PARSE_SYNTAX_ERROR: extraneous input") { From f7e8da3fb6dba44ab339a095c46cc8872d83b741 Mon Sep 17 00:00:00 2001 From: Jia Fan Date: Mon, 3 Jul 2023 19:40:57 +0900 Subject: [PATCH 05/13] [SPARK-44200][SQL][FOLLOWUP] Add `TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS` error into doc ### What changes were proposed in this pull request? This is a followup PR for https://github.com/apache/spark/pull/41750, because we add test for sync doc and `error-classes.json` after #41813 . We should add `TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS` (add on https://github.com/apache/spark/pull/41750) into doc. ### Why are the changes needed? sync error and doc ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? exist test. Closes #41827 from Hisoka-X/SPARK-44200_follow_up_error_json_doc. Authored-by: Jia Fan Signed-off-by: Hyukjin Kwon --- docs/sql-error-conditions.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 47836e9cc05f7..35af4db69aba8 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -1568,6 +1568,12 @@ If you did not qualify the name with a schema, verify the current_schema() outpu To tolerate the error on drop use DROP VIEW IF EXISTS or DROP TABLE IF EXISTS. +### TABLE_VALUED_FUNCTION_TOO_MANY_TABLE_ARGUMENTS + +SQLSTATE: none assigned + +There are too many table arguments for table-valued function. It allows one table argument, but got: ``. If you want to allow it, please set "spark.sql.allowMultipleTableArguments.enabled" to "true" + ### TASK_WRITE_FAILED SQLSTATE: none assigned From d9ce141c50183202dc34a3cd3f5a67060ed1a596 Mon Sep 17 00:00:00 2001 From: Vihang Karajgaonkar Date: Mon, 3 Jul 2023 16:06:43 -0700 Subject: [PATCH 06/13] [SPARK-44199] CacheManager refreshes the fileIndex unnecessarily ### What changes were proposed in this pull request? The `CacheManager` refreshFileIndexIfNecessary logic checks if the fileIndex root paths starts with the input path. This is problematic if the input path and root path share the prefixes but the root path is not a subdirectory of the input path. In such cases, the CacheManager can unnecessarily refresh the fileIndex which can fail the query if it does not have access to the rootPath for that SparkSession. ### Why are the changes needed? Fixes the bug where the queries on cached dataframe APIs can fail if the cached path shares prefix with the different path. ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? Unit test Closes #41749 from vihangk1/master_cachemanager. Lead-authored-by: Vihang Karajgaonkar Co-authored-by: Vihang Karajgaonkar Signed-off-by: Wenchen Fan --- .../apache/spark/sql/catalog/Catalog.scala | 5 ++- .../spark/sql/execution/CacheManager.scala | 24 +++++++++-- .../apache/spark/sql/CacheManagerSuite.scala | 40 +++++++++++++++++++ 3 files changed, 63 insertions(+), 6 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CacheManagerSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala index 93ff3059f6264..13b199948e0f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala @@ -632,8 +632,9 @@ abstract class Catalog { /** * Invalidates and refreshes all the cached data (and the associated metadata) for any `Dataset` - * that contains the given data source path. Path matching is by prefix, i.e. "/" would invalidate - * everything that is cached. + * that contains the given data source path. Path matching is by checking for sub-directories, + * i.e. "/" would invalidate everything that is cached and "/test/parent" would invalidate + * everything that is a subdirectory of "/test/parent". * * @since 2.0.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index b1153d7a1e86c..2afb82cdbc78d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -359,21 +359,37 @@ class CacheManager extends Logging with AdaptiveSparkPlanHelper { } /** - * Refresh the given [[FileIndex]] if any of its root paths starts with `qualifiedPath`. + * Refresh the given [[FileIndex]] if any of its root paths is a subdirectory + * of the `qualifiedPath`. * @return whether the [[FileIndex]] is refreshed. */ private def refreshFileIndexIfNecessary( fileIndex: FileIndex, fs: FileSystem, qualifiedPath: Path): Boolean = { - val prefixToInvalidate = qualifiedPath.toString val needToRefresh = fileIndex.rootPaths - .map(_.makeQualified(fs.getUri, fs.getWorkingDirectory).toString) - .exists(_.startsWith(prefixToInvalidate)) + .map(_.makeQualified(fs.getUri, fs.getWorkingDirectory)) + .exists(isSubDir(qualifiedPath, _)) if (needToRefresh) fileIndex.refresh() needToRefresh } + /** + * Checks if the given child path is a sub-directory of the given parent path. + * @param qualifiedPathChild: + * Fully qualified child path + * @param qualifiedPathParent: + * Fully qualified parent path. + * @return + * True if the child path is a sub-directory of the given parent path. Otherwise, false. + */ + def isSubDir(qualifiedPathParent: Path, qualifiedPathChild: Path): Boolean = { + Iterator + .iterate(qualifiedPathChild)(_.getParent) + .takeWhile(_ != null) + .exists(_.equals(qualifiedPathParent)) + } + /** * If CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING is enabled, just return original session. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CacheManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CacheManagerSuite.scala new file mode 100644 index 0000000000000..fb8e82dbf90d6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CacheManagerSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.test.SharedSparkSession + +class CacheManagerSuite extends SparkFunSuite with SharedSparkSession { + + test("SPARK-44199: isSubDirectory tests") { + val cacheManager = spark.sharedState.cacheManager + val testCases = Map[(String, String), Boolean]( + ("s3://bucket/a/b", "s3://bucket/a/b/c") -> true, + ("s3://bucket/a/b/c", "s3://bucket/a/b/c") -> true, + ("s3://bucket/a/b/c", "s3://bucket/a/b") -> false, + ("s3://bucket/a/z/c", "s3://bucket/a/b/c") -> false, + ("s3://bucket/a/b/c", "abfs://bucket/a/b/c") -> false) + testCases.foreach { test => + val result = cacheManager.isSubDir(new Path(test._1._1), new Path(test._1._2)) + assert(result == test._2) + } + } +} From 9cb557d12ab37b3f7580e526d89eef8f2ef94bc6 Mon Sep 17 00:00:00 2001 From: Xinrong Meng Date: Tue, 4 Jul 2023 08:14:04 +0900 Subject: [PATCH 07/13] [SPARK-44150][PYTHON][CONNECT] Explicit Arrow casting for mismatched return type in Arrow Python UDF ### What changes were proposed in this pull request? Explicit Arrow casting for the mismatched return type of Arrow Python UDF. ### Why are the changes needed? A more standardized and coherent type coercion. Please refer to https://github.com/apache/spark/pull/41706 for a comprehensive comparison between type coercion rules of Arrow and Pickle(used by the default Python UDF) separately. See more at [[Design] Type-coercion in Arrow Python UDFs](https://docs.google.com/document/d/e/2PACX-1vTEGElOZfhl9NfgbBw4CTrlm-8F_xQCAKNOXouz-7mg5vYobS7lCGUsGkDZxPY0wV5YkgoZmkYlxccU/pub). ### Does this PR introduce _any_ user-facing change? Yes. FROM ```py >>> df = spark.createDataFrame(['1', '2'], schema='string') df.select(pandas_udf(lambda x: x, 'int')('value')).show() >>> df.select(pandas_udf(lambda x: x, 'int')('value')).show() ... org.apache.spark.api.python.PythonException: Traceback (most recent call last): ... pyarrow.lib.ArrowInvalid: Could not convert '1' with type str: tried to convert to int32 ``` TO ```py >>> df = spark.createDataFrame(['1', '2'], schema='string') >>> df.select(pandas_udf(lambda x: x, 'int')('value')).show() +---------------+ |(value)| +---------------+ | 1| | 2| +---------------+ ``` ### How was this patch tested? Unit tests. Closes #41800 from xinrong-meng/snd_type_coersion. Authored-by: Xinrong Meng Signed-off-by: Hyukjin Kwon --- python/pyspark/sql/pandas/serializers.py | 35 ++++++++++++++--- .../sql/tests/test_arrow_python_udf.py | 39 +++++++++++++++++++ python/pyspark/worker.py | 3 ++ 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/pandas/serializers.py b/python/pyspark/sql/pandas/serializers.py index 307fcc33752b0..4c095249957c0 100644 --- a/python/pyspark/sql/pandas/serializers.py +++ b/python/pyspark/sql/pandas/serializers.py @@ -190,7 +190,7 @@ def arrow_to_pandas(self, arrow_column, struct_in_pandas="dict", ndarray_as_list ) return converter(s) - def _create_array(self, series, arrow_type, spark_type=None): + def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): """ Create an Arrow Array from the given pandas.Series and optional type. @@ -202,6 +202,9 @@ def _create_array(self, series, arrow_type, spark_type=None): If None, pyarrow's inferred type will be used spark_type : DataType, optional If None, spark type converted from arrow_type will be used + arrow_cast: bool, optional + Whether to apply Arrow casting when the user-specified return type mismatches the + actual return values. Returns ------- @@ -226,7 +229,17 @@ def _create_array(self, series, arrow_type, spark_type=None): else: mask = series.isnull() try: - return pa.Array.from_pandas(series, mask=mask, type=arrow_type, safe=self._safecheck) + try: + return pa.Array.from_pandas( + series, mask=mask, type=arrow_type, safe=self._safecheck + ) + except pa.lib.ArrowInvalid: + if arrow_cast: + return pa.Array.from_pandas(series, mask=mask).cast( + target_type=arrow_type, safe=self._safecheck + ) + else: + raise except TypeError as e: error_msg = ( "Exception thrown when converting pandas.Series (%s) " @@ -319,12 +332,14 @@ def __init__( df_for_struct=False, struct_in_pandas="dict", ndarray_as_list=False, + arrow_cast=False, ): super(ArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck) self._assign_cols_by_name = assign_cols_by_name self._df_for_struct = df_for_struct self._struct_in_pandas = struct_in_pandas self._ndarray_as_list = ndarray_as_list + self._arrow_cast = arrow_cast def arrow_to_pandas(self, arrow_column): import pyarrow.types as types @@ -386,7 +401,13 @@ def _create_batch(self, series): # Assign result columns by schema name if user labeled with strings elif self._assign_cols_by_name and any(isinstance(name, str) for name in s.columns): arrs_names = [ - (self._create_array(s[field.name], field.type), field.name) for field in t + ( + self._create_array( + s[field.name], field.type, arrow_cast=self._arrow_cast + ), + field.name, + ) + for field in t ] # Assign result columns by position else: @@ -394,7 +415,11 @@ def _create_batch(self, series): # the selected series has name '1', so we rename it to field.name # as the name is used by _create_array to provide a meaningful error message ( - self._create_array(s[s.columns[i]].rename(field.name), field.type), + self._create_array( + s[s.columns[i]].rename(field.name), + field.type, + arrow_cast=self._arrow_cast, + ), field.name, ) for i, field in enumerate(t) @@ -403,7 +428,7 @@ def _create_batch(self, series): struct_arrs, struct_names = zip(*arrs_names) arrs.append(pa.StructArray.from_arrays(struct_arrs, struct_names)) else: - arrs.append(self._create_array(s, t)) + arrs.append(self._create_array(s, t, arrow_cast=self._arrow_cast)) return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) diff --git a/python/pyspark/sql/tests/test_arrow_python_udf.py b/python/pyspark/sql/tests/test_arrow_python_udf.py index 0accb0f3cc110..264ea0b901f43 100644 --- a/python/pyspark/sql/tests/test_arrow_python_udf.py +++ b/python/pyspark/sql/tests/test_arrow_python_udf.py @@ -17,6 +17,8 @@ import unittest +from pyspark.errors import PythonException +from pyspark.sql import Row from pyspark.sql.functions import udf from pyspark.sql.tests.test_udf import BaseUDFTestsMixin from pyspark.testing.sqlutils import ( @@ -141,6 +143,43 @@ def test_nested_array_input(self): "[[1, 2], [3, 4]]", ) + def test_type_coercion_string_to_numeric(self): + df_int_value = self.spark.createDataFrame(["1", "2"], schema="string") + df_floating_value = self.spark.createDataFrame(["1.1", "2.2"], schema="string") + + int_ddl_types = ["tinyint", "smallint", "int", "bigint"] + floating_ddl_types = ["double", "float"] + + for ddl_type in int_ddl_types: + # df_int_value + res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) + self.assertEquals(res.collect(), [Row(res=1), Row(res=2)]) + self.assertEquals(res.dtypes[0][1], ddl_type) + + floating_results = [ + [Row(res=1.1), Row(res=2.2)], + [Row(res=1.100000023841858), Row(res=2.200000047683716)], + ] + for ddl_type, floating_res in zip(floating_ddl_types, floating_results): + # df_int_value + res = df_int_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) + self.assertEquals(res.collect(), [Row(res=1.0), Row(res=2.0)]) + self.assertEquals(res.dtypes[0][1], ddl_type) + # df_floating_value + res = df_floating_value.select(udf(lambda x: x, ddl_type)("value").alias("res")) + self.assertEquals(res.collect(), floating_res) + self.assertEquals(res.dtypes[0][1], ddl_type) + + # invalid + with self.assertRaises(PythonException): + df_floating_value.select(udf(lambda x: x, "int")("value").alias("res")).collect() + + with self.assertRaises(PythonException): + df_int_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect() + + with self.assertRaises(PythonException): + df_floating_value.select(udf(lambda x: x, "decimal")("value").alias("res")).collect() + class PythonUDFArrowTests(PythonUDFArrowTestsMixin, ReusedSQLTestCase): @classmethod diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index b24600b0c1b0f..1d28e6add2eb7 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -589,6 +589,8 @@ def read_udfs(pickleSer, infile, eval_type): "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict" ) ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion + arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF ser = ArrowStreamPandasUDFSerializer( timezone, safecheck, @@ -596,6 +598,7 @@ def read_udfs(pickleSer, infile, eval_type): df_for_struct, struct_in_pandas, ndarray_as_list, + arrow_cast, ) else: ser = BatchedSerializer(CPickleSerializer(), 100) From 4cdb6d487ed18891bb7f63f9fb20f33cbbcc26c2 Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 3 Jul 2023 20:04:17 -0700 Subject: [PATCH 08/13] [SPARK-44266][SQL] Move Util.truncatedString to sql/api ### What changes were proposed in this pull request? Move Util.truncatedString to sql/api. ### Why are the changes needed? Make StructType depends less on Catalyst so towards simpler DataType interface. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test Closes #41811 from amaliujia/move_out_truncatedString. Authored-by: Rui Wang Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/util/StringUtils.scala | 39 +++++++++++++++++++ .../spark/sql/catalyst/util/package.scala | 20 +--------- .../apache/spark/sql/types/StructType.scala | 5 +-- 3 files changed, 43 insertions(+), 21 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 10ac988da2efe..384453e3b5379 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -16,8 +16,11 @@ */ package org.apache.spark.sql.catalyst.util +import java.util.concurrent.atomic.AtomicBoolean + import scala.collection.mutable.ArrayBuffer +import org.apache.spark.internal.Logging import org.apache.spark.unsafe.array.ByteArrayUtils /** @@ -63,3 +66,39 @@ class StringConcat(val maxLength: Int = ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH) result.toString } } + +object SparkStringUtils extends Logging { + /** Whether we have warned about plan string truncation yet. */ + private val truncationWarningPrinted = new AtomicBoolean(false) + + /** + * Format a sequence with semantics similar to calling .mkString(). Any elements beyond + * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. + * + * @return the trimmed and formatted string. + */ + def truncatedString[T]( + seq: Seq[T], + start: String, + sep: String, + end: String, + maxFields: Int): String = { + if (seq.length > maxFields) { + if (truncationWarningPrinted.compareAndSet(false, true)) { + logWarning( + "Truncated the string representation of a plan since it was too large. This " + + s"behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.") + } + val numFields = math.max(0, maxFields - 1) + seq.take(numFields).mkString( + start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) + } else { + seq.mkString(start, sep, end) + } + } + + /** Shorthand for calling truncatedString() without start or end strings. */ + def truncatedString[T](seq: Seq[T], sep: String, maxFields: Int): String = { + truncatedString(seq, "", sep, "", maxFields) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index c7c226f01dbc3..0555d8d5fa451 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -20,13 +20,11 @@ package org.apache.spark.sql.catalyst import java.io._ import java.nio.charset.Charset import java.nio.charset.StandardCharsets.UTF_8 -import java.util.concurrent.atomic.AtomicBoolean import com.google.common.io.ByteStreams import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{MetadataBuilder, NumericType, StringType, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -155,9 +153,6 @@ package object util extends Logging { builder.toString() } - /** Whether we have warned about plan string truncation yet. */ - private val truncationWarningPrinted = new AtomicBoolean(false) - /** * Format a sequence with semantics similar to calling .mkString(). Any elements beyond * maxNumToStringFields will be dropped and replaced by a "... N more fields" placeholder. @@ -170,23 +165,12 @@ package object util extends Logging { sep: String, end: String, maxFields: Int): String = { - if (seq.length > maxFields) { - if (truncationWarningPrinted.compareAndSet(false, true)) { - logWarning( - "Truncated the string representation of a plan since it was too large. This " + - s"behavior can be adjusted by setting '${SQLConf.MAX_TO_STRING_FIELDS.key}'.") - } - val numFields = math.max(0, maxFields - 1) - seq.take(numFields).mkString( - start, sep, sep + "... " + (seq.length - numFields) + " more fields" + end) - } else { - seq.mkString(start, sep, end) - } + SparkStringUtils.truncatedString(seq, start, sep, end, maxFields) } /** Shorthand for calling truncatedString() without start or end strings. */ def truncatedString[T](seq: Seq[T], sep: String, maxFields: Int): String = { - truncatedString(seq, "", sep, "", maxFields) + SparkStringUtils.truncatedString(seq, "", sep, "", maxFields) } val METADATA_COL_ATTR_KEY = "__metadata_col" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index dad8252e5ca5b..5857aaa95305b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -29,9 +29,8 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, LegacyTypeStringParser} import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.catalyst.types.DataTypeUtils +import org.apache.spark.sql.catalyst.util.{SparkStringUtils, StringConcat} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.catalyst.util.StringConcat -import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.Utils @@ -423,7 +422,7 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru override def simpleString: String = { val fieldTypes = fields.view.map(field => s"${field.name}:${field.dataType.simpleString}").toSeq - truncatedString( + SparkStringUtils.truncatedString( fieldTypes, "struct<", ",", ">", SQLConf.get.maxToStringFields) From 4c2ee76d2afa63a7a7c0334fd8f4763d3d87ddbb Mon Sep 17 00:00:00 2001 From: Rui Wang Date: Mon, 3 Jul 2023 20:05:00 -0700 Subject: [PATCH 09/13] [SPARK-44274][CONNECT] Move out util functions used by ArtifactManager to common/utils ### What changes were proposed in this pull request? Move out util functions used by ArtifactManager to `common/utils`. More specific, move `resolveURI` and `awaitResult` to `common/utils`. ### Why are the changes needed? So that Spark Connect Scala client does not need to depend on Spark. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing test Closes #41825 from amaliujia/SPARK-44273. Authored-by: Rui Wang Signed-off-by: Wenchen Fan --- .../spark/util/SparkFatalException.scala | 0 .../apache/spark/util/SparkFileUtils.scala | 47 +++++++++++++++ .../apache/spark/util/SparkThreadUtils.scala | 60 +++++++++++++++++++ .../sql/connect/client/ArtifactManager.scala | 6 +- .../org/apache/spark/util/ThreadUtils.scala | 15 +---- .../scala/org/apache/spark/util/Utils.scala | 17 +----- 6 files changed, 112 insertions(+), 33 deletions(-) rename {core => common/utils}/src/main/scala/org/apache/spark/util/SparkFatalException.scala (100%) create mode 100644 common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala create mode 100644 common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala diff --git a/core/src/main/scala/org/apache/spark/util/SparkFatalException.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkFatalException.scala similarity index 100% rename from core/src/main/scala/org/apache/spark/util/SparkFatalException.scala rename to common/utils/src/main/scala/org/apache/spark/util/SparkFatalException.scala diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala new file mode 100644 index 0000000000000..63d1ab4799ab2 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkFileUtils.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.io.File +import java.net.{URI, URISyntaxException} + +private[spark] object SparkFileUtils { + /** + * Return a well-formed URI for the file described by a user input string. + * + * If the supplied path does not contain a scheme, or is a relative path, it will be + * converted into an absolute path with a file:// scheme. + */ + def resolveURI(path: String): URI = { + try { + val uri = new URI(path) + if (uri.getScheme() != null) { + return uri + } + // make sure to handle if the path has a fragment (applies to yarn + // distributed cache) + if (uri.getFragment() != null) { + val absoluteURI = new File(uri.getPath()).getAbsoluteFile().toURI() + return new URI(absoluteURI.getScheme(), absoluteURI.getHost(), absoluteURI.getPath(), + uri.getFragment()) + } + } catch { + case e: URISyntaxException => + } + new File(path).getCanonicalFile().toURI() + } +} diff --git a/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala new file mode 100644 index 0000000000000..ec14688a00625 --- /dev/null +++ b/common/utils/src/main/scala/org/apache/spark/util/SparkThreadUtils.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.util + +import java.util.concurrent.TimeoutException + +import scala.concurrent.Awaitable +import scala.concurrent.duration.Duration +import scala.util.control.NonFatal + +import org.apache.spark.SparkException + +private[spark] object SparkThreadUtils { + // scalastyle:off awaitresult + /** + * Preferred alternative to `Await.result()`. + * + * This method wraps and re-throws any exceptions thrown by the underlying `Await` call, ensuring + * that this thread's stack trace appears in logs. + * + * In addition, it calls `Awaitable.result` directly to avoid using `ForkJoinPool`'s + * `BlockingContext`. Codes running in the user's thread may be in a thread of Scala ForkJoinPool. + * As concurrent executions in ForkJoinPool may see some [[ThreadLocal]] value unexpectedly, this + * method basically prevents ForkJoinPool from running other tasks in the current waiting thread. + * In general, we should use this method because many places in Spark use [[ThreadLocal]] and it's + * hard to debug when [[ThreadLocal]]s leak to other tasks. + */ + @throws(classOf[SparkException]) + def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { + try { + // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. + // See SPARK-13747. + val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] + awaitable.result(atMost)(awaitPermission) + } catch { + case e: SparkFatalException => + throw e.throwable + // TimeoutException and RpcAbortException is thrown in the current thread, so not need to warp + // the exception. + case NonFatal(t) + if !t.isInstanceOf[TimeoutException] => + throw new SparkException("Exception thrown in awaitResult: ", t) + } + } + // scalastyle:on awaitresult +} diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala index 6d0d16df946eb..0ed1670f990ea 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala @@ -38,7 +38,7 @@ import org.apache.commons.codec.digest.DigestUtils.sha256Hex import org.apache.spark.connect.proto import org.apache.spark.connect.proto.AddArtifactsResponse import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.{SparkFileUtils, SparkThreadUtils} /** * The Artifact Manager is responsible for handling and transferring artifacts from the local @@ -71,7 +71,7 @@ class ArtifactManager( * Currently only local files with extensions .jar and .class are supported. */ def addArtifact(path: String): Unit = { - addArtifact(Utils.resolveURI(path)) + addArtifact(SparkFileUtils.resolveURI(path)) } private def parseArtifacts(uri: URI): Seq[Artifact] = { @@ -201,7 +201,7 @@ class ArtifactManager( writeBatch() } stream.onCompleted() - ThreadUtils.awaitResult(promise.future, Duration.Inf) + SparkThreadUtils.awaitResult(promise.future, Duration.Inf) // TODO(SPARK-42658): Handle responses containing CRC failures. } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 303493ef91aef..16d7de56c39eb 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -307,20 +307,7 @@ private[spark] object ThreadUtils { */ @throws(classOf[SparkException]) def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { - try { - // `awaitPermission` is not actually used anywhere so it's safe to pass in null here. - // See SPARK-13747. - val awaitPermission = null.asInstanceOf[scala.concurrent.CanAwait] - awaitable.result(atMost)(awaitPermission) - } catch { - case e: SparkFatalException => - throw e.throwable - // TimeoutException and RpcAbortException is thrown in the current thread, so not need to warp - // the exception. - case NonFatal(t) - if !t.isInstanceOf[TimeoutException] => - throw new SparkException("Exception thrown in awaitResult: ", t) - } + SparkThreadUtils.awaitResult(awaitable, atMost) } // scalastyle:on awaitresult diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 60895c791b5a6..b5c0ee1bab8bc 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2085,22 +2085,7 @@ private[spark] object Utils extends Logging with SparkClassUtils { * converted into an absolute path with a file:// scheme. */ def resolveURI(path: String): URI = { - try { - val uri = new URI(path) - if (uri.getScheme() != null) { - return uri - } - // make sure to handle if the path has a fragment (applies to yarn - // distributed cache) - if (uri.getFragment() != null) { - val absoluteURI = new File(uri.getPath()).getAbsoluteFile().toURI() - return new URI(absoluteURI.getScheme(), absoluteURI.getHost(), absoluteURI.getPath(), - uri.getFragment()) - } - } catch { - case e: URISyntaxException => - } - new File(path).getCanonicalFile().toURI() + SparkFileUtils.resolveURI(path) } /** Resolve a comma-separated list of paths. */ From 4fcecfe17f2d54e14ac204bbdd97104828bbf2af Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Tue, 4 Jul 2023 12:47:25 +0900 Subject: [PATCH 10/13] [SPARK-44194][PYTHON][CORE] Add JobTag APIs to PySpark SparkContext ### What changes were proposed in this pull request? This PR proposes to add: - `SparkContext.setInterruptOnCancel(interruptOnCancel: Boolean): Unit` - `SparkContext.addJobTag(tag: String): Unit` - `SparkContext.removeJobTag(tag: String): Unit` - `SparkContext.getJobTags(): Set[String]` - `SparkContext.clearJobTags(): Unit` - `SparkContext.cancelJobsWithTag(tag: String): Unit` into PySpark. See also SPARK-43952. ### Why are the changes needed? For PySpark users, and feature parity. ### Does this PR introduce _any_ user-facing change? Yes, it adds new API in PySpark. ### How was this patch tested? Unittests were added. Closes #41841 from HyukjinKwon/SPARK-44194. Authored-by: Hyukjin Kwon Signed-off-by: Hyukjin Kwon --- python/docs/source/reference/pyspark.rst | 6 + python/pyspark/context.py | 177 ++++++++++++++++++++++- python/pyspark/tests/test_pin_thread.py | 35 +++-- 3 files changed, 207 insertions(+), 11 deletions(-) diff --git a/python/docs/source/reference/pyspark.rst b/python/docs/source/reference/pyspark.rst index ec3df07163921..9a6fbb651716f 100644 --- a/python/docs/source/reference/pyspark.rst +++ b/python/docs/source/reference/pyspark.rst @@ -55,6 +55,7 @@ Spark Context APIs SparkContext.accumulator SparkContext.addArchive SparkContext.addFile + SparkContext.addJobTag SparkContext.addPyFile SparkContext.applicationId SparkContext.binaryFiles @@ -62,12 +63,15 @@ Spark Context APIs SparkContext.broadcast SparkContext.cancelAllJobs SparkContext.cancelJobGroup + SparkContext.cancelJobsWithTag + SparkContext.clearJobTags SparkContext.defaultMinPartitions SparkContext.defaultParallelism SparkContext.dump_profiles SparkContext.emptyRDD SparkContext.getCheckpointDir SparkContext.getConf + SparkContext.getJobTags SparkContext.getLocalProperty SparkContext.getOrCreate SparkContext.hadoopFile @@ -80,9 +84,11 @@ Spark Context APIs SparkContext.pickleFile SparkContext.range SparkContext.resources + SparkContext.removeJobTag SparkContext.runJob SparkContext.sequenceFile SparkContext.setCheckpointDir + SparkContext.setInterruptOnCancel SparkContext.setJobDescription SparkContext.setJobGroup SparkContext.setLocalProperty diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 51a4db67e8cdc..4867ce2ae2925 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -40,6 +40,7 @@ Type, TYPE_CHECKING, TypeVar, + Set, ) from py4j.java_collections import JavaMap @@ -2164,6 +2165,160 @@ def setJobGroup(self, groupId: str, description: str, interruptOnCancel: bool = """ self._jsc.setJobGroup(groupId, description, interruptOnCancel) + def setInterruptOnCancel(self, interruptOnCancel: bool) -> None: + """ + Set the behavior of job cancellation from jobs started in this thread. + + .. versionadded:: 3.5.0 + + Parameters + ---------- + interruptOnCancel : bool + If true, then job cancellation will result in ``Thread.interrupt()`` + being called on the job's executor threads. This is useful to help ensure that + the tasks are actually stopped in a timely manner, but is off by default due to + HDFS-1208, where HDFS may respond to ``Thread.interrupt()`` by marking nodes as dead. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.cancelAllJobs` + :meth:`SparkContext.cancelJobGroup` + :meth:`SparkContext.cancelJobsWithTag` + """ + self._jsc.setInterruptOnCancel(interruptOnCancel) + + def addJobTag(self, tag: str) -> None: + """ + Add a tag to be assigned to all the jobs started by this thread. + + Parameters + ---------- + tag : str + The tag to be added. Cannot contain ',' (comma) character. + + See Also + -------- + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.getJobTags` + :meth:`SparkContext.clearJobTags` + :meth:`SparkContext.cancelJobsWithTag` + :meth:`SparkContext.setInterruptOnCancel` + + Examples + -------- + >>> import threading + >>> from time import sleep + >>> from pyspark import InheritableThread + >>> sc.setInterruptOnCancel(interruptOnCancel=True) + >>> result = "Not Set" + >>> lock = threading.Lock() + >>> def map_func(x): + ... sleep(100) + ... raise RuntimeError("Task should have been cancelled") + ... + >>> def start_job(x): + ... global result + ... try: + ... sc.addJobTag("job_to_cancel") + ... result = sc.parallelize(range(x)).map(map_func).collect() + ... except Exception as e: + ... result = "Cancelled" + ... lock.release() + ... + >>> def stop_job(): + ... sleep(5) + ... sc.cancelJobsWithTag("job_to_cancel") + ... + >>> suppress = lock.acquire() + >>> suppress = InheritableThread(target=start_job, args=(10,)).start() + >>> suppress = InheritableThread(target=stop_job).start() + >>> suppress = lock.acquire() + >>> print(result) + Cancelled + >>> sc.clearJobTags() + """ + self._jsc.addJobTag(tag) + + def removeJobTag(self, tag: str) -> None: + """ + Remove a tag previously added to be assigned to all the jobs started by this thread. + Noop if such a tag was not added earlier. + + Parameters + ---------- + tag : str + The tag to be removed. Cannot contain ',' (comma) character. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.getJobTags` + :meth:`SparkContext.clearJobTags` + :meth:`SparkContext.cancelJobsWithTag` + :meth:`SparkContext.setInterruptOnCancel` + + Examples + -------- + >>> sc.addJobTag("job_to_cancel1") + >>> sc.addJobTag("job_to_cancel2") + >>> sc.getJobTags() + {'job_to_cancel1', 'job_to_cancel2'} + >>> sc.removeJobTag("job_to_cancel1") + >>> sc.getJobTags() + {'job_to_cancel2'} + >>> sc.clearJobTags() + """ + self._jsc.removeJobTag(tag) + + def getJobTags(self) -> Set[str]: + """ + Get the tags that are currently set to be assigned to all the jobs started by this thread. + + Returns + ------- + set of str + the tags that are currently set to be assigned to all the jobs started by this thread. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.clearJobTags` + :meth:`SparkContext.cancelJobsWithTag` + :meth:`SparkContext.setInterruptOnCancel` + + Examples + -------- + >>> sc.addJobTag("job_to_cancel") + >>> sc.getJobTags() + {'job_to_cancel'} + >>> sc.clearJobTags() + """ + return self._jsc.getJobTags() + + def clearJobTags(self) -> None: + """ + Clear the current thread's job tags. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.getJobTags` + :meth:`SparkContext.cancelJobsWithTag` + :meth:`SparkContext.setInterruptOnCancel` + + Examples + -------- + >>> sc.addJobTag("job_to_cancel") + >>> sc.clearJobTags() + >>> sc.getJobTags() + set() + """ + self._jsc.clearJobTags() + def setLocalProperty(self, key: str, value: str) -> None: """ Set a local property that affects jobs submitted from this thread, such as the @@ -2243,10 +2398,29 @@ def cancelJobGroup(self, groupId: str) -> None: See Also -------- :meth:`SparkContext.setJobGroup` - :meth:`SparkContext.cancelJobGroup` """ self._jsc.sc().cancelJobGroup(groupId) + def cancelJobsWithTag(self, tag: str) -> None: + """ + Cancel active jobs that have the specified tag. See + :meth:`SparkContext.addJobTag`. + + Parameters + ---------- + tag : str + The tag to be cancelled. Cannot contain ',' (comma) character. + + See Also + -------- + :meth:`SparkContext.addJobTag` + :meth:`SparkContext.removeJobTag` + :meth:`SparkContext.getJobTags` + :meth:`SparkContext.clearJobTags` + :meth:`SparkContext.setInterruptOnCancel` + """ + return self._jsc.cancelJobsWithTag(tag) + def cancelAllJobs(self) -> None: """ Cancel all jobs that have been scheduled or are running. @@ -2256,6 +2430,7 @@ def cancelAllJobs(self) -> None: See Also -------- :meth:`SparkContext.cancelJobGroup` + :meth:`SparkContext.cancelJobsWithTag` :meth:`SparkContext.runJob` """ self._jsc.sc().cancelAllJobs() diff --git a/python/pyspark/tests/test_pin_thread.py b/python/pyspark/tests/test_pin_thread.py index dd291b8a0cc9e..975b549808933 100644 --- a/python/pyspark/tests/test_pin_thread.py +++ b/python/pyspark/tests/test_pin_thread.py @@ -83,10 +83,25 @@ def test_local_property(): assert len(set(jvm_thread_ids)) == 10 def test_multiple_group_jobs(self): - # SPARK-22340 Add a mode to pin Python thread into JVM's - - group_a = "job_ids_to_cancel" - group_b = "job_ids_to_run" + # SPARK-22340: Add a mode to pin Python thread into JVM's + self.check_job_cancellation( + lambda job_group: self.sc.setJobGroup( + job_group, "test rdd collect with setting job group" + ), + lambda job_group: self.sc.cancelJobGroup(job_group), + ) + + def test_multiple_group_tags(self): + # SPARK-44194: Test pinned thread mode with job tags. + self.check_job_cancellation( + lambda job_tag: self.sc.addJobTag(job_tag), + lambda job_tag: self.sc.cancelJobsWithTag(job_tag), + ) + + def check_job_cancellation(self, setter, canceller): + + job_id_a = "job_ids_to_cancel" + job_id_b = "job_ids_to_run" threads = [] thread_ids = range(4) @@ -97,13 +112,13 @@ def test_multiple_group_jobs(self): # The index of the array is the thread index which job run in. is_job_cancelled = [False for _ in thread_ids] - def run_job(job_group, index): + def run_job(job_id, index): """ Executes a job with the group ``job_group``. Each job waits for 3 seconds and then exits. """ try: - self.sc.setJobGroup(job_group, "test rdd collect with setting job group") + setter(job_id) self.sc.parallelize([15]).map(lambda x: time.sleep(x)).collect() is_job_cancelled[index] = False except Exception: @@ -111,24 +126,24 @@ def run_job(job_group, index): is_job_cancelled[index] = True # Test if job succeeded when not cancelled. - run_job(group_a, 0) + run_job(job_id_a, 0) self.assertFalse(is_job_cancelled[0]) # Run jobs for i in thread_ids_to_cancel: - t = threading.Thread(target=run_job, args=(group_a, i)) + t = threading.Thread(target=run_job, args=(job_id_a, i)) t.start() threads.append(t) for i in thread_ids_to_run: - t = threading.Thread(target=run_job, args=(group_b, i)) + t = threading.Thread(target=run_job, args=(job_id_b, i)) t.start() threads.append(t) # Wait to make sure all jobs are executed. time.sleep(3) # And then, cancel one job group. - self.sc.cancelJobGroup(group_a) + canceller(job_id_a) # Wait until all threads launching jobs are finished. for t in threads: From b573cca90ea843f8b492c5b1a72463854d1568c2 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Tue, 4 Jul 2023 13:06:32 +0900 Subject: [PATCH 11/13] [SPARK-44288][SS] Set the column family options before passing to DBOptions in RocksDB state store provider ### What changes were proposed in this pull request? Set the column family options before passing to DBOptions in RocksDB state store provider ### Why are the changes needed? Address bug fix to ensure column family options around memory usage are passed correctly to dbOptions ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests Closes #41840 from anishshri-db/task/SPARK-44288. Authored-by: Anish Shrigondekar Signed-off-by: Jungtaek Lim --- .../spark/sql/execution/streaming/state/RocksDB.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala index a9c15cf7f7d74..65299ea37eff2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala @@ -92,9 +92,6 @@ class RocksDB( private val columnFamilyOptions = new ColumnFamilyOptions() - private val dbOptions = - new Options(new DBOptions(), columnFamilyOptions) // options to open the RocksDB - // Set RocksDB options around MemTable memory usage. By default, we let RocksDB // use its internal default values for these settings. if (conf.writeBufferSizeMB > 0L) { @@ -105,6 +102,9 @@ class RocksDB( columnFamilyOptions.setMaxWriteBufferNumber(conf.maxWriteBufferNumber) } + private val dbOptions = + new Options(new DBOptions(), columnFamilyOptions) // options to open the RocksDB + dbOptions.setCreateIfMissing(true) dbOptions.setTableFormatConfig(tableFormatConfig) dbOptions.setMaxOpenFiles(conf.maxOpenFiles) From 7bc28d54f83261b16eaa11201a7987d8d2c8dd1e Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Tue, 4 Jul 2023 08:07:55 +0300 Subject: [PATCH 12/13] [SPARK-44269][SQL] Assign names to the error class _LEGACY_ERROR_TEMP_[2310-2314] ### What changes were proposed in this pull request? The pr aims to assign names to the error class _LEGACY_ERROR_TEMP_[2310-2314]. ### Why are the changes needed? Improve the error framework. ### Does this PR introduce _any_ user-facing change? 'No'. ### How was this patch tested? Exists test cases updated and added new test cases. Closes #41816 from beliefer/SPARK-44269. Authored-by: Jiaan Geng Signed-off-by: Max Gekk --- .../main/resources/error/error-classes.json | 25 ++++--------------- docs/sql-error-conditions.md | 6 +++++ .../sql/catalyst/analysis/CheckAnalysis.scala | 11 +++----- .../analysis/AnalysisErrorSuite.scala | 11 ++++---- .../scala/org/apache/spark/sql/Dataset.scala | 8 +++--- .../spark/sql/DataFrameWriterV2Suite.scala | 19 ++++++++++++++ .../test/DataStreamReaderWriterSuite.scala | 20 +++++++-------- 7 files changed, 54 insertions(+), 46 deletions(-) diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 753701cf581c2..6a72fc5449eed 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -91,6 +91,11 @@ ], "sqlState" : "22003" }, + "CALL_ON_STREAMING_DATASET_UNSUPPORTED" : { + "message" : [ + "The method can not be called on streaming Dataset/DataFrame." + ] + }, "CANNOT_CAST_DATATYPE" : { "message" : [ "Cannot cast to ." @@ -5609,26 +5614,6 @@ "The input '' does not match the given number format: ''." ] }, - "_LEGACY_ERROR_TEMP_2311" : { - "message" : [ - "'writeTo' can not be called on streaming Dataset/DataFrame." - ] - }, - "_LEGACY_ERROR_TEMP_2312" : { - "message" : [ - "'write' can not be called on streaming Dataset/DataFrame." - ] - }, - "_LEGACY_ERROR_TEMP_2313" : { - "message" : [ - "Hint not found: ." - ] - }, - "_LEGACY_ERROR_TEMP_2314" : { - "message" : [ - "cannot resolve '' due to argument data type mismatch: " - ] - }, "_LEGACY_ERROR_TEMP_2315" : { "message" : [ "cannot resolve '' due to data type mismatch: ." diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index 35af4db69aba8..1b799ed7b67c8 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -111,6 +111,12 @@ Unable to find batch ``. `` `` `` caused overflow. +### CALL_ON_STREAMING_DATASET_UNSUPPORTED + +SQLSTATE: none assigned + +The method `` can not be called on streaming Dataset/DataFrame. + ### CANNOT_CAST_DATATYPE [SQLSTATE: 42846](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 2b4753a027d87..11387fde37e52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -209,9 +209,10 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB u.origin) case u: UnresolvedHint => - u.failAnalysis( - errorClass = "_LEGACY_ERROR_TEMP_2313", - messageParameters = Map("name" -> u.name)) + throw SparkException.internalError( + msg = s"Hint not found: ${toSQLId(u.name)}", + context = u.origin.getQueryContext, + summary = u.origin.context.summary) case command: V2PartitionCommand => command.table match { @@ -245,10 +246,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog with QueryErrorsB hof.checkArgumentDataTypes() match { case checkRes: TypeCheckResult.DataTypeMismatch => hof.dataTypeMismatch(hof, checkRes) - case TypeCheckResult.TypeCheckFailure(message) => - hof.failAnalysis( - errorClass = "_LEGACY_ERROR_TEMP_2314", - messageParameters = Map("sqlExpr" -> hof.sql, "msg" -> message)) case checkRes: TypeCheckResult.InvalidFormat => hof.setTagValue(INVALID_FORMAT_ERROR, true) hof.invalidFormat(checkRes) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index fdaeadc544500..6c43f84e8d0b7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.Assertions._ import org.apache.spark.SparkException -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -1166,10 +1165,12 @@ class AnalysisErrorSuite extends AnalysisTest { ) assert(plan.resolved) - val error = intercept[AnalysisException] { - SimpleAnalyzer.checkAnalysis(plan) - } - assert(error.message.contains(s"Hint not found: ${hintName}")) + checkError( + exception = intercept[SparkException] { + SimpleAnalyzer.checkAnalysis(plan) + }, + errorClass = "INTERNAL_ERROR", + parameters = Map("message" -> "Hint not found: `some_random_hint_that_does_not_exist`")) // UnresolvedHint be removed by batch `Remove Unresolved Hints` assertAnalysisSuccess(plan, true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3e0c692b00f8b..c87f95294bfa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -4036,8 +4036,8 @@ class Dataset[T] private[sql]( def write: DataFrameWriter[T] = { if (isStreaming) { logicalPlan.failAnalysis( - errorClass = "_LEGACY_ERROR_TEMP_2312", - messageParameters = Map.empty) + errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + messageParameters = Map("methodName" -> toSQLId("write"))) } new DataFrameWriter[T](this) } @@ -4065,8 +4065,8 @@ class Dataset[T] private[sql]( // TODO: streaming could be adapted to use this interface if (isStreaming) { logicalPlan.failAnalysis( - errorClass = "_LEGACY_ERROR_TEMP_2311", - messageParameters = Map.empty) + errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + messageParameters = Map("methodName" -> toSQLId("writeTo"))) } new DataFrameWriterV2[T](table, this) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala index 507207a2fdd26..f8128c8c23e68 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWriterV2Suite.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.FakeSourceOne import org.apache.spark.sql.test.SharedSparkSession @@ -767,4 +768,22 @@ class DataFrameWriterV2Suite extends QueryTest with SharedSparkSession with Befo assert(table.partitioning === Seq(BucketTransform(LiteralValue(4, IntegerType), Seq(FieldReference(Seq("ts", "timezone")))))) } + + test("can not be called on streaming Dataset/DataFrame") { + val ds = MemoryStream[Int].toDS() + + checkError( + exception = intercept[AnalysisException] { + ds.write + }, + errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + parameters = Map("methodName" -> "`write`")) + + checkError( + exception = intercept[AnalysisException] { + ds.writeTo("testcat.table_name") + }, + errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + parameters = Map("methodName" -> "`writeTo`")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 07a9ec4fdce06..d03e8bcad937b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -121,16 +121,16 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } test("write cannot be called on streaming datasets") { - val e = intercept[AnalysisException] { - spark.readStream - .format("org.apache.spark.sql.streaming.test") - .load() - .write - .save() - } - Seq("'write'", "not", "streaming Dataset/DataFrame").foreach { s => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(s.toLowerCase(Locale.ROOT))) - } + checkError( + exception = intercept[AnalysisException] { + spark.readStream + .format("org.apache.spark.sql.streaming.test") + .load() + .write + .save() + }, + errorClass = "CALL_ON_STREAMING_DATASET_UNSUPPORTED", + parameters = Map("methodName" -> "`write`")) } test("resolve default source") { From 7fcabef28743099363c5cf21b90e987e6be90b12 Mon Sep 17 00:00:00 2001 From: Weichen Xu Date: Tue, 4 Jul 2023 14:01:49 +0800 Subject: [PATCH 13/13] [SPARK-44250][ML][PYTHON][CONNECT] Implement classification evaluator ### What changes were proposed in this pull request? Implement classification evaluator ### Why are the changes needed? Distributed ML <> spark connect project. ### Does this PR introduce _any_ user-facing change? Yes. `BinaryClassificationEvaluator` and `MulticlassClassificationEvaluator` are added. ### How was this patch tested? Closes #41793 from WeichenXu123/classification-evaluator. Authored-by: Weichen Xu Signed-off-by: Weichen Xu --- python/pyspark/ml/connect/evaluation.py | 161 ++++++++++++++---- .../connect/test_legacy_mode_evaluation.py | 77 ++++++++- 2 files changed, 202 insertions(+), 36 deletions(-) diff --git a/python/pyspark/ml/connect/evaluation.py b/python/pyspark/ml/connect/evaluation.py index c10599cf49f49..0606c7cad7df8 100644 --- a/python/pyspark/ml/connect/evaluation.py +++ b/python/pyspark/ml/connect/evaluation.py @@ -14,22 +14,61 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import numpy as np import pandas as pd -from typing import Any, Union +from typing import Any, Union, List, Tuple from pyspark.ml.param import Param, Params, TypeConverters -from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol +from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasProbabilityCol from pyspark.ml.connect.base import Evaluator from pyspark.ml.connect.io_utils import ParamsReadWrite from pyspark.ml.connect.util import aggregate_dataframe from pyspark.sql import DataFrame -import torch -import torcheval.metrics as torchmetrics +class _TorchMetricEvaluator(Evaluator): -class RegressionEvaluator(Evaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite): + metricName: Param[str] = Param( + Params._dummy(), + "metricName", + "metric name for the regression evaluator, valid values are 'mse' and 'r2'", + typeConverter=TypeConverters.toString, + ) + + def _get_torch_metric(self) -> Any: + raise NotImplementedError() + + def _get_input_cols(self) -> List[str]: + raise NotImplementedError() + + def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]: + raise NotImplementedError() + + def _evaluate(self, dataset: Union["DataFrame", "pd.DataFrame"]) -> float: + torch_metric = self._get_torch_metric() + + def local_agg_fn(pandas_df: "pd.DataFrame") -> "pd.DataFrame": + torch_metric.update(*self._get_metric_update_inputs(pandas_df)) + return torch_metric + + def merge_agg_state(state1: Any, state2: Any) -> Any: + state1.merge_state([state2]) + return state1 + + def agg_state_to_result(state: Any) -> Any: + return state.compute().item() + + return aggregate_dataframe( + dataset, + self._get_input_cols(), + local_agg_fn, + merge_agg_state, + agg_state_to_result, + ) + + +class RegressionEvaluator(_TorchMetricEvaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite): """ Evaluator for Regression, which expects input columns prediction and label. Supported metrics are 'mse' and 'r2'. @@ -41,14 +80,9 @@ def __init__(self, metricName: str, labelCol: str, predictionCol: str) -> None: super().__init__() self._set(metricName=metricName, labelCol=labelCol, predictionCol=predictionCol) - metricName: Param[str] = Param( - Params._dummy(), - "metricName", - "metric name for the regression evaluator, valid values are 'mse' and 'r2'", - typeConverter=TypeConverters.toString, - ) - def _get_torch_metric(self) -> Any: + import torcheval.metrics as torchmetrics + metric_name = self.getOrDefault(self.metricName) if metric_name == "mse": @@ -58,32 +92,89 @@ def _get_torch_metric(self) -> Any: raise ValueError(f"Unsupported regressor evaluator metric name: {metric_name}") - def _evaluate(self, dataset: Union["DataFrame", "pd.DataFrame"]) -> float: - prediction_col = self.getPredictionCol() - label_col = self.getLabelCol() + def _get_input_cols(self) -> List[str]: + return [self.getPredictionCol(), self.getLabelCol()] - torch_metric = self._get_torch_metric() + def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]: + import torch - def local_agg_fn(pandas_df: "pd.DataFrame") -> "pd.DataFrame": - with torch.inference_mode(): - preds_tensor = torch.tensor(pandas_df[prediction_col].values) - labels_tensor = torch.tensor(pandas_df[label_col].values) - torch_metric.update(preds_tensor, labels_tensor) - return torch_metric + preds_tensor = torch.tensor(dataset[self.getPredictionCol()].values) + labels_tensor = torch.tensor(dataset[self.getLabelCol()].values) + return preds_tensor, labels_tensor - def merge_agg_state(state1: Any, state2: Any) -> Any: - with torch.inference_mode(): - state1.merge_state([state2]) - return state1 - def agg_state_to_result(state: Any) -> Any: - with torch.inference_mode(): - return state.compute().item() +class BinaryClassificationEvaluator( + _TorchMetricEvaluator, HasLabelCol, HasProbabilityCol, ParamsReadWrite +): + """ + Evaluator for binary classification, which expects input columns prediction and label. + Supported metrics are 'areaUnderROC' and 'areaUnderPR'. - return aggregate_dataframe( - dataset, - [prediction_col, label_col], - local_agg_fn, - merge_agg_state, - agg_state_to_result, + .. versionadded:: 3.5.0 + """ + + def __init__(self, metricName: str, labelCol: str, probabilityCol: str) -> None: + super().__init__() + self._set(metricName=metricName, labelCol=labelCol, probabilityCol=probabilityCol) + + def _get_torch_metric(self) -> Any: + import torcheval.metrics as torchmetrics + + metric_name = self.getOrDefault(self.metricName) + + if metric_name == "areaUnderROC": + return torchmetrics.BinaryAUROC() + if metric_name == "areaUnderPR": + return torchmetrics.BinaryAUPRC() + + raise ValueError(f"Unsupported binary classification evaluator metric name: {metric_name}") + + def _get_input_cols(self) -> List[str]: + return [self.getProbabilityCol(), self.getLabelCol()] + + def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]: + import torch + + values = np.stack(dataset[self.getProbabilityCol()].values) # type: ignore[call-overload] + preds_tensor = torch.tensor(values) + if preds_tensor.dim() == 2: + preds_tensor = preds_tensor[:, 1] + labels_tensor = torch.tensor(dataset[self.getLabelCol()].values) + return preds_tensor, labels_tensor + + +class MulticlassClassificationEvaluator( + _TorchMetricEvaluator, HasLabelCol, HasPredictionCol, ParamsReadWrite +): + """ + Evaluator for multiclass classification, which expects input columns prediction and label. + Supported metrics are 'accuracy'. + + .. versionadded:: 3.5.0 + """ + + def __init__(self, metricName: str, labelCol: str, predictionCol: str) -> None: + super().__init__() + self._set(metricName=metricName, labelCol=labelCol, predictionCol=predictionCol) + + def _get_torch_metric(self) -> Any: + import torcheval.metrics as torchmetrics + + metric_name = self.getOrDefault(self.metricName) + + if metric_name == "accuracy": + return torchmetrics.MulticlassAccuracy() + + raise ValueError( + f"Unsupported multiclass classification evaluator metric name: {metric_name}" ) + + def _get_input_cols(self) -> List[str]: + return [self.getPredictionCol(), self.getLabelCol()] + + def _get_metric_update_inputs(self, dataset: "pd.DataFrame") -> Tuple[Any, Any]: + import torch + + preds_tensor = torch.tensor(dataset[self.getPredictionCol()].values) + labels_tensor = torch.tensor(dataset[self.getLabelCol()].values) + return preds_tensor, labels_tensor diff --git a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py index 3db00d6661b7b..51c3bb26db898 100644 --- a/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py +++ b/python/pyspark/ml/tests/connect/test_legacy_mode_evaluation.py @@ -18,7 +18,11 @@ import unittest import numpy as np -from pyspark.ml.connect.evaluation import RegressionEvaluator +from pyspark.ml.connect.evaluation import ( + RegressionEvaluator, + BinaryClassificationEvaluator, + MulticlassClassificationEvaluator, +) from pyspark.sql import SparkSession @@ -66,6 +70,77 @@ def test_regressor_evaluator(self): np.testing.assert_almost_equal(r2, expected_r2) np.testing.assert_almost_equal(r2_local, expected_r2) + def test_binary_classifier_evaluator(self): + df1 = self.spark.createDataFrame( + [ + (1, 0.2, [0.8, 0.2]), + (0, 0.6, [0.4, 0.6]), + (1, 0.8, [0.2, 0.8]), + (1, 0.7, [0.3, 0.7]), + (0, 0.4, [0.6, 0.4]), + (0, 0.3, [0.7, 0.3]), + ], + schema=["label", "prob", "prob2"], + ) + + local_df1 = df1.toPandas() + + for prob_col in ["prob", "prob2"]: + auroc_evaluator = BinaryClassificationEvaluator( + metricName="areaUnderROC", + labelCol="label", + probabilityCol=prob_col, + ) + + expected_auroc = 0.6667 + auroc = auroc_evaluator.evaluate(df1) + auroc_local = auroc_evaluator.evaluate(local_df1) + np.testing.assert_almost_equal(auroc, expected_auroc, decimal=2) + np.testing.assert_almost_equal(auroc_local, expected_auroc, decimal=2) + + auprc_evaluator = BinaryClassificationEvaluator( + metricName="areaUnderPR", + labelCol="label", + probabilityCol=prob_col, + ) + + expected_auprc = 0.8333 + auprc = auprc_evaluator.evaluate(df1) + auprc_local = auprc_evaluator.evaluate(local_df1) + np.testing.assert_almost_equal(auprc, expected_auprc, decimal=2) + np.testing.assert_almost_equal(auprc_local, expected_auprc, decimal=2) + + def test_multiclass_classifier_evaluator(self): + df1 = self.spark.createDataFrame( + [ + (1, 1), + (1, 1), + (2, 3), + (0, 0), + (0, 1), + (3, 1), + (3, 3), + (2, 2), + (1, 0), + (2, 2), + ], + schema=["label", "prediction"], + ) + + local_df1 = df1.toPandas() + + accuracy_evaluator = MulticlassClassificationEvaluator( + metricName="accuracy", + labelCol="label", + predictionCol="prediction", + ) + + expected_accuracy = 0.600 + accuracy = accuracy_evaluator.evaluate(df1) + accuracy_local = accuracy_evaluator.evaluate(local_df1) + np.testing.assert_almost_equal(accuracy, expected_accuracy, decimal=2) + np.testing.assert_almost_equal(accuracy_local, expected_accuracy, decimal=2) + @unittest.skipIf(not have_torcheval, "torcheval is required") class EvaluationTests(EvaluationTestsMixin, unittest.TestCase):