diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 43ac6b50052ae..14d93a498fc59 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -264,20 +264,20 @@ jobs: with: distribution: zulu java-version: ${{ matrix.java }} - - name: Install Python 3.9 + - name: Install Python 3.11 uses: actions/setup-python@v5 # We should install one Python that is higher than 3+ for SQL and Yarn because: # - SQL component also has Python related tests, for example, IntegratedUDFTestUtils. # - Yarn has a Python specific test too, for example, YarnClusterSuite. if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') with: - python-version: '3.9' + python-version: '3.11' architecture: x64 - - name: Install Python packages (Python 3.9) + - name: Install Python packages (Python 3.11) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') run: | - python3.9 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==4.25.1' - python3.9 -m pip list + python3.11 -m pip install 'numpy>=1.20.0' pyarrow pandas scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.62.0' 'grpcio-status==1.62.0' 'protobuf==4.25.1' + python3.11 -m pip list # Run the tests. - name: Run tests env: ${{ fromJSON(inputs.envs) }} @@ -608,14 +608,14 @@ jobs: with: input: sql/connect/common/src/main against: 'https://github.com/apache/spark.git#branch=branch-3.5,subdir=connector/connect/common/src/main' - - name: Install Python 3.9 + - name: Install Python 3.11 uses: actions/setup-python@v5 with: - python-version: '3.9' + python-version: '3.11' - name: Install dependencies for Python CodeGen check run: | - python3.9 -m pip install 'black==23.9.1' 'protobuf==4.25.1' 'mypy==1.8.0' 'mypy-protobuf==3.3.0' - python3.9 -m pip list + python3.11 -m pip install 'black==23.9.1' 'protobuf==4.25.1' 'mypy==1.8.0' 'mypy-protobuf==3.3.0' + python3.11 -m pip list - name: Python CodeGen check run: ./dev/connect-check-protos.py diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index f7dd261c10fd2..49000c62d1063 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -57,6 +57,7 @@ Collate: 'types.R' 'utils.R' 'window.R' + 'zzz.R' RoxygenNote: 7.1.2 VignetteBuilder: knitr NeedsCompilation: no diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 29c05b0db7c2d..1b5faad376eaa 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -403,12 +403,6 @@ sparkR.session <- function( sparkPackages = "", enableHiveSupport = TRUE, ...) { - - if (Sys.getenv("SPARKR_SUPPRESS_DEPRECATION_WARNING") == "") { - warning( - "SparkR is deprecated from Apache Spark 4.0.0 and will be removed in a future version.") - } - sparkConfigMap <- convertNamedListToEnv(sparkConfig) namedParams <- list(...) if (length(namedParams) > 0) { diff --git a/R/pkg/R/zzz.R b/R/pkg/R/zzz.R new file mode 100644 index 0000000000000..947bd543b75e0 --- /dev/null +++ b/R/pkg/R/zzz.R @@ -0,0 +1,30 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# zzz.R - package startup message + +.onAttach <- function(...) { + if (Sys.getenv("SPARKR_SUPPRESS_DEPRECATION_WARNING") == "") { + packageStartupMessage( + paste0( + "Warning: ", + "SparkR is deprecated in Apache Spark 4.0.0 and will be removed in a future release. ", + "To continue using Spark in R, we recommend using sparklyr instead: ", + "https://spark.posit.co/get-started/" + ) + ) + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index fb610a5d96f17..d67697eaea38b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -1363,9 +1363,9 @@ public static UTF8String trimRight( public static UTF8String[] splitSQL(final UTF8String input, final UTF8String delim, final int limit, final int collationId) { - if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collationId).isUtf8BinaryType) { return input.split(delim, limit); - } else if (CollationFactory.fetchCollation(collationId).supportsLowercaseEquality) { + } else if (CollationFactory.fetchCollation(collationId).isUtf8LcaseType) { return lowercaseSplitSQL(input, delim, limit); } else { return icuSplitSQL(input, delim, limit, collationId); diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 01f6c7e0331b0..50bb93465921e 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -160,6 +160,18 @@ public static class Collation { */ public final boolean supportsSpaceTrimming; + /** + * Is Utf8 binary type as indicator if collation base type is UTF8 binary. Note currently only + * collations Utf8_Binary and Utf8_Binary_RTRIM are considered as Utf8 binary type. + */ + public final boolean isUtf8BinaryType; + + /** + * Is Utf8 lcase type as indicator if collation base type is UTF8 lcase. Note currently only + * collations Utf8_Lcase and Utf8_Lcase_RTRIM are considered as Utf8 Lcase type. + */ + public final boolean isUtf8LcaseType; + public Collation( String collationName, String provider, @@ -168,9 +180,8 @@ public Collation( String version, ToLongFunction hashFunction, BiFunction equalsFunction, - boolean supportsBinaryEquality, - boolean supportsBinaryOrdering, - boolean supportsLowercaseEquality, + boolean isUtf8BinaryType, + boolean isUtf8LcaseType, boolean supportsSpaceTrimming) { this.collationName = collationName; this.provider = provider; @@ -178,14 +189,13 @@ public Collation( this.comparator = comparator; this.version = version; this.hashFunction = hashFunction; - this.supportsBinaryEquality = supportsBinaryEquality; - this.supportsBinaryOrdering = supportsBinaryOrdering; - this.supportsLowercaseEquality = supportsLowercaseEquality; + this.isUtf8BinaryType = isUtf8BinaryType; + this.isUtf8LcaseType = isUtf8LcaseType; this.equalsFunction = equalsFunction; this.supportsSpaceTrimming = supportsSpaceTrimming; - - // De Morgan's Law to check supportsBinaryOrdering => supportsBinaryEquality - assert(!supportsBinaryOrdering || supportsBinaryEquality); + this.supportsBinaryEquality = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsBinaryOrdering = !supportsSpaceTrimming && isUtf8BinaryType; + this.supportsLowercaseEquality = !supportsSpaceTrimming && isUtf8LcaseType; // No Collation can simultaneously support binary equality and lowercase equality assert(!supportsBinaryEquality || !supportsLowercaseEquality); @@ -567,9 +577,8 @@ protected Collation buildCollation() { "1.0", hashFunction, equalsFunction, - /* supportsBinaryEquality = */ true, - /* supportsBinaryOrdering = */ true, - /* supportsLowercaseEquality = */ false, + /* isUtf8BinaryType = */ true, + /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } else { Comparator comparator; @@ -595,9 +604,8 @@ protected Collation buildCollation() { "1.0", hashFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ true, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ true, spaceTrimming != SpaceTrimming.NONE); } } @@ -982,9 +990,8 @@ protected Collation buildCollation() { ICU_COLLATOR_VERSION, hashFunction, (s1, s2) -> comparator.compare(s1, s2) == 0, - /* supportsBinaryEquality = */ false, - /* supportsBinaryOrdering = */ false, - /* supportsLowercaseEquality = */ false, + /* isUtf8BinaryType = */ false, + /* isUtf8LcaseType = */ false, spaceTrimming != SpaceTrimming.NONE); } @@ -1191,9 +1198,9 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return input; - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input); } else { CollationKey collationKey = collation.collator.getCollationKey( @@ -1207,9 +1214,9 @@ public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { if (collation.supportsSpaceTrimming) { input = Collation.CollationSpec.applyTrimmingPolicy(input, collationId); } - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return input.getBytes(); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return CollationAwareUTF8String.lowerCaseCodePoints(input).getBytes(); } else { return collation.collator.getCollationKey( diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java index f05d9e512568f..978b663cc25c9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationSupport.java @@ -37,9 +37,9 @@ public final class CollationSupport { public static class StringSplitSQL { public static UTF8String[] exec(final UTF8String s, final UTF8String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(s, d); } else { return execICU(s, d, collationId); @@ -48,9 +48,9 @@ public static UTF8String[] exec(final UTF8String s, final UTF8String d, final in public static String genCode(final String s, final String d, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringSplitSQL.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", s, d); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", s, d); } else { return String.format(expr + "ICU(%s, %s, %d)", s, d, collationId); @@ -71,9 +71,9 @@ public static UTF8String[] execICU(final UTF8String string, final UTF8String del public static class Contains { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -82,9 +82,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Contains.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -109,9 +109,9 @@ public static class StartsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -120,9 +120,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StartsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -146,9 +146,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class EndsWith { public static boolean exec(final UTF8String l, final UTF8String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(l, r); } else { return execICU(l, r, collationId); @@ -157,9 +157,9 @@ public static boolean exec(final UTF8String l, final UTF8String r, final int col public static String genCode(final String l, final String r, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.EndsWith.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", l, r); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", l, r); } else { return String.format(expr + "ICU(%s, %s, %d)", l, r, collationId); @@ -184,9 +184,9 @@ public static boolean execICU(final UTF8String l, final UTF8String r, public static class Upper { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -195,10 +195,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Upper.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -221,9 +221,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class Lower { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -232,10 +232,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.Lower.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -258,9 +258,9 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class InitCap { public static UTF8String exec(final UTF8String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return useICU ? execBinaryICU(v) : execBinary(v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(v); } else { return execICU(v, collationId); @@ -270,10 +270,10 @@ public static UTF8String exec(final UTF8String v, final int collationId, boolean public static String genCode(final String v, final int collationId, boolean useICU) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.InitCap.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { String funcName = useICU ? "BinaryICU" : "Binary"; return String.format(expr + "%s(%s)", funcName, v); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s)", v); } else { return String.format(expr + "ICU(%s, %d)", v, collationId); @@ -296,7 +296,7 @@ public static UTF8String execICU(final UTF8String v, final int collationId) { public static class FindInSet { public static int exec(final UTF8String word, final UTF8String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(word, set); } else { return execCollationAware(word, set, collationId); @@ -305,7 +305,7 @@ public static int exec(final UTF8String word, final UTF8String set, final int co public static String genCode(final String word, final String set, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.FindInSet.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", word, set); } else { return String.format(expr + "CollationAware(%s, %s, %d)", word, set, collationId); @@ -324,9 +324,9 @@ public static class StringInstr { public static int exec(final UTF8String string, final UTF8String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring); } else { return execICU(string, substring, collationId); @@ -336,9 +336,9 @@ public static String genCode(final String string, final String substring, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringInstr.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", string, substring); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", string, substring); } else { return String.format(expr + "ICU(%s, %s, %d)", string, substring, collationId); @@ -360,9 +360,9 @@ public static class StringReplace { public static UTF8String exec(final UTF8String src, final UTF8String search, final UTF8String replace, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(src, search, replace); } else { return execICU(src, search, replace, collationId); @@ -372,9 +372,9 @@ public static String genCode(final String src, final String search, final String final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringReplace.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", src, search, replace); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", src, search, replace); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", src, search, replace, collationId); @@ -398,9 +398,9 @@ public static class StringLocate { public static int exec(final UTF8String string, final UTF8String substring, final int start, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, substring, start); } else { return execICU(string, substring, start, collationId); @@ -410,9 +410,9 @@ public static String genCode(final String string, final String substring, final final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringLocate.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %d)", string, substring, start); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %d)", string, substring, start); } else { return String.format(expr + "ICU(%s, %s, %d, %d)", string, substring, start, collationId); @@ -436,9 +436,9 @@ public static class SubstringIndex { public static UTF8String exec(final UTF8String string, final UTF8String delimiter, final int count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(string, delimiter, count); } else { return execICU(string, delimiter, count, collationId); @@ -448,9 +448,9 @@ public static String genCode(final String string, final String delimiter, final String count, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.SubstringIndex.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s, %s)", string, delimiter, count); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s, %s)", string, delimiter, count); } else { return String.format(expr + "ICU(%s, %s, %s, %d)", string, delimiter, count, collationId); @@ -474,9 +474,9 @@ public static class StringTranslate { public static UTF8String exec(final UTF8String source, Map dict, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(source, dict); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(source, dict); } else { return execICU(source, dict, collationId); @@ -503,9 +503,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -520,9 +520,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrim.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -559,9 +559,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -576,9 +576,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimLeft.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -614,9 +614,9 @@ public static UTF8String exec( final UTF8String trimString, final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return execBinary(srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return execLowercase(srcString, trimString); } else { return execICU(srcString, trimString, collationId); @@ -631,9 +631,9 @@ public static String genCode( final int collationId) { CollationFactory.Collation collation = CollationFactory.fetchCollation(collationId); String expr = "CollationSupport.StringTrimRight.exec"; - if (collation.supportsBinaryEquality) { + if (collation.isUtf8BinaryType) { return String.format(expr + "Binary(%s, %s)", srcString, trimString); - } else if (collation.supportsLowercaseEquality) { + } else if (collation.isUtf8LcaseType) { return String.format(expr + "Lowercase(%s, %s)", srcString, trimString); } else { return String.format(expr + "ICU(%s, %s, %d)", srcString, trimString, collationId); @@ -669,7 +669,7 @@ public static UTF8String execICU( public static boolean supportsLowercaseRegex(final int collationId) { // for regex, only Unicode case-insensitive matching is possible, // so UTF8_LCASE is treated as UNICODE_CI in this context - return CollationFactory.fetchCollation(collationId).supportsLowercaseEquality; + return CollationFactory.fetchCollation(collationId).isUtf8LcaseType; } static final int lowercaseRegexFlags = Pattern.UNICODE_CASE | Pattern.CASE_INSENSITIVE; diff --git a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala index a565d2d347636..df9af1579d4f1 100644 --- a/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala +++ b/common/unsafe/src/test/scala/org/apache/spark/unsafe/types/CollationFactorySuite.scala @@ -38,22 +38,22 @@ class CollationFactorySuite extends AnyFunSuite with Matchers { // scalastyle:ig assert(UTF8_BINARY_COLLATION_ID == 0) val utf8Binary = fetchCollation(UTF8_BINARY_COLLATION_ID) assert(utf8Binary.collationName == "UTF8_BINARY") - assert(utf8Binary.supportsBinaryEquality) + assert(utf8Binary.isUtf8BinaryType) assert(UTF8_LCASE_COLLATION_ID == 1) val utf8Lcase = fetchCollation(UTF8_LCASE_COLLATION_ID) assert(utf8Lcase.collationName == "UTF8_LCASE") - assert(!utf8Lcase.supportsBinaryEquality) + assert(!utf8Lcase.isUtf8BinaryType) assert(UNICODE_COLLATION_ID == (1 << 29)) val unicode = fetchCollation(UNICODE_COLLATION_ID) assert(unicode.collationName == "UNICODE") - assert(!unicode.supportsBinaryEquality) + assert(!unicode.isUtf8BinaryType) assert(UNICODE_CI_COLLATION_ID == ((1 << 29) | (1 << 17))) val unicodeCi = fetchCollation(UNICODE_CI_COLLATION_ID) assert(unicodeCi.collationName == "UNICODE_CI") - assert(!unicodeCi.supportsBinaryEquality) + assert(!unicodeCi.isUtf8BinaryType) } test("UTF8_BINARY and ICU root locale collation names") { diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index cb9f54cced57d..21bcee21afd1d 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -376,6 +376,12 @@ ], "sqlState" : "429BB" }, + "CANNOT_REMOVE_RESERVED_PROPERTY" : { + "message" : [ + "Cannot remove reserved property: ." + ], + "sqlState" : "42000" + }, "CANNOT_RENAME_ACROSS_SCHEMA" : { "message" : [ "Renaming a across schemas is not allowed." @@ -606,6 +612,12 @@ ], "sqlState" : "42711" }, + "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH" : { + "message" : [ + "Some values in field are incompatible with the column array type. Expected type ." + ], + "sqlState" : "0A000" + }, "COLUMN_NOT_DEFINED_IN_TABLE" : { "message" : [ " column is not defined in table , defined table columns are: ." @@ -3821,6 +3833,12 @@ ], "sqlState" : "42617" }, + "PARSE_MODE_UNSUPPORTED" : { + "message" : [ + "The function doesn't support the mode. Acceptable modes are PERMISSIVE and FAILFAST." + ], + "sqlState" : "42601" + }, "PARSE_SYNTAX_ERROR" : { "message" : [ "Syntax error at or near ." @@ -6034,11 +6052,6 @@ "DataType '' is not supported by ." ] }, - "_LEGACY_ERROR_TEMP_1099" : { - "message" : [ - "() doesn't support the mode. Acceptable modes are and ." - ] - }, "_LEGACY_ERROR_TEMP_1103" : { "message" : [ "Unsupported component type in arrays." @@ -6853,11 +6866,6 @@ " is not implemented." ] }, - "_LEGACY_ERROR_TEMP_2042" : { - "message" : [ - ". If necessary set to false to bypass this error." - ] - }, "_LEGACY_ERROR_TEMP_2045" : { "message" : [ "Unsupported table change: " @@ -6950,11 +6958,6 @@ "Missing database location." ] }, - "_LEGACY_ERROR_TEMP_2069" : { - "message" : [ - "Cannot remove reserved property: ." - ] - }, "_LEGACY_ERROR_TEMP_2070" : { "message" : [ "Writing job failed." @@ -8095,11 +8098,6 @@ "No handler for UDF/UDAF/UDTF '': " ] }, - "_LEGACY_ERROR_TEMP_3085" : { - "message" : [ - "from_avro() doesn't support the mode. Acceptable modes are and ." - ] - }, "_LEGACY_ERROR_TEMP_3086" : { "message" : [ "Cannot persist into Hive metastore as table property keys may not start with 'spark.sql.': " @@ -8705,6 +8703,21 @@ "Doesn't support month or year interval: " ] }, + "_LEGACY_ERROR_TEMP_3300" : { + "message" : [ + "error while calling spill() on : " + ] + }, + "_LEGACY_ERROR_TEMP_3301" : { + "message" : [ + "Not enough memory to grow pointer array" + ] + }, + "_LEGACY_ERROR_TEMP_3302" : { + "message" : [ + "No enough memory for aggregation" + ] + }, "_LEGACY_ERROR_USER_RAISED_EXCEPTION" : { "message" : [ "" diff --git a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala index 0b85b208242cb..9c8b2d0375588 100644 --- a/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala +++ b/connector/avro/src/main/scala/org/apache/spark/sql/avro/AvroDataToCatalyst.scala @@ -24,10 +24,10 @@ import org.apache.avro.generic.GenericDatumReader import org.apache.avro.io.{BinaryDecoder, DecoderFactory} import org.apache.spark.SparkException -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ private[sql] case class AvroDataToCatalyst( @@ -80,12 +80,9 @@ private[sql] case class AvroDataToCatalyst( @transient private lazy val parseMode: ParseMode = { val mode = avroOptions.parseMode if (mode != PermissiveMode && mode != FailFastMode) { - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3085", - messageParameters = Map( - "name" -> mode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + throw QueryCompilationErrors.parseModeUnsupportedError( + prettyName, mode + ) } mode } @@ -123,12 +120,9 @@ private[sql] case class AvroDataToCatalyst( s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + "result, try setting the option 'mode' as 'PERMISSIVE'.", e) case _ => - throw new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_3085", - messageParameters = Map( - "name" -> parseMode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + throw QueryCompilationErrors.parseModeUnsupportedError( + prettyName, parseMode + ) } } } diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala index a7f7abadcf485..096cdfe0b9ee4 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroFunctionsSuite.scala @@ -106,6 +106,17 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession { functions.from_avro( $"avro", avroTypeStruct, Map("mode" -> "PERMISSIVE").asJava)), expected) + + checkError( + exception = intercept[AnalysisException] { + avroStructDF.select( + functions.from_avro( + $"avro", avroTypeStruct, Map("mode" -> "DROPMALFORMED").asJava)).collect() + }, + condition = "PARSE_MODE_UNSUPPORTED", + parameters = Map( + "funcName" -> "`from_avro`", + "mode" -> "DROPMALFORMED")) } test("roundtrip in to_avro and from_avro - array with null") { diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala index 9d3c7d1eca328..6bd33356cab3d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSQLServerDatabaseOnDocker.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc class MsSQLServerDatabaseOnDocker extends DatabaseOnDocker { override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", - "mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04") + "mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04") override val env = Map( "SA_PASSWORD" -> "Sapass123", "ACCEPT_EULA" -> "Y" diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala index 90cd68e6e1d24..62f088ebc2b6d 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala @@ -31,10 +31,10 @@ import org.apache.spark.sql.types.{BinaryType, DecimalType} import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2022-CU14-ubuntu-22.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests * "docker-integration-tests/testOnly org.apache.spark.sql.jdbc.MsSqlServerIntegrationSuite" * }}} diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index aaaaa28558342..d884ad4c62466 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -27,10 +27,10 @@ import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2022-CU14-ubuntu-22.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MsSqlServerIntegrationSuite" * }}} */ diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala index 9fb3bc4fba945..724c394a4f052 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -26,10 +26,10 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.tags.DockerTest /** - * To run this test suite for a specific version (e.g., 2022-CU14-ubuntu-22.04): + * To run this test suite for a specific version (e.g., 2022-CU15-ubuntu-22.04): * {{{ * ENABLE_DOCKER_INTEGRATION_TESTS=1 - * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU14-ubuntu-22.04 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2022-CU15-ubuntu-22.04 * ./build/sbt -Pdocker-integration-tests "testOnly *v2.MsSqlServerNamespaceSuite" * }}} */ diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 05f02a402353b..f70b500f974a4 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkSQLException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog @@ -65,9 +65,104 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCT |) """.stripMargin ).executeUpdate() - connection.prepareStatement( - "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + + connection.prepareStatement("CREATE TABLE array_test_table (int_array int[]," + + "float_array FLOAT8[], timestamp_array TIMESTAMP[], string_array TEXT[]," + + "datetime_array TIMESTAMPTZ[], array_of_int_arrays INT[][])").executeUpdate() + + val query = + """ + INSERT INTO array_test_table + (int_array, float_array, timestamp_array, string_array, + datetime_array, array_of_int_arrays) + VALUES + ( + ARRAY[1, 2, 3], -- Array of integers + ARRAY[1.1, 2.2, 3.3], -- Array of floats + ARRAY['2023-01-01 12:00'::timestamp, '2023-06-01 08:30'::timestamp], + ARRAY['hello', 'world'], -- Array of strings + ARRAY['2023-10-04 12:00:00+00'::timestamptz, + '2023-12-01 14:15:00+00'::timestamptz], + ARRAY[ARRAY[1, 2]] -- Array of arrays of integers + ), + ( + ARRAY[10, 20, 30], -- Another set of data + ARRAY[10.5, 20.5, 30.5], + ARRAY['2022-01-01 09:15'::timestamp, '2022-03-15 07:45'::timestamp], + ARRAY['postgres', 'arrays'], + ARRAY['2022-11-22 09:00:00+00'::timestamptz, + '2022-12-31 23:59:59+00'::timestamptz], + ARRAY[ARRAY[10, 20]] + ); + """ + connection.prepareStatement(query).executeUpdate() + + connection.prepareStatement("CREATE TABLE array_int (col int[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_bigint(col bigint[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_smallint (col smallint[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_boolean (col boolean[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_float (col real[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_double (col float8[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_timestamp (col timestamp[])").executeUpdate() + connection.prepareStatement("CREATE TABLE array_timestamptz (col timestamptz[])") + .executeUpdate() + + connection.prepareStatement("INSERT INTO array_int VALUES (array[array[10]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_bigint VALUES (array[array[10]])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_smallint VALUES (array[array[10]])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_boolean VALUES (array[array[true]])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_float VALUES (array[array[10.5]])") + .executeUpdate() + connection.prepareStatement("INSERT INTO array_double VALUES (array[array[10.1]])") .executeUpdate() + connection.prepareStatement("INSERT INTO array_timestamp VALUES (" + + "array[array['2022-01-01 09:15'::timestamp]])").executeUpdate() + connection.prepareStatement("INSERT INTO array_timestamptz VALUES " + + "(array[array['2022-01-01 09:15'::timestamptz]])").executeUpdate() + connection.prepareStatement( + "CREATE TABLE datetime (name VARCHAR(32), date1 DATE, time1 TIMESTAMP)") + .executeUpdate() + } + + test("Test multi-dimensional column types") { + // This test is used to verify that the multi-dimensional + // column types are supported by the JDBC V2 data source. + // We do not verify any result output + // + val df = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", "array_test_table") + .load() + df.collect() + + val array_tables = Array( + ("array_int", "\"ARRAY\""), + ("array_bigint", "\"ARRAY\""), + ("array_smallint", "\"ARRAY\""), + ("array_boolean", "\"ARRAY\""), + ("array_float", "\"ARRAY\""), + ("array_double", "\"ARRAY\""), + ("array_timestamp", "\"ARRAY\""), + ("array_timestamptz", "\"ARRAY\"") + ) + + array_tables.foreach { case (dbtable, arrayType) => + checkError( + exception = intercept[SparkSQLException] { + val df = spark.read.format("jdbc") + .option("url", jdbcUrl) + .option("dbtable", dbtable) + .load() + df.collect() + }, + condition = "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH", + parameters = Map("pos" -> "0", "type" -> arrayType), + sqlState = Some("0A000") + ) + } } override def dataPreparation(connection: Connection): Unit = { diff --git a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java index fa71eb066ff89..0e35ebecfd270 100644 --- a/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java +++ b/core/src/main/java/org/apache/spark/memory/SparkOutOfMemoryError.java @@ -32,14 +32,6 @@ public final class SparkOutOfMemoryError extends OutOfMemoryError implements Spa String errorClass; Map messageParameters; - public SparkOutOfMemoryError(String s) { - super(s); - } - - public SparkOutOfMemoryError(OutOfMemoryError e) { - super(e.getMessage()); - } - public SparkOutOfMemoryError(String errorClass, Map messageParameters) { super(SparkThrowableHelper.getMessage(errorClass, messageParameters)); this.errorClass = errorClass; diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index df224bc902bff..bd9f58bf7415f 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -21,13 +21,7 @@ import java.io.InterruptedIOException; import java.io.IOException; import java.nio.channels.ClosedByInterruptException; -import java.util.Arrays; -import java.util.ArrayList; -import java.util.BitSet; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.TreeMap; +import java.util.*; import com.google.common.annotations.VisibleForTesting; @@ -291,8 +285,12 @@ private long trySpillAndAcquire( logger.error("error while calling spill() on {}", e, MDC.of(LogKeys.MEMORY_CONSUMER$.MODULE$, consumerToSpill)); // checkstyle.off: RegexpSinglelineJava - throw new SparkOutOfMemoryError("error while calling spill() on " + consumerToSpill + " : " - + e.getMessage()); + throw new SparkOutOfMemoryError( + "_LEGACY_ERROR_TEMP_3300", + new HashMap() {{ + put("consumerToSpill", consumerToSpill.toString()); + put("message", e.getMessage()); + }}); // checkstyle.on: RegexpSinglelineJava } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 7579c0aefb250..761ced66f78cf 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -18,6 +18,7 @@ package org.apache.spark.util.collection.unsafe.sort; import java.util.Comparator; +import java.util.HashMap; import java.util.LinkedList; import javax.annotation.Nullable; @@ -215,7 +216,7 @@ public void expandPointerArray(LongArray newArray) { if (array != null) { if (newArray.size() < array.size()) { // checkstyle.off: RegexpSinglelineJava - throw new SparkOutOfMemoryError("Not enough memory to grow pointer array"); + throw new SparkOutOfMemoryError("_LEGACY_ERROR_TEMP_3301", new HashMap()); // checkstyle.on: RegexpSinglelineJava } Platform.copyMemory( diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index 805e7ca467497..fa13092dc47aa 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -21,7 +21,7 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.lang.Thread.UncaughtExceptionHandler import java.net.URL import java.nio.ByteBuffer -import java.util.Properties +import java.util.{HashMap, Properties} import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean @@ -522,7 +522,13 @@ class ExecutorSuite extends SparkFunSuite testThrowable(new OutOfMemoryError(), depthToCheck, isFatal = true) testThrowable(new InterruptedException(), depthToCheck, isFatal = false) testThrowable(new RuntimeException("test"), depthToCheck, isFatal = false) - testThrowable(new SparkOutOfMemoryError("test"), depthToCheck, isFatal = false) + testThrowable( + new SparkOutOfMemoryError( + "_LEGACY_ERROR_USER_RAISED_EXCEPTION", + new HashMap[String, String]() { + put("errorMessage", "test") + }), + depthToCheck, isFatal = false) } // Verify we can handle the cycle in the exception chain diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index 4239180ba6c37..fb2bb83cb7fc4 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -201,10 +201,10 @@ class AsyncRDDActionsSuite extends SparkFunSuite with TimeLimits { test("FutureAction result, timeout") { val f = sc.parallelize(1 to 100, 4) - .mapPartitions(itr => { Thread.sleep(20); itr }) + .mapPartitions(itr => { Thread.sleep(200); itr }) .countAsync() intercept[TimeoutException] { - ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) + ThreadUtils.awaitResult(f, Duration(2, "milliseconds")) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 849832c57edaa..f00fb0d2cfa3f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -101,7 +101,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext with val rdd2 = rdd.barrier().mapPartitions { it => val context = BarrierTaskContext.get() // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) + Thread.sleep(Random.nextInt(500)) context.barrier() val time1 = System.currentTimeMillis() // Sleep for a random time before global sync. diff --git a/dev/infra/Dockerfile b/dev/infra/Dockerfile index 1619b009e9364..1edeed775880b 100644 --- a/dev/infra/Dockerfile +++ b/dev/infra/Dockerfile @@ -24,7 +24,7 @@ LABEL org.opencontainers.image.ref.name="Apache Spark Infra Image" # Overwrite this label to avoid exposing the underlying Ubuntu OS version label LABEL org.opencontainers.image.version="" -ENV FULL_REFRESH_DATE 20241002 +ENV FULL_REFRESH_DATE 20241007 ENV DEBIAN_FRONTEND noninteractive ENV DEBCONF_NONINTERACTIVE_SEEN true @@ -91,10 +91,10 @@ RUN mkdir -p /usr/local/pypy/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3.9 && \ ln -sf /usr/local/pypy/pypy3.9/bin/pypy /usr/local/bin/pypy3 RUN curl -sS https://bootstrap.pypa.io/get-pip.py | pypy3 -RUN pypy3 -m pip install 'numpy==1.26.4' 'six==1.16.0' 'pandas==2.2.3' scipy coverage matplotlib lxml +RUN pypy3 -m pip install numpy 'six==1.16.0' 'pandas==2.2.3' scipy coverage matplotlib lxml -ARG BASIC_PIP_PKGS="numpy==1.26.4 pyarrow>=15.0.0 six==1.16.0 pandas==2.2.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" +ARG BASIC_PIP_PKGS="numpy pyarrow>=15.0.0 six==1.16.0 pandas==2.2.3 scipy plotly>=4.8 mlflow>=2.8.1 coverage matplotlib openpyxl memory-profiler>=0.61.0 scikit-learn>=1.3.2" # Python deps for Spark Connect ARG CONNECT_PIP_PKGS="grpcio==1.62.0 grpcio-status==1.62.0 protobuf==4.25.1 googleapis-common-protos==1.56.4 graphviz==0.20.3" @@ -152,6 +152,6 @@ RUN python3.13 -m pip install lxml numpy>=2.1 && \ python3.13 -m pip cache purge # Remove unused installation packages to free up disk space -RUN apt-get remove --purge -y 'gfortran-11' 'humanity-icon-theme' 'nodejs-doc' || true +RUN apt-get remove --purge -y 'humanity-icon-theme' 'nodejs-doc' RUN apt-get autoremove --purge -y RUN apt-get clean diff --git a/pom.xml b/pom.xml index 2b89454873782..cab7f7f595434 100644 --- a/pom.xml +++ b/pom.xml @@ -1115,11 +1115,6 @@ jersey-client ${jersey.version} - - javax.ws.rs - javax.ws.rs-api - 2.0.1 - javax.xml.bind jaxb-api diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 737efa8f7846b..a87e0af0b542f 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1072,20 +1072,9 @@ object DependencyOverrides { object ExcludedDependencies { lazy val settings = Seq( libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") }, - // SPARK-33705: Due to sbt compiler issues, it brings exclusions defined in maven pom back to - // the classpath directly and assemble test scope artifacts to assembly/target/scala-xx/jars, - // which is also will be added to the classpath of some unit tests that will build a subprocess - // to run `spark-submit`, e.g. HiveThriftServer2Test. - // - // These artifacts are for the jersey-1 API but Spark use jersey-2 ones, so it cause test - // flakiness w/ jar conflicts issues. - // - // Also jersey-1 is only used by yarn module(see resource-managers/yarn/pom.xml) for testing - // purpose only. Here we exclude them from the whole project scope and add them w/ yarn only. excludeDependencies ++= Seq( - ExclusionRule(organization = "com.sun.jersey"), ExclusionRule(organization = "ch.qos.logback"), - ExclusionRule("javax.ws.rs", "jsr311-api")) + ExclusionRule("javax.servlet", "javax.servlet-api")) ) } @@ -1229,10 +1218,6 @@ object YARN { val hadoopProvidedProp = "spark.yarn.isHadoopProvided" lazy val settings = Seq( - excludeDependencies --= Seq( - ExclusionRule(organization = "com.sun.jersey"), - ExclusionRule("javax.servlet", "javax.servlet-api"), - ExclusionRule("javax.ws.rs", "jsr311-api")), Compile / unmanagedResources := (Compile / unmanagedResources).value.filter(!_.getName.endsWith(s"$propFileName")), genConfigProperties := { diff --git a/project/plugins.sbt b/project/plugins.sbt index 67d739452d8da..b2d0177e6a411 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -24,7 +24,7 @@ libraryDependencies += "com.puppycrawl.tools" % "checkstyle" % "10.17.0" // checkstyle uses guava 33.1.0-jre. libraryDependencies += "com.google.guava" % "guava" % "33.1.0-jre" -addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.2.0") +addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "2.3.0") addSbtPlugin("com.github.sbt" % "sbt-eclipse" % "6.2.0") diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 937753b50bb13..b89755d9c18a5 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -699,7 +699,7 @@ class LinearSVC( >>> model_path = temp_path + "/svm_model" >>> model.save(model_path) >>> model2 = LinearSVCModel.load(model_path) - >>> model.coefficients[0] == model2.coefficients[0] + >>> bool(model.coefficients[0] == model2.coefficients[0]) True >>> model.intercept == model2.intercept True @@ -1210,7 +1210,7 @@ class LogisticRegression( >>> model_path = temp_path + "/lr_model" >>> blorModel.save(model_path) >>> model2 = LogisticRegressionModel.load(model_path) - >>> blorModel.coefficients[0] == model2.coefficients[0] + >>> bool(blorModel.coefficients[0] == model2.coefficients[0]) True >>> blorModel.intercept == model2.intercept True @@ -2038,9 +2038,9 @@ class RandomForestClassifier( >>> result = model.transform(test0).head() >>> result.prediction 0.0 - >>> numpy.argmax(result.probability) + >>> int(numpy.argmax(result.probability)) 0 - >>> numpy.argmax(result.newRawPrediction) + >>> int(numpy.argmax(result.newRawPrediction)) 0 >>> result.leafId DenseVector([0.0, 0.0, 0.0]) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index d08e241b41d23..d7cc27e274279 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -266,7 +266,7 @@ class LinearRegression( True >>> abs(model.transform(test0).head().newPrediction - (-1.0)) < 0.001 True - >>> abs(model.coefficients[0] - 1.0) < 0.001 + >>> bool(abs(model.coefficients[0] - 1.0) < 0.001) True >>> abs(model.intercept - 0.0) < 0.001 True @@ -283,11 +283,11 @@ class LinearRegression( >>> model_path = temp_path + "/lr_model" >>> model.save(model_path) >>> model2 = LinearRegressionModel.load(model_path) - >>> model.coefficients[0] == model2.coefficients[0] + >>> bool(model.coefficients[0] == model2.coefficients[0]) True - >>> model.intercept == model2.intercept + >>> bool(model.intercept == model2.intercept) True - >>> model.transform(test0).take(1) == model2.transform(test0).take(1) + >>> bool(model.transform(test0).take(1) == model2.transform(test0).take(1)) True >>> model.numFeatures 1 @@ -2542,7 +2542,7 @@ class GeneralizedLinearRegression( >>> model2 = GeneralizedLinearRegressionModel.load(model_path) >>> model.intercept == model2.intercept True - >>> model.coefficients[0] == model2.coefficients[0] + >>> bool(model.coefficients[0] == model2.coefficients[0]) True >>> model.transform(df).take(1) == model2.transform(df).take(1) True diff --git a/python/pyspark/ml/tests/test_functions.py b/python/pyspark/ml/tests/test_functions.py index 7df0a26394140..e67e46ded67bd 100644 --- a/python/pyspark/ml/tests/test_functions.py +++ b/python/pyspark/ml/tests/test_functions.py @@ -18,6 +18,7 @@ import numpy as np +from pyspark.loose_version import LooseVersion from pyspark.ml.functions import predict_batch_udf from pyspark.sql.functions import array, struct, col from pyspark.sql.types import ArrayType, DoubleType, IntegerType, StructType, StructField, FloatType @@ -193,6 +194,10 @@ def predict(inputs): batch_sizes = preds["preds"].to_numpy() self.assertTrue(all(batch_sizes <= batch_size)) + # TODO(SPARK-49793): enable the test below + @unittest.skipIf( + LooseVersion(np.__version__) >= LooseVersion("2"), "Caching does not work with numpy 2" + ) def test_caching(self): def make_predict_fn(): # emulate loading a model, this should only be invoked once (per worker process) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index e8713d81c4d62..888beff663523 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -706,7 +706,7 @@ class CrossValidator( >>> cvModel = cv.fit(dataset) >>> cvModel.getNumFolds() 3 - >>> cvModel.avgMetrics[0] + >>> float(cvModel.avgMetrics[0]) 0.5 >>> path = tempfile.mkdtemp() >>> model_path = path + "/model" diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 1e1795d9fb3d4..bf8fd04dc2837 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -172,9 +172,9 @@ class LogisticRegressionModel(LinearClassificationModel): >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) >>> sameModel = LogisticRegressionModel.load(sc, path) - >>> sameModel.predict(numpy.array([0.0, 1.0])) + >>> int(sameModel.predict(numpy.array([0.0, 1.0]))) 1 - >>> sameModel.predict(SparseVector(2, {0: 1.0})) + >>> int(sameModel.predict(SparseVector(2, {0: 1.0}))) 0 >>> from shutil import rmtree >>> try: @@ -555,7 +555,7 @@ class SVMModel(LinearClassificationModel): >>> svm.predict(sc.parallelize([[1.0]])).collect() [1] >>> svm.clearThreshold() - >>> svm.predict(numpy.array([1.0])) + >>> float(svm.predict(numpy.array([1.0]))) 1.44... >>> sparse_data = [ @@ -573,9 +573,9 @@ class SVMModel(LinearClassificationModel): >>> path = tempfile.mkdtemp() >>> svm.save(sc, path) >>> sameModel = SVMModel.load(sc, path) - >>> sameModel.predict(SparseVector(2, {1: 1.0})) + >>> int(sameModel.predict(SparseVector(2, {1: 1.0}))) 1 - >>> sameModel.predict(SparseVector(2, {0: -1.0})) + >>> int(sameModel.predict(SparseVector(2, {0: -1.0}))) 0 >>> from shutil import rmtree >>> try: @@ -756,11 +756,11 @@ class NaiveBayesModel(Saveable, Loader["NaiveBayesModel"]): ... LabeledPoint(1.0, [1.0, 0.0]), ... ] >>> model = NaiveBayes.train(sc.parallelize(data)) - >>> model.predict(numpy.array([0.0, 1.0])) + >>> float(model.predict(numpy.array([0.0, 1.0]))) 0.0 - >>> model.predict(numpy.array([1.0, 0.0])) + >>> float(model.predict(numpy.array([1.0, 0.0]))) 1.0 - >>> model.predict(sc.parallelize([[1.0, 0.0]])).collect() + >>> list(map(float, model.predict(sc.parallelize([[1.0, 0.0]])).collect())) [1.0] >>> sparse_data = [ ... LabeledPoint(0.0, SparseVector(2, {1: 0.0})), @@ -768,15 +768,18 @@ class NaiveBayesModel(Saveable, Loader["NaiveBayesModel"]): ... LabeledPoint(1.0, SparseVector(2, {0: 1.0})) ... ] >>> model = NaiveBayes.train(sc.parallelize(sparse_data)) - >>> model.predict(SparseVector(2, {1: 1.0})) + >>> float(model.predict(SparseVector(2, {1: 1.0}))) 0.0 - >>> model.predict(SparseVector(2, {0: 1.0})) + >>> float(model.predict(SparseVector(2, {0: 1.0}))) 1.0 >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> model.save(sc, path) >>> sameModel = NaiveBayesModel.load(sc, path) - >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0})) + >>> bool(( + ... sameModel.predict(SparseVector(2, {0: 1.0})) == + ... model.predict(SparseVector(2, {0: 1.0})) + ... )) True >>> from shutil import rmtree >>> try: diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 24884f4853371..915a55595cb53 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -554,9 +554,9 @@ class PCA: ... Vectors.dense([4.0, 0.0, 0.0, 6.0, 7.0])] >>> model = PCA(2).fit(sc.parallelize(data)) >>> pcArray = model.transform(Vectors.sparse(5, [(1, 1.0), (3, 7.0)])).toArray() - >>> pcArray[0] + >>> float(pcArray[0]) 1.648... - >>> pcArray[1] + >>> float(pcArray[1]) -4.013... """ diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/random.py index 80bbd717071dc..dbe1048a64b36 100644 --- a/python/pyspark/mllib/random.py +++ b/python/pyspark/mllib/random.py @@ -134,9 +134,9 @@ def normalRDD( >>> stats = x.stats() >>> stats.count() 1000 - >>> abs(stats.mean() - 0.0) < 0.1 + >>> bool(abs(stats.mean() - 0.0) < 0.1) True - >>> abs(stats.stdev() - 1.0) < 0.1 + >>> bool(abs(stats.stdev() - 1.0) < 0.1) True """ return callMLlibFunc("normalRDD", sc._jsc, size, numPartitions, seed) @@ -186,10 +186,10 @@ def logNormalRDD( >>> stats = x.stats() >>> stats.count() 1000 - >>> abs(stats.mean() - expMean) < 0.5 + >>> bool(abs(stats.mean() - expMean) < 0.5) True >>> from math import sqrt - >>> abs(stats.stdev() - expStd) < 0.5 + >>> bool(abs(stats.stdev() - expStd) < 0.5) True """ return callMLlibFunc( @@ -238,7 +238,7 @@ def poissonRDD( >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt - >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + >>> bool(abs(stats.stdev() - sqrt(mean)) < 0.5) True """ return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed) @@ -285,7 +285,7 @@ def exponentialRDD( >>> abs(stats.mean() - mean) < 0.5 True >>> from math import sqrt - >>> abs(stats.stdev() - sqrt(mean)) < 0.5 + >>> bool(abs(stats.stdev() - sqrt(mean)) < 0.5) True """ return callMLlibFunc("exponentialRDD", sc._jsc, float(mean), size, numPartitions, seed) @@ -336,9 +336,9 @@ def gammaRDD( >>> stats = x.stats() >>> stats.count() 1000 - >>> abs(stats.mean() - expMean) < 0.5 + >>> bool(abs(stats.mean() - expMean) < 0.5) True - >>> abs(stats.stdev() - expStd) < 0.5 + >>> bool(abs(stats.stdev() - expStd) < 0.5) True """ return callMLlibFunc( @@ -384,7 +384,7 @@ def uniformVectorRDD( >>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect()) >>> mat.shape (10, 10) - >>> mat.max() <= 1.0 and mat.min() >= 0.0 + >>> bool(mat.max() <= 1.0 and mat.min() >= 0.0) True >>> RandomRDDs.uniformVectorRDD(sc, 10, 10, 4).getNumPartitions() 4 @@ -430,9 +430,9 @@ def normalVectorRDD( >>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) - >>> abs(mat.mean() - 0.0) < 0.1 + >>> bool(abs(mat.mean() - 0.0) < 0.1) True - >>> abs(mat.std() - 1.0) < 0.1 + >>> bool(abs(mat.std() - 1.0) < 0.1) True """ return callMLlibFunc("normalVectorRDD", sc._jsc, numRows, numCols, numPartitions, seed) @@ -488,9 +488,9 @@ def logNormalVectorRDD( >>> mat = np.matrix(m) >>> mat.shape (100, 100) - >>> abs(mat.mean() - expMean) < 0.1 + >>> bool(abs(mat.mean() - expMean) < 0.1) True - >>> abs(mat.std() - expStd) < 0.1 + >>> bool(abs(mat.std() - expStd) < 0.1) True """ return callMLlibFunc( @@ -545,13 +545,13 @@ def poissonVectorRDD( >>> import numpy as np >>> mean = 100.0 >>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1) - >>> mat = np.mat(rdd.collect()) + >>> mat = np.asmatrix(rdd.collect()) >>> mat.shape (100, 100) - >>> abs(mat.mean() - mean) < 0.5 + >>> bool(abs(mat.mean() - mean) < 0.5) True >>> from math import sqrt - >>> abs(mat.std() - sqrt(mean)) < 0.5 + >>> bool(abs(mat.std() - sqrt(mean)) < 0.5) True """ return callMLlibFunc( @@ -599,13 +599,13 @@ def exponentialVectorRDD( >>> import numpy as np >>> mean = 0.5 >>> rdd = RandomRDDs.exponentialVectorRDD(sc, mean, 100, 100, seed=1) - >>> mat = np.mat(rdd.collect()) + >>> mat = np.asmatrix(rdd.collect()) >>> mat.shape (100, 100) - >>> abs(mat.mean() - mean) < 0.5 + >>> bool(abs(mat.mean() - mean) < 0.5) True >>> from math import sqrt - >>> abs(mat.std() - sqrt(mean)) < 0.5 + >>> bool(abs(mat.std() - sqrt(mean)) < 0.5) True """ return callMLlibFunc( @@ -662,9 +662,9 @@ def gammaVectorRDD( >>> mat = np.matrix(RandomRDDs.gammaVectorRDD(sc, shape, scale, 100, 100, seed=1).collect()) >>> mat.shape (100, 100) - >>> abs(mat.mean() - expMean) < 0.1 + >>> bool(abs(mat.mean() - expMean) < 0.1) True - >>> abs(mat.std() - expStd) < 0.1 + >>> bool(abs(mat.std() - expStd) < 0.1) True """ return callMLlibFunc( diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index f1003327912d0..87f05bc0979b8 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -144,9 +144,9 @@ class LinearRegressionModelBase(LinearModel): -------- >>> from pyspark.mllib.linalg import SparseVector >>> lrmb = LinearRegressionModelBase(np.array([1.0, 2.0]), 0.1) - >>> abs(lrmb.predict(np.array([-1.03, 7.777])) - 14.624) < 1e-6 + >>> bool(abs(lrmb.predict(np.array([-1.03, 7.777])) - 14.624) < 1e-6) True - >>> abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6 + >>> bool(abs(lrmb.predict(SparseVector(2, {0: -1.03, 1: 7.777})) - 14.624) < 1e-6) True """ @@ -190,23 +190,23 @@ class LinearRegressionModel(LinearRegressionModelBase): ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(lrm.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True - >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + >>> bool(abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5) True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) >>> sameModel = LinearRegressionModel.load(sc, path) - >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(sameModel.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(sameModel.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> from shutil import rmtree >>> try: @@ -221,16 +221,16 @@ class LinearRegressionModel(LinearRegressionModelBase): ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... miniBatchFraction=1.0, initialWeights=np.array([1.0]), regParam=0.1, regType="l2", ... intercept=True, validateData=True) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True """ @@ -402,23 +402,23 @@ class LassoModel(LinearRegressionModelBase): ... ] >>> lrm = LassoWithSGD.train( ... sc.parallelize(data), iterations=10, initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(lrm.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True - >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + >>> bool(abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5) True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) >>> sameModel = LassoModel.load(sc, path) - >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(sameModel.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(sameModel.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> from shutil import rmtree >>> try: @@ -433,16 +433,16 @@ class LassoModel(LinearRegressionModelBase): ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> lrm = LassoWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... regParam=0.01, miniBatchFraction=1.0, initialWeights=np.array([1.0]), intercept=True, ... validateData=True) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True """ @@ -580,23 +580,23 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... ] >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(lrm.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True - >>> abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5 + >>> bool(abs(lrm.predict(sc.parallelize([[1.0]])).collect()[0] - 1) < 0.5) True >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> lrm.save(sc, path) >>> sameModel = RidgeRegressionModel.load(sc, path) - >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(sameModel.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + >>> bool(abs(sameModel.predict(np.array([1.0])) - 1) < 0.5) True - >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> from shutil import rmtree >>> try: @@ -611,16 +611,16 @@ class RidgeRegressionModel(LinearRegressionModelBase): ... ] >>> lrm = LinearRegressionWithSGD.train(sc.parallelize(data), iterations=10, ... initialWeights=np.array([1.0])) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True >>> lrm = RidgeRegressionWithSGD.train(sc.parallelize(data), iterations=10, step=1.0, ... regParam=0.01, miniBatchFraction=1.0, initialWeights=np.array([1.0]), intercept=True, ... validateData=True) - >>> abs(lrm.predict(np.array([0.0])) - 0) < 0.5 + >>> bool(abs(lrm.predict(np.array([0.0])) - 0) < 0.5) True - >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + >>> bool(abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5) True """ @@ -764,19 +764,19 @@ class IsotonicRegressionModel(Saveable, Loader["IsotonicRegressionModel"]): -------- >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] >>> irm = IsotonicRegression.train(sc.parallelize(data)) - >>> irm.predict(3) + >>> float(irm.predict(3)) 2.0 - >>> irm.predict(5) + >>> float(irm.predict(5)) 16.5 - >>> irm.predict(sc.parallelize([3, 5])).collect() + >>> list(map(float, irm.predict(sc.parallelize([3, 5])).collect())) [2.0, 16.5] >>> import os, tempfile >>> path = tempfile.mkdtemp() >>> irm.save(sc, path) >>> sameModel = IsotonicRegressionModel.load(sc, path) - >>> sameModel.predict(3) + >>> float(sameModel.predict(3)) 2.0 - >>> sameModel.predict(5) + >>> float(sameModel.predict(5)) 16.5 >>> from shutil import rmtree >>> try: diff --git a/python/pyspark/pandas/generic.py b/python/pyspark/pandas/generic.py index 6e63cff1d37b9..55f15fd2eb1a2 100644 --- a/python/pyspark/pandas/generic.py +++ b/python/pyspark/pandas/generic.py @@ -2631,7 +2631,7 @@ def first_valid_index(self) -> Optional[Union[Scalar, Tuple[Scalar, ...]]]: 500 5.0 dtype: float64 - >>> s.first_valid_index() + >>> int(s.first_valid_index()) 300 Support for MultiIndex @@ -2950,7 +2950,7 @@ def get(self, key: Any, default: Optional[Any] = None) -> Any: 20 1 b 20 2 b - >>> df.x.get(10) + >>> int(df.x.get(10)) 0 >>> df.x.get(20) @@ -3008,7 +3008,7 @@ def squeeze(self, axis: Optional[Axis] = None) -> Union[Scalar, "DataFrame", "Se 0 2 dtype: int64 - >>> even_primes.squeeze() + >>> int(even_primes.squeeze()) 2 Squeezing objects with more than one value in every axis does nothing: @@ -3066,7 +3066,7 @@ def squeeze(self, axis: Optional[Axis] = None) -> Union[Scalar, "DataFrame", "Se Squeezing all axes will project directly into a scalar: - >>> df_1a.squeeze() + >>> int(df_1a.squeeze()) 3 """ if axis is not None: diff --git a/python/pyspark/pandas/indexing.py b/python/pyspark/pandas/indexing.py index b5bf65a4907b7..c93366a31e315 100644 --- a/python/pyspark/pandas/indexing.py +++ b/python/pyspark/pandas/indexing.py @@ -122,7 +122,7 @@ class AtIndexer(IndexerLike): Get value at specified row/column pair - >>> psdf.at[4, 'B'] + >>> int(psdf.at[4, 'B']) 2 Get array if an index occurs multiple times @@ -202,7 +202,7 @@ class iAtIndexer(IndexerLike): Get value at specified row/column pair - >>> df.iat[1, 2] + >>> int(df.iat[1, 2]) 1 Get value within a series @@ -214,7 +214,7 @@ class iAtIndexer(IndexerLike): 30 3 dtype: int64 - >>> psser.iat[1] + >>> int(psser.iat[1]) 2 """ @@ -853,7 +853,7 @@ class LocIndexer(LocIndexerLike): Single label for column. - >>> df.loc['cobra', 'shield'] + >>> int(df.loc['cobra', 'shield']) 2 List of labels for row. diff --git a/python/pyspark/pandas/internal.py b/python/pyspark/pandas/internal.py index 6063641e22e3b..90c361547b814 100644 --- a/python/pyspark/pandas/internal.py +++ b/python/pyspark/pandas/internal.py @@ -909,14 +909,7 @@ def attach_sequence_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDa @staticmethod def attach_distributed_column(sdf: PySparkDataFrame, column_name: str) -> PySparkDataFrame: - scols = [scol_for(sdf, column) for column in sdf.columns] - # Does not add an alias to avoid having some changes in protobuf definition for now. - # The alias is more for query strings in DataFrame.explain, and they are cosmetic changes. - if is_remote(): - return sdf.select(F.monotonically_increasing_id().alias(column_name), *scols) - jvm = sdf.sparkSession._jvm - jcol = jvm.PythonSQLUtils.distributedIndex() - return sdf.select(PySparkColumn(jcol).alias(column_name), *scols) + return sdf.select(SF.distributed_id().alias(column_name), "*") @staticmethod def attach_distributed_sequence_column( diff --git a/python/pyspark/pandas/plot/core.py b/python/pyspark/pandas/plot/core.py index 7333fae1ad432..12c17a06f153b 100644 --- a/python/pyspark/pandas/plot/core.py +++ b/python/pyspark/pandas/plot/core.py @@ -841,7 +841,7 @@ def barh(self, x=None, y=None, **kwargs): elif isinstance(self.data, DataFrame): return self(kind="barh", x=x, y=y, **kwargs) - def box(self, **kwds): + def box(self, precision=0.01, **kwds): """ Make a box plot of the DataFrame columns. @@ -857,14 +857,13 @@ def box(self, **kwds): Parameters ---------- - **kwds : optional - Additional keyword arguments are documented in - :meth:`pyspark.pandas.Series.plot`. - precision: scalar, default = 0.01 This argument is used by pandas-on-Spark to compute approximate statistics for building a boxplot. Use *smaller* values to get more precise - statistics (matplotlib-only). + statistics. + **kwds : optional + Additional keyword arguments are documented in + :meth:`pyspark.pandas.Series.plot`. Returns ------- @@ -902,7 +901,7 @@ def box(self, **kwds): from pyspark.pandas import DataFrame, Series if isinstance(self.data, (Series, DataFrame)): - return self(kind="box", **kwds) + return self(kind="box", precision=precision, **kwds) def hist(self, bins=10, **kwds): """ diff --git a/python/pyspark/pandas/series.py b/python/pyspark/pandas/series.py index ff941b692f95f..7e276860fbab1 100644 --- a/python/pyspark/pandas/series.py +++ b/python/pyspark/pandas/series.py @@ -4558,7 +4558,7 @@ def pop(self, item: Name) -> Union["Series", Scalar]: C 2 dtype: int64 - >>> s.pop('A') + >>> int(s.pop('A')) 0 >>> s @@ -5821,7 +5821,7 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: A scalar `where`. - >>> s.asof(20) + >>> float(s.asof(20)) 2.0 For a sequence `where`, a Series is returned. The first value is @@ -5836,12 +5836,12 @@ def asof(self, where: Union[Any, List]) -> Union[Scalar, "Series"]: Missing values are not considered. The following is ``2.0``, not NaN, even though NaN is at the index location for ``30``. - >>> s.asof(30) + >>> float(s.asof(30)) 2.0 >>> s = ps.Series([1, 2, np.nan, 4], index=[10, 30, 20, 40]) >>> with ps.option_context("compute.eager_check", False): - ... s.asof(20) + ... float(s.asof(20)) ... 1.0 """ diff --git a/python/pyspark/pandas/spark/functions.py b/python/pyspark/pandas/spark/functions.py index bdd11559df3b6..53146a163b1ef 100644 --- a/python/pyspark/pandas/spark/functions.py +++ b/python/pyspark/pandas/spark/functions.py @@ -79,6 +79,10 @@ def null_index(col: Column) -> Column: return _invoke_internal_function_over_columns("null_index", col) +def distributed_id() -> Column: + return _invoke_internal_function_over_columns("distributed_id") + + def distributed_sequence_id() -> Column: return _invoke_internal_function_over_columns("distributed_sequence_id") diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index a51156e895c62..f6c1278c0dc7a 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -83,7 +83,9 @@ def test_function_parity(self): missing_in_py = jvm_fn_set.difference(py_fn_set) # Functions that we expect to be missing in python until they are added to pyspark - expected_missing_in_py = set() + expected_missing_in_py = set( + ["is_valid_utf8", "make_valid_utf8", "validate_utf8", "try_validate_utf8"] + ) self.assertEqual( expected_missing_in_py, missing_in_py, "Missing functions in pyspark not as expected" diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index 770a550030f51..5a10aa797c1b1 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -29,43 +29,8 @@ Spark Project YARN yarn - 1.19 - - - hadoop-3 - - true - - - - org.apache.hadoop - hadoop-client-runtime - ${hadoop.version} - ${hadoop.deps.scope} - - - org.apache.hadoop - hadoop-client-minicluster - ${hadoop.version} - test - - - - org.bouncycastle - bcprov-jdk18on - test - - - org.bouncycastle - bcpkix-jdk18on - test - - - - - org.apache.spark @@ -102,6 +67,35 @@ org.apache.hadoop hadoop-client-api ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-client-runtime + ${hadoop.version} + ${hadoop.deps.scope} + + + org.apache.hadoop + hadoop-client-minicluster + ${hadoop.version} + test + + + + javax.xml.bind + jaxb-api + test + + + org.bouncycastle + bcprov-jdk18on + test + + + org.bouncycastle + bcpkix-jdk18on + test @@ -135,22 +129,6 @@ - - - org.eclipse.jetty.orbit - javax.servlet.jsp - 2.2.0.v201112011158 - test - - - org.eclipse.jetty.orbit - javax.servlet.jsp.jstl - 1.2.0.v201105211821 - test - - org.mockito mockito-core @@ -166,65 +144,6 @@ byte-buddy-agent test - - - - com.sun.jersey - jersey-core - test - ${jersey-1.version} - - - com.sun.jersey - jersey-json - test - ${jersey-1.version} - - - com.sun.jersey - jersey-server - test - ${jersey-1.version} - - - com.sun.jersey.contribs - jersey-guice - test - ${jersey-1.version} - - - com.sun.jersey - jersey-servlet - test - ${jersey-1.version} - - - - - ${hive.group} - hive-exec - ${hive.classifier} - provided - - - ${hive.group} - hive-metastore - provided - - - org.apache.thrift - libthrift - provided - - - org.apache.thrift - libfb303 - provided - diff --git a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala index 4838bc5298bb3..4a9a20efd3a56 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/functions.scala @@ -3911,6 +3911,44 @@ object functions { def encode(value: Column, charset: String): Column = Column.fn("encode", value, lit(charset)) + /** + * Returns true if the input is a valid UTF-8 string, otherwise returns false. + * + * @group string_funcs + * @since 4.0.0 + */ + def is_valid_utf8(str: Column): Column = + Column.fn("is_valid_utf8", str) + + /** + * Returns a new string in which all invalid UTF-8 byte sequences, if any, are replaced by the + * Unicode replacement character (U+FFFD). + * + * @group string_funcs + * @since 4.0.0 + */ + def make_valid_utf8(str: Column): Column = + Column.fn("make_valid_utf8", str) + + /** + * Returns the input value if it corresponds to a valid UTF-8 string, or emits a + * SparkIllegalArgumentException exception otherwise. + * + * @group string_funcs + * @since 4.0.0 + */ + def validate_utf8(str: Column): Column = + Column.fn("validate_utf8", str) + + /** + * Returns the input value if it corresponds to a valid UTF-8 string, or NULL otherwise. + * + * @group string_funcs + * @since 4.0.0 + */ + def try_validate_utf8(str: Column): Column = + Column.fn("try_validate_utf8", str) + /** * Formats numeric column x to a format like '#,###,###.##', rounded to d decimal places with * HALF_EVEN round mode, and returns the result as a string column. diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index 721e6a60befe2..12a2879794b10 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -58,7 +58,7 @@ public int numElements() { private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) { if (data.hasNull()) { for (int i = 0; i < length; i++) { - if (data.isNullAt(i)) { + if (data.isNullAt(offset + i)) { arrayData.setNullAt(i); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5d41c07b47842..49f3092390536 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -39,7 +39,6 @@ import org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.AlwaysProcess import org.apache.spark.sql.catalyst.trees.CurrentOrigin.withOrigin import org.apache.spark.sql.catalyst.trees.TreePattern._ @@ -203,6 +202,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor with CheckAnalysis with SQLConfHelper with ColumnResolutionHelper { private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog + private val relationResolution = new RelationResolution(catalogManager) override protected def validatePlanChanges( previousPlan: LogicalPlan, @@ -972,30 +972,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } - private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty - private def isReferredTempViewName(nameParts: Seq[String]): Boolean = { - AnalysisContext.get.referredTempViewNames.exists { n => - (n.length == nameParts.length) && n.zip(nameParts).forall { - case (a, b) => resolver(a, b) - } - } - } - - // If we are resolving database objects (relations, functions, etc.) insides views, we may need to - // expand single or multi-part identifiers with the current catalog and namespace of when the - // view was created. - private def expandIdentifier(nameParts: Seq[String]): Seq[String] = { - if (!isResolvingView || isReferredTempViewName(nameParts)) return nameParts - - if (nameParts.length == 1) { - AnalysisContext.get.catalogAndNamespace :+ nameParts.head - } else if (catalogManager.isCatalogRegistered(nameParts.head)) { - nameParts - } else { - AnalysisContext.get.catalogAndNamespace.head +: nameParts - } - } - /** * Adds metadata columns to output for child relations when nodes are missing resolved attributes. * @@ -1122,7 +1098,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case i @ InsertIntoStatement(table, _, _, _, _, _, _) => val relation = table match { case u: UnresolvedRelation if !u.isStreaming => - resolveRelation(u).getOrElse(u) + relationResolution.resolveRelation(u).getOrElse(u) case other => other } @@ -1139,7 +1115,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case write: V2WriteCommand => write.table match { case u: UnresolvedRelation if !u.isStreaming => - resolveRelation(u).map(unwrapRelationPlan).map { + relationResolution.resolveRelation(u).map(unwrapRelationPlan).map { case v: View => throw QueryCompilationErrors.writeIntoViewNotAllowedError( v.desc.identifier, write) case r: DataSourceV2Relation => write.withNewTable(r) @@ -1154,12 +1130,12 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } case u: UnresolvedRelation => - resolveRelation(u).map(resolveViews).getOrElse(u) + relationResolution.resolveRelation(u).map(resolveViews).getOrElse(u) case r @ RelationTimeTravel(u: UnresolvedRelation, timestamp, version) if timestamp.forall(ts => ts.resolved && !SubqueryExpression.hasSubquery(ts)) => val timeTravelSpec = TimeTravelSpec.create(timestamp, version, conf.sessionLocalTimeZone) - resolveRelation(u, timeTravelSpec).getOrElse(r) + relationResolution.resolveRelation(u, timeTravelSpec).getOrElse(r) case u @ UnresolvedTable(identifier, cmd, suggestAlternative) => lookupTableOrView(identifier).map { @@ -1194,29 +1170,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor }.getOrElse(u) } - private def lookupTempView(identifier: Seq[String]): Option[TemporaryViewRelation] = { - // We are resolving a view and this name is not a temp view when that view was created. We - // return None earlier here. - if (isResolvingView && !isReferredTempViewName(identifier)) return None - v1SessionCatalog.getRawLocalOrGlobalTempView(identifier) - } - - private def resolveTempView( - identifier: Seq[String], - isStreaming: Boolean = false, - isTimeTravel: Boolean = false): Option[LogicalPlan] = { - lookupTempView(identifier).map { v => - val tempViewPlan = v1SessionCatalog.getTempViewRelation(v) - if (isStreaming && !tempViewPlan.isStreaming) { - throw QueryCompilationErrors.readNonStreamingTempViewError(identifier.quoted) - } - if (isTimeTravel) { - throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(identifier)) - } - tempViewPlan - } - } - /** * Resolves relations to `ResolvedTable` or `Resolved[Temp/Persistent]View`. This is * for resolving DDL and misc commands. @@ -1224,10 +1177,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private def lookupTableOrView( identifier: Seq[String], viewOnly: Boolean = false): Option[LogicalPlan] = { - lookupTempView(identifier).map { tempView => + relationResolution.lookupTempView(identifier).map { tempView => ResolvedTempView(identifier.asIdentifier, tempView.tableMeta) }.orElse { - expandIdentifier(identifier) match { + relationResolution.expandIdentifier(identifier) match { case CatalogAndIdentifier(catalog, ident) => if (viewOnly && !CatalogV2Util.isSessionCatalog(catalog)) { throw QueryCompilationErrors.catalogOperationNotSupported(catalog, "views") @@ -1246,113 +1199,6 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor } } } - - private def createRelation( - catalog: CatalogPlugin, - ident: Identifier, - table: Option[Table], - options: CaseInsensitiveStringMap, - isStreaming: Boolean): Option[LogicalPlan] = { - table.map { - // To utilize this code path to execute V1 commands, e.g. INSERT, - // either it must be session catalog, or tracksPartitionsInCatalog - // must be false so it does not require use catalog to manage partitions. - // Obviously we cannot execute V1Table by V1 code path if the table - // is not from session catalog and the table still requires its catalog - // to manage partitions. - case v1Table: V1Table if CatalogV2Util.isSessionCatalog(catalog) - || !v1Table.catalogTable.tracksPartitionsInCatalog => - if (isStreaming) { - if (v1Table.v1Table.tableType == CatalogTableType.VIEW) { - throw QueryCompilationErrors.permanentViewNotSupportedByStreamingReadingAPIError( - ident.quoted) - } - SubqueryAlias( - catalog.name +: ident.asMultipartIdentifier, - UnresolvedCatalogRelation(v1Table.v1Table, options, isStreaming = true)) - } else { - v1SessionCatalog.getRelation(v1Table.v1Table, options) - } - - case table => - if (isStreaming) { - val v1Fallback = table match { - case withFallback: V2TableWithV1Fallback => - Some(UnresolvedCatalogRelation(withFallback.v1Table, isStreaming = true)) - case _ => None - } - SubqueryAlias( - catalog.name +: ident.asMultipartIdentifier, - StreamingRelationV2(None, table.name, table, options, table.columns.toAttributes, - Some(catalog), Some(ident), v1Fallback)) - } else { - SubqueryAlias( - catalog.name +: ident.asMultipartIdentifier, - DataSourceV2Relation.create(table, Some(catalog), Some(ident), options)) - } - } - } - - /** - * Resolves relations to v1 relation if it's a v1 table from the session catalog, or to v2 - * relation. This is for resolving DML commands and SELECT queries. - */ - private def resolveRelation( - u: UnresolvedRelation, - timeTravelSpec: Option[TimeTravelSpec] = None): Option[LogicalPlan] = { - val timeTravelSpecFromOptions = TimeTravelSpec.fromOptions( - u.options, - conf.getConf(SQLConf.TIME_TRAVEL_TIMESTAMP_KEY), - conf.getConf(SQLConf.TIME_TRAVEL_VERSION_KEY), - conf.sessionLocalTimeZone - ) - if (timeTravelSpec.nonEmpty && timeTravelSpecFromOptions.nonEmpty) { - throw new AnalysisException("MULTIPLE_TIME_TRAVEL_SPEC", Map.empty[String, String]) - } - val finalTimeTravelSpec = timeTravelSpec.orElse(timeTravelSpecFromOptions) - resolveTempView(u.multipartIdentifier, u.isStreaming, finalTimeTravelSpec.isDefined).orElse { - expandIdentifier(u.multipartIdentifier) match { - case CatalogAndIdentifier(catalog, ident) => - val key = - ((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq, - finalTimeTravelSpec) - AnalysisContext.get.relationCache.get(key).map { cache => - val cachedRelation = cache.transform { - case multi: MultiInstanceRelation => - val newRelation = multi.newInstance() - newRelation.copyTagsFrom(multi) - newRelation - } - u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => - val cachedConnectRelation = cachedRelation.clone() - cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) - cachedConnectRelation - }.getOrElse(cachedRelation) - }.orElse { - val writePrivilegesString = - Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) - val table = CatalogV2Util.loadTable( - catalog, ident, finalTimeTravelSpec, writePrivilegesString) - val loaded = createRelation( - catalog, ident, table, u.clearWritePrivileges.options, u.isStreaming) - loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) - u.getTagValue(LogicalPlan.PLAN_ID_TAG).map { planId => - loaded.map { loadedRelation => - val loadedConnectRelation = loadedRelation.clone() - loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) - loadedConnectRelation - } - }.getOrElse(loaded) - } - case _ => None - } - } - } - - /** Consumes an unresolved relation and resolves it to a v1 or v2 relation or temporary view. */ - def resolveRelationOrTempView(u: UnresolvedRelation): LogicalPlan = { - EliminateSubqueryAliases(resolveRelation(u).getOrElse(u)) - } } /** Handle INSERT INTO for DSv2 */ @@ -2135,7 +1981,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor if (ResolveFunctions.lookupBuiltinOrTempFunction(nameParts, Some(f)).isDefined) { f } else { - val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts) + val CatalogAndIdentifier(catalog, ident) = + relationResolution.expandIdentifier(nameParts) val fullName = normalizeFuncName((catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq) if (externalFunctionNameSet.contains(fullName)) { @@ -2186,7 +2033,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolvedNonPersistentFunc(nameParts.head, V1Function(info)) } }.getOrElse { - val CatalogAndIdentifier(catalog, ident) = expandIdentifier(nameParts) + val CatalogAndIdentifier(catalog, ident) = + relationResolution.expandIdentifier(nameParts) val fullName = catalog.name +: ident.namespace :+ ident.name CatalogV2Util.loadFunction(catalog, ident).map { func => ResolvedPersistentFunc(catalog.asFunctionCatalog, ident, func) @@ -2198,7 +2046,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor withPosition(u) { try { val resolvedFunc = resolveBuiltinOrTempTableFunction(u.name, u.functionArgs).getOrElse { - val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.name) + val CatalogAndIdentifier(catalog, ident) = + relationResolution.expandIdentifier(u.name) if (CatalogV2Util.isSessionCatalog(catalog)) { v1SessionCatalog.resolvePersistentTableFunction( ident.asFunctionIdentifier, u.functionArgs) @@ -2355,7 +2204,8 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor private[analysis] def resolveFunction(u: UnresolvedFunction): Expression = { withPosition(u) { resolveBuiltinOrTempFunction(u.nameParts, u.arguments, u).getOrElse { - val CatalogAndIdentifier(catalog, ident) = expandIdentifier(u.nameParts) + val CatalogAndIdentifier(catalog, ident) = + relationResolution.expandIdentifier(u.nameParts) if (CatalogV2Util.isSessionCatalog(catalog)) { resolveV1Function(ident.asFunctionIdentifier, u.arguments, u) } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d03d8114e9976..abe61619a2331 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -895,9 +895,20 @@ object FunctionRegistry { /** Registry for internal functions used by Connect and the Column API. */ private[sql] val internal: SimpleFunctionRegistry = new SimpleFunctionRegistry - private def registerInternalExpression[T <: Expression : ClassTag](name: String): Unit = { - val (info, builder) = FunctionRegistryBase.build(name, None) - internal.internalRegisterFunction(FunctionIdentifier(name), info, builder) + private def registerInternalExpression[T <: Expression : ClassTag]( + name: String, + setAlias: Boolean = false): Unit = { + val (info, builder) = FunctionRegistryBase.build[T](name, None) + val newBuilder = if (setAlias) { + (expressions: Seq[Expression]) => { + val expr = builder(expressions) + expr.setTagValue(FUNC_ALIAS, name) + expr + } + } else { + builder + } + internal.internalRegisterFunction(FunctionIdentifier(name), info, newBuilder) } registerInternalExpression[Product]("product") @@ -911,6 +922,7 @@ object FunctionRegistry { registerInternalExpression[Days]("days") registerInternalExpression[Hours]("hours") registerInternalExpression[UnwrapUDT]("unwrap_udt") + registerInternalExpression[MonotonicallyIncreasingID]("distributed_id", setAlias = true) registerInternalExpression[DistributedSequenceID]("distributed_sequence_id") registerInternalExpression[PandasProduct]("pandas_product") registerInternalExpression[PandasStddev]("pandas_stddev") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala new file mode 100644 index 0000000000000..08be456f090e2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.catalog.{ + CatalogTableType, + TemporaryViewRelation, + UnresolvedCatalogRelation +} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} +import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.connector.catalog.{ + CatalogManager, + CatalogPlugin, + CatalogV2Util, + Identifier, + LookupCatalog, + Table, + V1Table, + V2TableWithV1Fallback +} +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ +import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors} +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.ArrayImplicits._ + +class RelationResolution(override val catalogManager: CatalogManager) + extends DataTypeErrorsBase + with Logging + with LookupCatalog + with SQLConfHelper { + val v1SessionCatalog = catalogManager.v1SessionCatalog + + /** + * If we are resolving database objects (relations, functions, etc.) inside views, we may need to + * expand single or multi-part identifiers with the current catalog and namespace of when the + * view was created. + */ + def expandIdentifier(nameParts: Seq[String]): Seq[String] = { + if (!isResolvingView || isReferredTempViewName(nameParts)) { + return nameParts + } + + if (nameParts.length == 1) { + AnalysisContext.get.catalogAndNamespace :+ nameParts.head + } else if (catalogManager.isCatalogRegistered(nameParts.head)) { + nameParts + } else { + AnalysisContext.get.catalogAndNamespace.head +: nameParts + } + } + + /** + * Lookup temporary view by `identifier`. Returns `None` if the view wasn't found. + */ + def lookupTempView(identifier: Seq[String]): Option[TemporaryViewRelation] = { + // We are resolving a view and this name is not a temp view when that view was created. We + // return None earlier here. + if (isResolvingView && !isReferredTempViewName(identifier)) { + return None + } + + v1SessionCatalog.getRawLocalOrGlobalTempView(identifier) + } + + /** + * Resolve relation `u` to v1 relation if it's a v1 table from the session catalog, or to v2 + * relation. This is for resolving DML commands and SELECT queries. + */ + def resolveRelation( + u: UnresolvedRelation, + timeTravelSpec: Option[TimeTravelSpec] = None): Option[LogicalPlan] = { + val timeTravelSpecFromOptions = TimeTravelSpec.fromOptions( + u.options, + conf.getConf(SQLConf.TIME_TRAVEL_TIMESTAMP_KEY), + conf.getConf(SQLConf.TIME_TRAVEL_VERSION_KEY), + conf.sessionLocalTimeZone + ) + if (timeTravelSpec.nonEmpty && timeTravelSpecFromOptions.nonEmpty) { + throw new AnalysisException("MULTIPLE_TIME_TRAVEL_SPEC", Map.empty[String, String]) + } + val finalTimeTravelSpec = timeTravelSpec.orElse(timeTravelSpecFromOptions) + resolveTempView( + u.multipartIdentifier, + u.isStreaming, + finalTimeTravelSpec.isDefined + ).orElse { + expandIdentifier(u.multipartIdentifier) match { + case CatalogAndIdentifier(catalog, ident) => + val key = + ( + (catalog.name +: ident.namespace :+ ident.name).toImmutableArraySeq, + finalTimeTravelSpec + ) + AnalysisContext.get.relationCache + .get(key) + .map { cache => + val cachedRelation = cache.transform { + case multi: MultiInstanceRelation => + val newRelation = multi.newInstance() + newRelation.copyTagsFrom(multi) + newRelation + } + u.getTagValue(LogicalPlan.PLAN_ID_TAG) + .map { planId => + val cachedConnectRelation = cachedRelation.clone() + cachedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + cachedConnectRelation + } + .getOrElse(cachedRelation) + } + .orElse { + val writePrivilegesString = + Option(u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES)) + val table = + CatalogV2Util.loadTable(catalog, ident, finalTimeTravelSpec, writePrivilegesString) + val loaded = createRelation( + catalog, + ident, + table, + u.clearWritePrivileges.options, + u.isStreaming + ) + loaded.foreach(AnalysisContext.get.relationCache.update(key, _)) + u.getTagValue(LogicalPlan.PLAN_ID_TAG) + .map { planId => + loaded.map { loadedRelation => + val loadedConnectRelation = loadedRelation.clone() + loadedConnectRelation.setTagValue(LogicalPlan.PLAN_ID_TAG, planId) + loadedConnectRelation + } + } + .getOrElse(loaded) + } + case _ => None + } + } + } + + private def createRelation( + catalog: CatalogPlugin, + ident: Identifier, + table: Option[Table], + options: CaseInsensitiveStringMap, + isStreaming: Boolean): Option[LogicalPlan] = { + table.map { + // To utilize this code path to execute V1 commands, e.g. INSERT, + // either it must be session catalog, or tracksPartitionsInCatalog + // must be false so it does not require use catalog to manage partitions. + // Obviously we cannot execute V1Table by V1 code path if the table + // is not from session catalog and the table still requires its catalog + // to manage partitions. + case v1Table: V1Table + if CatalogV2Util.isSessionCatalog(catalog) + || !v1Table.catalogTable.tracksPartitionsInCatalog => + if (isStreaming) { + if (v1Table.v1Table.tableType == CatalogTableType.VIEW) { + throw QueryCompilationErrors.permanentViewNotSupportedByStreamingReadingAPIError( + ident.quoted + ) + } + SubqueryAlias( + catalog.name +: ident.asMultipartIdentifier, + UnresolvedCatalogRelation(v1Table.v1Table, options, isStreaming = true) + ) + } else { + v1SessionCatalog.getRelation(v1Table.v1Table, options) + } + + case table => + if (isStreaming) { + val v1Fallback = table match { + case withFallback: V2TableWithV1Fallback => + Some(UnresolvedCatalogRelation(withFallback.v1Table, isStreaming = true)) + case _ => None + } + SubqueryAlias( + catalog.name +: ident.asMultipartIdentifier, + StreamingRelationV2( + None, + table.name, + table, + options, + table.columns.toAttributes, + Some(catalog), + Some(ident), + v1Fallback + ) + ) + } else { + SubqueryAlias( + catalog.name +: ident.asMultipartIdentifier, + DataSourceV2Relation.create(table, Some(catalog), Some(ident), options) + ) + } + } + } + + private def resolveTempView( + identifier: Seq[String], + isStreaming: Boolean = false, + isTimeTravel: Boolean = false): Option[LogicalPlan] = { + lookupTempView(identifier).map { v => + val tempViewPlan = v1SessionCatalog.getTempViewRelation(v) + if (isStreaming && !tempViewPlan.isStreaming) { + throw QueryCompilationErrors.readNonStreamingTempViewError(identifier.quoted) + } + if (isTimeTravel) { + throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(identifier)) + } + tempViewPlan + } + } + + private def isResolvingView: Boolean = AnalysisContext.get.catalogAndNamespace.nonEmpty + + private def isReferredTempViewName(nameParts: Seq[String]): Boolean = { + val resolver = conf.resolver + AnalysisContext.get.referredTempViewNames.exists { n => + (n.length == nameParts.length) && n.zip(nameParts).forall { + case (a, b) => resolver(a, b) + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 4212c12e96f77..9c617b51df62f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -32,11 +32,11 @@ import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeWithCaseA import org.apache.spark.sql.types.{DataType, MapType, StringType, StructType, VariantType} import org.apache.spark.unsafe.types.UTF8String -object ExprUtils extends QueryErrorsBase { +object ExprUtils extends EvalHelper with QueryErrorsBase { def evalTypeExpr(exp: Expression): DataType = { if (exp.foldable) { - exp.eval() match { + prepareForEval(exp).eval() match { case s: UTF8String if s != null => val dataType = DataType.parseTypeWithFallback( s.toString, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6a57ba2aaa569..bb32e518ec39a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -1347,9 +1347,21 @@ trait CommutativeExpression extends Expression { /** Collects adjacent commutative operations. */ private def gatherCommutative( e: Expression, - f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = e match { - case c: CommutativeExpression if f.isDefinedAt(c) => f(c).flatMap(gatherCommutative(_, f)) - case other => other.canonicalized :: Nil + f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = { + val resultBuffer = scala.collection.mutable.Buffer[Expression]() + val stack = scala.collection.mutable.Stack[Expression](e) + + // [SPARK-49977]: Use iterative approach to avoid creating many temporary List objects + // for deep expression trees through recursion. + while (stack.nonEmpty) { + stack.pop() match { + case c: CommutativeExpression if f.isDefinedAt(c) => + stack.pushAll(f(c)) + case other => + resultBuffer += other.canonicalized + } + } + resultBuffer.toSeq } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 0a4882bfada17..3270c6e87e2cd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.trees.UnaryLike -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils, UnsafeRowUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} import org.apache.spark.sql.types._ import org.apache.spark.util.BoundedPriorityQueue @@ -145,6 +145,7 @@ case class CollectList( """, group = "agg_funcs", since = "2.0.0") +// TODO: Make CollectSet collation aware case class CollectSet( child: Expression, mutableAggBufferOffset: Int = 0, @@ -178,14 +179,15 @@ case class CollectSet( } override def checkInputDataTypes(): TypeCheckResult = { - if (!child.dataType.existsRecursively(_.isInstanceOf[MapType])) { + if (!child.dataType.existsRecursively(_.isInstanceOf[MapType]) && + UnsafeRowUtils.isBinaryStable(child.dataType)) { TypeCheckResult.TypeCheckSuccess } else { DataTypeMismatch( errorSubClass = "UNSUPPORTED_INPUT_TYPE", messageParameters = Map( "functionName" -> toSQLId(prettyName), - "dataType" -> toSQLType(MapType) + "dataType" -> (s"${toSQLType(MapType)} " + "or \"COLLATED STRING\"") ) ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala index 13676733a9bad..d18630f542020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/intervalExpressions.scala @@ -336,7 +336,7 @@ case class MakeInterval( val iu = IntervalUtils.getClass.getName.stripSuffix("$") val secFrac = sec.getOrElse("0") val failOnErrorBranch = if (failOnError) { - "throw QueryExecutionErrors.arithmeticOverflowError(e);" + """throw QueryExecutionErrors.arithmeticOverflowError(e.getMessage(), "", null);""" } else { s"${ev.isNull} = true;" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala new file mode 100644 index 0000000000000..6291e62304a38 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.json + +import com.fasterxml.jackson.core.JsonFactory + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.ExprUtils +import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.util.{FailFastMode, FailureSafeParser, PermissiveMode} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType, VariantType} +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.Utils + +object JsonExpressionEvalUtils { + + def schemaOfJson( + jsonFactory: JsonFactory, + jsonOptions: JSONOptions, + jsonInferSchema: JsonInferSchema, + json: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + parser.nextToken() + // To match with schema inference from JSON datasource. + jsonInferSchema.inferField(parser) match { + case st: StructType => + jsonInferSchema.canonicalizeType(st, jsonOptions).getOrElse(StructType(Nil)) + case at: ArrayType if at.elementType.isInstanceOf[StructType] => + jsonInferSchema + .canonicalizeType(at.elementType, jsonOptions) + .map(ArrayType(_, containsNull = at.containsNull)) + .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) + case other: DataType => + jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse( + SQLConf.get.defaultStringType) + } + } + + UTF8String.fromString(dt.sql) + } +} + +class JsonToStructsEvaluator( + options: Map[String, String], + nullableSchema: DataType, + nameOfCorruptRecord: String, + timeZoneId: Option[String], + variantAllowDuplicateKeys: Boolean) extends Serializable { + + // This converts parsed rows to the desired output by the given schema. + @transient + private lazy val converter = nullableSchema match { + case _: StructType => + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null + case _: ArrayType => + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null + case _: MapType => + (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null + } + + @transient + private lazy val parser = { + val parsedOptions = new JSONOptions(options, timeZoneId.get, nameOfCorruptRecord) + val mode = parsedOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw QueryCompilationErrors.parseModeUnsupportedError("from_json", mode) + } + val (parserSchema, actualSchema) = nullableSchema match { + case s: StructType => + ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord) + (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))) + case other => + (StructType(Array(StructField("value", other))), other) + } + + val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false) + val createParser = CreateJacksonParser.utf8String _ + + new FailureSafeParser[UTF8String]( + input => rawParser.parse(input, createParser, identity[UTF8String]), + mode, + parserSchema, + parsedOptions.columnNameOfCorruptRecord) + } + + final def evaluate(json: UTF8String): Any = { + if (json == null) return null + nullableSchema match { + case _: VariantType => + VariantExpressionEvalUtils.parseJson(json, + allowDuplicateKeys = variantAllowDuplicateKeys) + case _ => + converter(parser.parse(json)) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index e01531cc821c9..6eef3d6f9d7df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -30,9 +30,8 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper -import org.apache.spark.sql.catalyst.expressions.json.JsonExpressionUtils +import org.apache.spark.sql.catalyst.expressions.json.{JsonExpressionEvalUtils, JsonExpressionUtils, JsonToStructsEvaluator} import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke -import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.catalyst.json._ import org.apache.spark.sql.catalyst.trees.TreePattern.{JSON_TO_STRUCT, TreePattern} import org.apache.spark.sql.catalyst.util._ @@ -639,7 +638,6 @@ case class JsonToStructs( variantAllowDuplicateKeys: Boolean = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)) extends UnaryExpression with TimeZoneAwareExpression - with CodegenFallback with ExpectsInputTypes with NullIntolerant with QueryErrorsBase { @@ -647,7 +645,7 @@ case class JsonToStructs( // The JSON input data might be missing certain fields. We force the nullability // of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder // can generate incorrect files if values are missing in columns declared as non-nullable. - val nullableSchema = schema.asNullable + private val nullableSchema: DataType = schema.asNullable override def nullable: Boolean = true @@ -680,53 +678,35 @@ case class JsonToStructs( messageParameters = Map("schema" -> toSQLType(nullableSchema))) } - // This converts parsed rows to the desired output by the given schema. - @transient - lazy val converter = nullableSchema match { - case _: StructType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next() else null - case _: ArrayType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getArray(0) else null - case _: MapType => - (rows: Iterator[InternalRow]) => if (rows.hasNext) rows.next().getMap(0) else null - } - - val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) - @transient lazy val parser = { - val parsedOptions = new JSONOptions(options, timeZoneId.get, nameOfCorruptRecord) - val mode = parsedOptions.parseMode - if (mode != PermissiveMode && mode != FailFastMode) { - throw QueryCompilationErrors.parseModeUnsupportedError("from_json", mode) - } - val (parserSchema, actualSchema) = nullableSchema match { - case s: StructType => - ExprUtils.verifyColumnNameOfCorruptRecord(s, parsedOptions.columnNameOfCorruptRecord) - (s, StructType(s.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))) - case other => - (StructType(Array(StructField("value", other))), other) - } - - val rawParser = new JacksonParser(actualSchema, parsedOptions, allowArrayAsStructs = false) - val createParser = CreateJacksonParser.utf8String _ - - new FailureSafeParser[UTF8String]( - input => rawParser.parse(input, createParser, identity[UTF8String]), - mode, - parserSchema, - parsedOptions.columnNameOfCorruptRecord) - } - override def dataType: DataType = nullableSchema override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) - override def nullSafeEval(json: Any): Any = nullableSchema match { - case _: VariantType => - VariantExpressionEvalUtils.parseJson(json.asInstanceOf[UTF8String], - allowDuplicateKeys = variantAllowDuplicateKeys) - case _ => - converter(parser.parse(json.asInstanceOf[UTF8String])) + @transient + private val nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD) + + @transient + private lazy val evaluator = new JsonToStructsEvaluator( + options, nullableSchema, nameOfCorruptRecord, timeZoneId, variantAllowDuplicateKeys) + + override def nullSafeEval(json: Any): Any = evaluator.evaluate(json.asInstanceOf[UTF8String]) + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val refEvaluator = ctx.addReferenceObj("evaluator", evaluator) + val eval = child.genCode(ctx) + val resultType = CodeGenerator.boxedType(dataType) + val resultTerm = ctx.freshName("result") + ev.copy(code = + code""" + |${eval.code} + |$resultType $resultTerm = ($resultType) $refEvaluator.evaluate(${eval.value}); + |boolean ${ev.isNull} = $resultTerm == null; + |${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + |""".stripMargin) } override def inputTypes: Seq[AbstractDataType] = StringTypeWithCaseAccentSensitivity :: Nil @@ -878,7 +858,9 @@ case class StructsToJson( case class SchemaOfJson( child: Expression, options: Map[String, String]) - extends UnaryExpression with CodegenFallback with QueryErrorsBase { + extends UnaryExpression + with RuntimeReplaceable + with QueryErrorsBase { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -919,26 +901,20 @@ case class SchemaOfJson( } } - override def eval(v: InternalRow): Any = { - val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => - parser.nextToken() - // To match with schema inference from JSON datasource. - jsonInferSchema.inferField(parser) match { - case st: StructType => - jsonInferSchema.canonicalizeType(st, jsonOptions).getOrElse(StructType(Nil)) - case at: ArrayType if at.elementType.isInstanceOf[StructType] => - jsonInferSchema - .canonicalizeType(at.elementType, jsonOptions) - .map(ArrayType(_, containsNull = at.containsNull)) - .getOrElse(ArrayType(StructType(Nil), containsNull = at.containsNull)) - case other: DataType => - jsonInferSchema.canonicalizeType(other, jsonOptions).getOrElse( - SQLConf.get.defaultStringType) - } - } + @transient private lazy val jsonFactoryObjectType = ObjectType(classOf[JsonFactory]) + @transient private lazy val jsonOptionsObjectType = ObjectType(classOf[JSONOptions]) + @transient private lazy val jsonInferSchemaObjectType = ObjectType(classOf[JsonInferSchema]) - UTF8String.fromString(dt.sql) - } + override def replacement: Expression = StaticInvoke( + JsonExpressionEvalUtils.getClass, + dataType, + "schemaOfJson", + Seq(Literal(jsonFactory, jsonFactoryObjectType), + Literal(jsonOptions, jsonOptionsObjectType), + Literal(jsonInferSchema, jsonInferSchemaObjectType), + child), + Seq(jsonFactoryObjectType, jsonOptionsObjectType, jsonInferSchemaObjectType, child.dataType) + ) override def prettyName: String = "schema_of_json" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 3cec83facd01d..16bdaa1f7f708 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, UnresolvedSeed} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch -import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLType} +import org.apache.spark.sql.catalyst.expressions.ExpectsInputTypes.{ordinalNumber, toSQLExpr, toSQLId, toSQLType} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral} import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.trees.{BinaryLike, TernaryLike, UnaryLike} @@ -263,7 +263,7 @@ case class Uniform(min: Expression, max: Expression, seedExpression: Expression, result = DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( - "inputName" -> name, + "inputName" -> toSQLId(name), "inputType" -> requiredType, "inputExpr" -> toSQLExpr(expr))) } else expr.dataType match { @@ -374,14 +374,14 @@ case class RandStr( var result: TypeCheckResult = TypeCheckResult.TypeCheckSuccess def requiredType = "INT or SMALLINT" Seq((length, "length", 0), - (seedExpression, "seedExpression", 1)).foreach { + (seedExpression, "seed", 1)).foreach { case (expr: Expression, name: String, index: Int) => if (result == TypeCheckResult.TypeCheckSuccess) { if (!expr.foldable) { result = DataTypeMismatch( errorSubClass = "NON_FOLDABLE_INPUT", messageParameters = Map( - "inputName" -> name, + "inputName" -> toSQLId(name), "inputType" -> requiredType, "inputExpr" -> toSQLExpr(expr))) } else expr.dataType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala index 832af340c3397..d23d43acc217b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala @@ -111,7 +111,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup // Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule. case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => empty(p) case LeftSemi if isRightEmpty | isFalseCondition => empty(p) - case LeftAnti if isRightEmpty | isFalseCondition => p.left + case LeftAnti if (isRightEmpty | isFalseCondition) && canExecuteWithoutJoin(p.left) => + p.left case FullOuter if isLeftEmpty && isRightEmpty => empty(p) case LeftOuter | FullOuter if isRightEmpty && canExecuteWithoutJoin(p.left) => Project(p.left.output ++ nullValueProjectList(p.right), p.left) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index 40b8bccafaad2..118dd92c3ed54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -207,7 +207,7 @@ object UnsafeRowUtils { def isBinaryStable(dataType: DataType): Boolean = !dataType.existsRecursively { case st: StringType => val collation = CollationFactory.fetchCollation(st.collationId) - (!collation.supportsBinaryEquality || collation.supportsSpaceTrimming) + (!collation.supportsBinaryEquality) case _ => false } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e942fd506fdb9..ddad2657ff474 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical.{Assignment, InputParameter, Join, LogicalPlan, SerdeInfo, Window} import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} -import org.apache.spark.sql.catalyst.util.{quoteIdentifier, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ParseMode} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.functions.{BoundFunction, UnboundFunction} @@ -1342,12 +1342,10 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat def parseModeUnsupportedError(funcName: String, mode: ParseMode): Throwable = { new AnalysisException( - errorClass = "_LEGACY_ERROR_TEMP_1099", + errorClass = "PARSE_MODE_UNSUPPORTED", messageParameters = Map( - "funcName" -> funcName, - "mode" -> mode.name, - "permissiveMode" -> PermissiveMode.name, - "failFastMode" -> FailFastMode.name)) + "funcName" -> toSQLId(funcName), + "mode" -> mode.name)) } def nonFoldableArgumentError( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 6e64e7e9e39bf..edc1b909292df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -599,16 +599,6 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE messageParameters = Map("methodName" -> methodName)) } - def arithmeticOverflowError(e: ArithmeticException): SparkArithmeticException = { - new SparkArithmeticException( - errorClass = "_LEGACY_ERROR_TEMP_2042", - messageParameters = Map( - "message" -> e.getMessage, - "ansiConfig" -> toSQLConf(SQLConf.ANSI_ENABLED.key)), - context = Array.empty, - summary = "") - } - def binaryArithmeticCauseOverflowError( eval1: Short, symbol: String, @@ -870,7 +860,7 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def cannotRemoveReservedPropertyError(property: String): SparkUnsupportedOperationException = { new SparkUnsupportedOperationException( - errorClass = "_LEGACY_ERROR_TEMP_2069", + errorClass = "CANNOT_REMOVE_RESERVED_PROPERTY", messageParameters = Map("property" -> property)) } @@ -1112,7 +1102,8 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE def cannotAcquireMemoryToBuildUnsafeHashedRelationError(): Throwable = { new SparkOutOfMemoryError( - "_LEGACY_ERROR_TEMP_2107") + "_LEGACY_ERROR_TEMP_2107", + new java.util.HashMap[String, String]()) } def rowLargerThan256MUnsupportedError(): SparkUnsupportedOperationException = { @@ -1257,6 +1248,12 @@ private[sql] object QueryExecutionErrors extends QueryErrorsBase with ExecutionE "dataType" -> toSQLType(dataType))) } + def wrongDatatypeInSomeRows(pos: Int, dataType: DataType): SparkSQLException = { + new SparkSQLException( + errorClass = "COLUMN_ARRAY_ELEMENT_TYPE_MISMATCH", + messageParameters = Map("pos" -> pos.toString(), "type" -> toSQLType(dataType))) + } + def rootConverterReturnNullError(): SparkRuntimeException = { new SparkRuntimeException( errorClass = "INVALID_JSON_ROOT_FIELD", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala index fe1ee0e6f338b..81dd8242c600b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CsvExpressionsSuite.scala @@ -149,12 +149,17 @@ class CsvExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with P test("unsupported mode") { val csvData = "---" val schema = StructType(StructField("a", DoubleType) :: Nil) - val exception = intercept[TestFailedException] { - checkEvaluation( - CsvToStructs(schema, Map("mode" -> DropMalformedMode.name), Literal(csvData), UTC_OPT), - InternalRow(null)) - }.getCause - assert(exception.getMessage.contains("from_csv() doesn't support the DROPMALFORMED mode")) + + checkError( + exception = intercept[TestFailedException] { + checkEvaluation( + CsvToStructs(schema, Map("mode" -> DropMalformedMode.name), Literal(csvData), UTC_OPT), + InternalRow(null)) + }.getCause.asInstanceOf[AnalysisException], + condition = "PARSE_MODE_UNSUPPORTED", + parameters = Map( + "funcName" -> "`from_csv`", + "mode" -> "DROPMALFORMED")) } test("infer schema of CSV strings") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 6f3890cafd2ac..92ef24bb8ec63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -636,7 +636,7 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(murmur3Hash1, interpretedHash1) checkEvaluation(murmur3Hash2, interpretedHash2) - if (CollationFactory.fetchCollation(collation).supportsBinaryEquality) { + if (CollationFactory.fetchCollation(collation).isUtf8BinaryType) { assert(interpretedHash1 != interpretedHash2) } else { assert(interpretedHash1 == interpretedHash2) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala index 7caf23490a0ce..78bc77b9dc2ab 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/IntervalExpressionsSuite.scala @@ -266,7 +266,7 @@ class IntervalExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val intervalExpr = MakeInterval(Literal(years), Literal(months), Literal(weeks), Literal(days), Literal(hours), Literal(minutes), Literal(Decimal(secFrac, Decimal.MAX_LONG_DIGITS, 6))) - checkExceptionInExpression[ArithmeticException](intervalExpr, EmptyRow, "") + checkExceptionInExpression[ArithmeticException](intervalExpr, EmptyRow, "ARITHMETIC_OVERFLOW") } withSQLConf(SQLConf.ANSI_ENABLED.key -> "true") { diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain index 8ec799bc58084..b400aeeca5af2 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json.explain @@ -1,2 +1,2 @@ -Project [schema_of_json([{"col":01}]) AS schema_of_json([{"col":01}])#0] +Project [static_invoke(JsonExpressionEvalUtils.schemaOfJson(com.fasterxml.jackson.core.JsonFactory, org.apache.spark.sql.catalyst.json.JSONOptions, org.apache.spark.sql.catalyst.json.JsonInferSchema, [{"col":01}])) AS schema_of_json([{"col":01}])#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain index 13867949177a4..b400aeeca5af2 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_json_with_options.explain @@ -1,2 +1,2 @@ -Project [schema_of_json([{"col":01}], (allowNumericLeadingZeros,true)) AS schema_of_json([{"col":01}])#0] +Project [static_invoke(JsonExpressionEvalUtils.schemaOfJson(com.fasterxml.jackson.core.JsonFactory, org.apache.spark.sql.catalyst.json.JSONOptions, org.apache.spark.sql.catalyst.json.JsonInferSchema, [{"col":01}])) AS schema_of_json([{"col":01}])#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala index 7a0c067ab430b..445f40d25edcd 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListener.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} +import java.util.concurrent.atomic.AtomicReference import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -41,7 +42,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { // The server side listener that is responsible to stream streaming query events back to client. // There is only one listener per sessionHolder, but each listener is responsible for all events // of all streaming queries in the SparkSession. - var streamingQueryServerSideListener: Option[SparkConnectListenerBusListener] = None + var streamingQueryServerSideListener: AtomicReference[SparkConnectListenerBusListener] = + new AtomicReference() // The cache for QueryStartedEvent, key is query runId and value is the actual QueryStartedEvent. // Events for corresponding query will be sent back to client with // the WriteStreamOperationStart response, so that the client can handle the event before @@ -50,10 +52,8 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { val streamingQueryStartedEventCache : ConcurrentMap[String, StreamingQueryListener.QueryStartedEvent] = new ConcurrentHashMap() - val lock = new Object() - - def isServerSideListenerRegistered: Boolean = lock.synchronized { - streamingQueryServerSideListener.isDefined + def isServerSideListenerRegistered: Boolean = { + streamingQueryServerSideListener.getAcquire() != null } /** @@ -65,10 +65,10 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * @param responseObserver * the responseObserver created from the first long running executeThread. */ - def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = lock.synchronized { + def init(responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val serverListener = new SparkConnectListenerBusListener(this, responseObserver) sessionHolder.session.streams.addListener(serverListener) - streamingQueryServerSideListener = Some(serverListener) + streamingQueryServerSideListener.setRelease(serverListener) } /** @@ -77,13 +77,13 @@ private[sql] class ServerSideListenerHolder(val sessionHolder: SessionHolder) { * exception. It removes the listener from the session, clears the cache. Also it sends back the * final ResultComplete response. */ - def cleanUp(): Unit = lock.synchronized { - streamingQueryServerSideListener.foreach { listener => + def cleanUp(): Unit = { + var listener = streamingQueryServerSideListener.getAndSet(null) + if (listener != null) { sessionHolder.session.streams.removeListener(listener) listener.sendResultComplete() + streamingQueryStartedEventCache.clear() } - streamingQueryStartedEventCache.clear() - streamingQueryServerSideListener = None } } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala index 48492bac62344..3da2548b456e8 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCache.scala @@ -17,11 +17,8 @@ package org.apache.spark.sql.connect.service -import java.util.concurrent.Executors -import java.util.concurrent.ScheduledExecutorService -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, Executors, ScheduledExecutorService, TimeUnit} import java.util.concurrent.atomic.AtomicReference -import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} @@ -61,36 +58,34 @@ private[connect] class SparkConnectStreamingQueryCache( sessionHolder: SessionHolder, query: StreamingQuery, tags: Set[String], - operationId: String): Unit = queryCacheLock.synchronized { - taggedQueriesLock.synchronized { - val value = QueryCacheValue( - userId = sessionHolder.userId, - sessionId = sessionHolder.sessionId, - session = sessionHolder.session, - query = query, - operationId = operationId, - expiresAtMs = None) - - val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) - tags.foreach { tag => - taggedQueries - .getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]) - .addOne(queryKey) - } - - queryCache.put(queryKey, value) match { - case Some(existing) => // Query is being replace. Not really expected. + operationId: String): Unit = { + val value = QueryCacheValue( + userId = sessionHolder.userId, + sessionId = sessionHolder.sessionId, + session = sessionHolder.session, + query = query, + operationId = operationId, + expiresAtMs = None) + + val queryKey = QueryCacheKey(query.id.toString, query.runId.toString) + tags.foreach { tag => addTaggedQuery(tag, queryKey) } + + queryCache.compute( + queryKey, + (key, existing) => { + if (existing != null) { // The query is being replaced: allowed, though not expected. logWarning(log"Replacing existing query in the cache (unexpected). " + log"Query Id: ${MDC(QUERY_ID, query.id)}.Existing value ${MDC(OLD_VALUE, existing)}, " + log"new value ${MDC(NEW_VALUE, value)}.") - case None => + } else { logInfo( log"Adding new query to the cache. Query Id ${MDC(QUERY_ID, query.id)}, " + log"value ${MDC(QUERY_CACHE_VALUE, value)}.") - } + } + value + }) - schedulePeriodicChecks() // Starts the scheduler thread if it hasn't started. - } + schedulePeriodicChecks() // Start the scheduler thread if it has not been started. } /** @@ -104,44 +99,35 @@ private[connect] class SparkConnectStreamingQueryCache( runId: String, tags: Set[String], session: SparkSession): Option[QueryCacheValue] = { - taggedQueriesLock.synchronized { - val key = QueryCacheKey(queryId, runId) - val result = getCachedQuery(QueryCacheKey(queryId, runId), session) - tags.foreach { tag => - taggedQueries.getOrElseUpdate(tag, new mutable.ArrayBuffer[QueryCacheKey]).addOne(key) - } - result - } + val queryKey = QueryCacheKey(queryId, runId) + val result = getCachedQuery(QueryCacheKey(queryId, runId), session) + tags.foreach { tag => addTaggedQuery(tag, queryKey) } + result } /** * Similar with [[getCachedQuery]] but it gets queries tagged previously. */ def getTaggedQuery(tag: String, session: SparkSession): Seq[QueryCacheValue] = { - taggedQueriesLock.synchronized { - taggedQueries - .get(tag) - .map { k => - k.flatMap(getCachedQuery(_, session)).toSeq - } - .getOrElse(Seq.empty[QueryCacheValue]) - } + val queryKeySet = Option(taggedQueries.get(tag)) + queryKeySet + .map(_.flatMap(k => getCachedQuery(k, session))) + .getOrElse(Seq.empty[QueryCacheValue]) } private def getCachedQuery( key: QueryCacheKey, session: SparkSession): Option[QueryCacheValue] = { - queryCacheLock.synchronized { - queryCache.get(key).flatMap { v => - if (v.session == session) { - v.expiresAtMs.foreach { _ => - // Extend the expiry time as the client is accessing it. - val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis - queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) - } - Some(v) - } else None // Should be rare, may be client is trying access from a different session. - } + val value = Option(queryCache.get(key)) + value.flatMap { v => + if (v.session == session) { + v.expiresAtMs.foreach { _ => + // Extend the expiry time as the client is accessing it. + val expiresAtMs = clock.getTimeMillis() + stoppedQueryInactivityTimeout.toMillis + queryCache.put(key, v.copy(expiresAtMs = Some(expiresAtMs))) + } + Some(v) + } else None // Should be rare, may be client is trying access from a different session. } } @@ -154,7 +140,7 @@ private[connect] class SparkConnectStreamingQueryCache( sessionHolder: SessionHolder, blocking: Boolean = true): Seq[String] = { val operationIds = new mutable.ArrayBuffer[String]() - for ((k, v) <- queryCache) { + queryCache.forEach((k, v) => { if (v.userId.equals(sessionHolder.userId) && v.sessionId.equals(sessionHolder.sessionId)) { if (v.query.isActive && Option(v.session.streams.get(k.queryId)).nonEmpty) { logInfo( @@ -178,29 +164,27 @@ private[connect] class SparkConnectStreamingQueryCache( } } } - } + }) operationIds.toSeq } // Visible for testing private[service] def getCachedValue(queryId: String, runId: String): Option[QueryCacheValue] = - queryCache.get(QueryCacheKey(queryId, runId)) + Option(queryCache.get(QueryCacheKey(queryId, runId))) // Visible for testing. - private[service] def shutdown(): Unit = queryCacheLock.synchronized { + private[service] def shutdown(): Unit = { val executor = scheduledExecutor.getAndSet(null) if (executor != null) { ThreadUtils.shutdown(executor, FiniteDuration(1, TimeUnit.MINUTES)) } } - @GuardedBy("queryCacheLock") - private val queryCache = new mutable.HashMap[QueryCacheKey, QueryCacheValue] - private val queryCacheLock = new Object + private val queryCache: ConcurrentMap[QueryCacheKey, QueryCacheValue] = + new ConcurrentHashMap[QueryCacheKey, QueryCacheValue] - @GuardedBy("queryCacheLock") - private val taggedQueries = new mutable.HashMap[String, mutable.ArrayBuffer[QueryCacheKey]] - private val taggedQueriesLock = new Object + private[service] val taggedQueries: ConcurrentMap[String, QueryCacheKeySet] = + new ConcurrentHashMap[String, QueryCacheKeySet] private var scheduledExecutor: AtomicReference[ScheduledExecutorService] = new AtomicReference[ScheduledExecutorService]() @@ -228,62 +212,109 @@ private[connect] class SparkConnectStreamingQueryCache( } } + private def addTaggedQuery(tag: String, queryKey: QueryCacheKey): Unit = { + taggedQueries.compute( + tag, + (k, v) => { + if (v == null || !v.addKey(queryKey)) { + // Create a new QueryCacheKeySet if the entry is absent or being removed. + var keys = mutable.HashSet.empty[QueryCacheKey] + keys.add(queryKey) + new QueryCacheKeySet(keys = keys) + } else { + v + } + }) + } + /** * Periodic maintenance task to do the following: * - Update status of query if it is inactive. Sets an expiry time for such queries * - Drop expired queries from the cache. */ - private def periodicMaintenance(): Unit = taggedQueriesLock.synchronized { + private def periodicMaintenance(): Unit = { + val nowMs = clock.getTimeMillis() - queryCacheLock.synchronized { - val nowMs = clock.getTimeMillis() + queryCache.forEach((k, v) => { + val id = k.queryId + val runId = k.runId + v.expiresAtMs match { - for ((k, v) <- queryCache) { - val id = k.queryId - val runId = k.runId - v.expiresAtMs match { + case Some(ts) if nowMs >= ts => // Expired. Drop references. + logInfo( + log"Removing references for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + + log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") + queryCache.remove(k) - case Some(ts) if nowMs >= ts => // Expired. Drop references. - logInfo( - log"Removing references for id: ${MDC(QUERY_ID, id)} " + - log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)} after expiry period") - queryCache.remove(k) + case Some(_) => // Inactive query waiting for expiration. Do nothing. + logInfo( + log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + + log"session ${MDC(SESSION_ID, v.sessionId)}") + + case None => // Active query, check if it is stopped. Enable timeout if it is stopped. + val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty - case Some(_) => // Inactive query waiting for expiration. Do nothing. + if (!isActive) { logInfo( - log"Waiting for the expiration for id: ${MDC(QUERY_ID, id)} " + + log"Marking query id: ${MDC(QUERY_ID, id)} " + log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)}") - - case None => // Active query, check if it is stopped. Enable timeout if it is stopped. - val isActive = v.query.isActive && Option(v.session.streams.get(id)).nonEmpty - - if (!isActive) { - logInfo( - log"Marking query id: ${MDC(QUERY_ID, id)} " + - log"runId: ${MDC(QUERY_RUN_ID, runId)} in " + - log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") - val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis - queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) - // To consider: Clean up any runner registered for this query with the session holder - // for this session. Useful in case listener events are delayed (such delays are - // seen in practice, especially when users have heavy processing inside listeners). - // Currently such workers would be cleaned up when the connect session expires. - } - } + log"session ${MDC(SESSION_ID, v.sessionId)} inactive.") + val expiresAtMs = nowMs + stoppedQueryInactivityTimeout.toMillis + queryCache.put(k, v.copy(expiresAtMs = Some(expiresAtMs))) + // To consider: Clean up any runner registered for this query with the session holder + // for this session. Useful in case listener events are delayed (such delays are + // seen in practice, especially when users have heavy processing inside listeners). + // Currently such workers would be cleaned up when the connect session expires. + } } + }) - taggedQueries.toArray.foreach { case (key, value) => - value.zipWithIndex.toArray.foreach { case (queryKey, i) => - if (queryCache.contains(queryKey)) { - value.remove(i) - } + // Removes any tagged queries that do not correspond to cached queries. + taggedQueries.forEach((key, value) => { + if (value.filter(k => queryCache.containsKey(k))) { + taggedQueries.remove(key, value) + } + }) + } + + case class QueryCacheKeySet(keys: mutable.HashSet[QueryCacheKey]) { + + /** Tries to add the key if the set is not empty, otherwise returns false. */ + def addKey(key: QueryCacheKey): Boolean = { + keys.synchronized { + if (keys.isEmpty) { + // The entry is about to be removed. + return false } + keys.add(key) + true + } + } - if (value.isEmpty) { - taggedQueries.remove(key) + /** Removes the key and returns true if the set is empty. */ + def removeKey(key: QueryCacheKey): Boolean = { + keys.synchronized { + if (keys.remove(key)) { + return keys.isEmpty } + false + } + } + + /** Removes entries that do not satisfy the predicate. */ + def filter(pred: QueryCacheKey => Boolean): Boolean = { + keys.synchronized { + keys.filterInPlace(k => pred(k)) + keys.isEmpty + } + } + + /** Iterates over entries, apply the function individually, and then flatten the result. */ + def flatMap[T](function: QueryCacheKey => Option[T]): Seq[T] = { + keys.synchronized { + keys.flatMap(k => function(k)).toSeq } } } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala index d856ffaabc316..2404dea21d91e 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectListenerBusListenerSuite.scala @@ -202,7 +202,8 @@ class SparkConnectListenerBusListenerSuite val listenerHolder = sessionHolder.streamingServersideListenerHolder eventually(timeout(5.seconds), interval(500.milliseconds)) { assert( - sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.isEmpty) + sessionHolder.streamingServersideListenerHolder.streamingQueryServerSideListener.get() == + null) assert(spark.streams.listListeners().size === listenerCntBeforeThrow) assert(listenerHolder.streamingQueryStartedEventCache.isEmpty) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala index 512a0a80c4a91..729a995f46145 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectStreamingQueryCacheSuite.scala @@ -48,6 +48,7 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug val queryId = UUID.randomUUID().toString val runId = UUID.randomUUID().toString + val tag = "test_tag" val mockSession = mock[SparkSession] val mockQuery = mock[StreamingQuery] val mockStreamingQueryManager = mock[StreamingQueryManager] @@ -67,13 +68,16 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug // Register the query. - sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set.empty[String], "") + sessionMgr.registerNewStreamingQuery(sessionHolder, mockQuery, Set(tag), "") sessionMgr.getCachedValue(queryId, runId) match { case Some(v) => assert(v.sessionId == sessionHolder.sessionId) assert(v.expiresAtMs.isEmpty, "No expiry time should be set for active query") + val taggedQueries = sessionMgr.getTaggedQuery(tag, mockSession) + assert(taggedQueries.contains(v)) + case None => assert(false, "Query should be found") } @@ -127,6 +131,9 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug assert(sessionMgr.getCachedValue(queryId, runId).map(_.query).contains(mockQuery)) assert( sessionMgr.getCachedValue(queryId, restartedRunId).map(_.query).contains(restartedQuery)) + eventually(timeout(1.minute)) { + assert(sessionMgr.taggedQueries.containsKey(tag)) + } // Advance time by 1 minute and verify the first query is dropped from the cache. clock.advance(1.minute.toMillis) @@ -144,8 +151,11 @@ class SparkConnectStreamingQueryCacheSuite extends SparkFunSuite with MockitoSug clock.advance(1.minute.toMillis) eventually(timeout(1.minute)) { assert(sessionMgr.getCachedValue(queryId, restartedRunId).isEmpty) + assert(sessionMgr.getTaggedQuery(tag, mockSession).isEmpty) + } + eventually(timeout(1.minute)) { + assert(!sessionMgr.taggedQueries.containsKey(tag)) } - sessionMgr.shutdown() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 08395ef4c347c..a66a6e54a7c8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -154,12 +154,6 @@ private[sql] object PythonSQLUtils extends Logging { def namedArgumentExpression(name: String, e: Column): Column = NamedArgumentExpression(name, e) - def distributedIndex(): Column = { - val expr = MonotonicallyIncreasingID() - expr.setTagValue(FunctionRegistry.FUNC_ALIAS, "distributed_index") - expr - } - @scala.annotation.varargs def fn(name: String, arguments: Column*): Column = Column.fn(name, arguments: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index ffab67b7cae24..77efc4793359f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -340,6 +340,7 @@ case class AdaptiveSparkPlanExec( }(AdaptiveSparkPlanExec.executionContext) } catch { case e: Throwable => + stage.error.set(Some(e)) cleanUpAndThrowException(Seq(e), Some(stage.id)) } } @@ -355,6 +356,7 @@ case class AdaptiveSparkPlanExec( case StageSuccess(stage, res) => stage.resultOption.set(Some(res)) case StageFailure(stage, ex) => + stage.error.set(Some(ex)) errors.append(ex) } @@ -600,6 +602,7 @@ case class AdaptiveSparkPlanExec( newStages = Seq(newStage)) case q: QueryStageExec => + assertStageNotFailed(q) CreateStageResult(newPlan = q, allChildStagesMaterialized = q.isMaterialized, newStages = Seq.empty) @@ -815,6 +818,15 @@ case class AdaptiveSparkPlanExec( } } + private def assertStageNotFailed(stage: QueryStageExec): Unit = { + if (stage.hasFailed) { + throw stage.error.get().get match { + case fatal: SparkFatalException => fatal.throwable + case other => other + } + } + } + /** * Cancel all running stages with best effort and throw an Exception containing all stage * materialization errors and stage cancellation errors. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala index 51595e20ae5f8..2391fe740118d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryStageExec.scala @@ -93,6 +93,13 @@ abstract class QueryStageExec extends LeafExecNode { private[adaptive] def resultOption: AtomicReference[Option[Any]] = _resultOption final def isMaterialized: Boolean = resultOption.get().isDefined + @transient + @volatile + protected var _error = new AtomicReference[Option[Throwable]](None) + + def error: AtomicReference[Option[Throwable]] = _error + final def hasFailed: Boolean = _error.get().isDefined + override def output: Seq[Attribute] = plan.output override def outputPartitioning: Partitioning = plan.outputPartitioning override def outputOrdering: Seq[SortOrder] = plan.outputOrdering @@ -203,6 +210,7 @@ case class ShuffleQueryStageExec( ReusedExchangeExec(newOutput, shuffle), _canonicalized) reuse._resultOption = this._resultOption + reuse._error = this._error reuse } @@ -249,6 +257,7 @@ case class BroadcastQueryStageExec( ReusedExchangeExec(newOutput, broadcast), _canonicalized) reuse._resultOption = this._resultOption + reuse._error = this._error reuse } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 8f2b7ca5cba25..750b74aab384f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -682,7 +682,7 @@ case class HashAggregateExec( | $unsafeRowKeys, $unsafeRowKeyHash); | if ($unsafeRowBuffer == null) { | // failed to allocate the first page - | throw new $oomeClassName("No enough memory for aggregation"); + | throw new $oomeClassName("_LEGACY_ERROR_TEMP_3302", new java.util.HashMap()); | } |} """.stripMargin diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala index 3b1f349520f39..19a36483abe6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashMapGenerator.scala @@ -173,9 +173,9 @@ abstract class HashMapGenerator( ${hashBytes(bytes)} """ } - case st: StringType if st.supportsBinaryEquality && !st.usesTrimCollation => + case st: StringType if st.supportsBinaryEquality => hashBytes(s"$input.getBytes()") - case st: StringType if !st.supportsBinaryEquality || st.usesTrimCollation => + case st: StringType if !st.supportsBinaryEquality => hashLong(s"CollationFactory.fetchCollation(${st.collationId})" + s".hashFunction.applyAsLong($input)") case CalendarIntervalType => hashInt(s"$input.hashCode()") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 1ebf0d143bd1f..2f1cda9d0f9be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.aggregate +import java.util + import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.memory.SparkOutOfMemoryError @@ -210,7 +212,7 @@ class TungstenAggregationIterator( if (buffer == null) { // failed to allocate the first page // scalastyle:off throwerror - throw new SparkOutOfMemoryError("No enough memory for aggregation") + throw new SparkOutOfMemoryError("_LEGACY_ERROR_TEMP_3302", new util.HashMap()) // scalastyle:on throwerror } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 7946068b9452e..6e79a2f2a3267 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -585,14 +585,26 @@ object JdbcUtils extends Logging with SQLConfHelper { arr => new GenericArrayData(elementConversion(et0)(arr)) } + case IntegerType => arrayConverter[Int]((i: Int) => i) + case FloatType => arrayConverter[Float]((f: Float) => f) + case DoubleType => arrayConverter[Double]((d: Double) => d) + case ShortType => arrayConverter[Short]((s: Short) => s) + case BooleanType => arrayConverter[Boolean]((b: Boolean) => b) + case LongType => arrayConverter[Long]((l: Long) => l) + case _ => (array: Object) => array.asInstanceOf[Array[Any]] } (rs: ResultSet, row: InternalRow, pos: Int) => - val array = nullSafeConvert[java.sql.Array]( - input = rs.getArray(pos + 1), - array => new GenericArrayData(elementConversion(et)(array.getArray))) - row.update(pos, array) + try { + val array = nullSafeConvert[java.sql.Array]( + input = rs.getArray(pos + 1), + array => new GenericArrayData(elementConversion(et)(array.getArray()))) + row.update(pos, array) + } catch { + case e: java.lang.ClassCastException => + throw QueryExecutionErrors.wrongDatatypeInSomeRows(pos, dt) + } case NullType => (_: ResultSet, row: InternalRow, pos: Int) => row.update(pos, null) diff --git a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out index 133cd6a60a4fb..31919381c99b6 100644 --- a/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/analyzer-results/random.sql.out @@ -188,7 +188,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seed", + "inputName" : "`seed`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(10, 20, col)\"" }, @@ -211,7 +211,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "min", + "inputName" : "`min`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(col, 10, 0)\"" }, @@ -436,7 +436,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "length", + "inputName" : "`length`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, @@ -459,7 +459,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seedExpression", + "inputName" : "`seed`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, diff --git a/sql/core/src/test/resources/sql-tests/results/random.sql.out b/sql/core/src/test/resources/sql-tests/results/random.sql.out index 0b4e5e078ee15..01638abdcec6e 100644 --- a/sql/core/src/test/resources/sql-tests/results/random.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/random.sql.out @@ -240,7 +240,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seed", + "inputName" : "`seed`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(10, 20, col)\"" }, @@ -265,7 +265,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "min", + "inputName" : "`min`", "inputType" : "integer or floating-point", "sqlExpr" : "\"uniform(col, 10, 0)\"" }, @@ -520,7 +520,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "length", + "inputName" : "`length`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(col, 0)\"" }, @@ -545,7 +545,7 @@ org.apache.spark.sql.catalyst.ExtendedAnalysisException "sqlState" : "42K09", "messageParameters" : { "inputExpr" : "\"col\"", - "inputName" : "seedExpression", + "inputName" : "`seed`", "inputType" : "INT or SMALLINT", "sqlExpr" : "\"randstr(10, col)\"" }, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index ce6818652d2b5..d568cd77050fd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -2819,16 +2819,24 @@ class CollationSQLExpressionsSuite } } - test("collect_set supports collation") { + test("collect_set does not support collation") { val collation = "UNICODE" val query = s"SELECT collect_set(col) FROM VALUES ('a'), ('b'), ('a') AS tab(col);" withSQLConf(SqlApiConf.DEFAULT_COLLATION -> collation) { - val result = sql(query).collect().head.getSeq[String](0).toSet - val expected = Set("a", "b") - assert(result == expected) - // check result row data type - val dataType = ArrayType(StringType(collation), false) - assert(sql(query).schema.head.dataType == dataType) + checkError( + exception = intercept[AnalysisException] { + sql(query) + }, + condition = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", + sqlState = Some("42K09"), + parameters = Map( + "functionName" -> "`collect_set`", + "dataType" -> "\"MAP\" or \"COLLATED STRING\"", + "sqlExpr" -> "\"collect_set(col)\""), + context = ExpectedContext( + fragment = "collect_set(col)", + start = 7, + stop = 22)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 4234d73c1794d..b6da0b169f050 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1333,7 +1333,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { for (codeGen <- Seq("NO_CODEGEN", "CODEGEN_ONLY")) { val collationSetup = if (collation.isEmpty) "" else " COLLATE " + collation val supportsBinaryEquality = collation.isEmpty || collation == "UNICODE" || - CollationFactory.fetchCollation(collation).supportsBinaryEquality + CollationFactory.fetchCollation(collation).isUtf8BinaryType test(s"Group by on map containing$collationSetup strings ($codeGen)") { val tableName = "t" @@ -1558,7 +1558,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: HashJoin => b.leftKeys.head }.head.isInstanceOf[CollationKey]) @@ -1615,7 +1615,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function.asInstanceOf[LambdaFunction]. @@ -1676,7 +1676,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(collectFirst(queryPlan) { case b: BroadcastHashJoinExec => b.leftKeys.head }.head.asInstanceOf[ArrayTransform].function. @@ -1735,7 +1735,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) @@ -1794,7 +1794,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ) // Only if collation doesn't support binary equality, collation key should be injected. - if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + if (!CollationFactory.fetchCollation(t.collation).isUtf8BinaryType) { assert(queryPlan.toString().contains("collationkey")) } else { assert(!queryPlan.toString().contains("collationkey")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala index e6907b8656482..970ed5843b3c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CsvFunctionsSuite.scala @@ -352,12 +352,10 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_csv($"value", schema, Map("mode" -> "DROPMALFORMED"))).collect() }, - condition = "_LEGACY_ERROR_TEMP_1099", + condition = "PARSE_MODE_UNSUPPORTED", parameters = Map( - "funcName" -> "from_csv", - "mode" -> "DROPMALFORMED", - "permissiveMode" -> "PERMISSIVE", - "failFastMode" -> "FAILFAST")) + "funcName" -> "`from_csv`", + "mode" -> "DROPMALFORMED")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e80c3b23a7db3..25f4d9f62354a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -648,7 +648,7 @@ class DataFrameAggregateSuite extends QueryTest condition = "DATATYPE_MISMATCH.UNSUPPORTED_INPUT_TYPE", parameters = Map( "functionName" -> "`collect_set`", - "dataType" -> "\"MAP\"", + "dataType" -> "\"MAP\" or \"COLLATED STRING\"", "sqlExpr" -> "\"collect_set(b)\"" ), context = ExpectedContext( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 8c1cc6c3bea1d..48ea0e01a4372 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -39,6 +39,15 @@ import org.apache.spark.unsafe.types.CalendarInterval class DataFrameComplexTypeSuite extends QueryTest with SharedSparkSession { import testImplicits._ + test("ArrayTransform with scan input") { + withTempPath { f => + spark.sql("select array(array(1, null, 3), array(4, 5, null), array(null, 8, 9)) as a") + .write.parquet(f.getAbsolutePath) + val df = spark.read.parquet(f.getAbsolutePath).selectExpr("transform(a, (x, i) -> x)") + checkAnswer(df, Row(Seq(Seq(1, null, 3), Seq(4, 5, null), Seq(null, 8, 9)))) + } + } + test("UDF on struct") { val f = udf((a: String) => a) val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 47691e1ccd40f..39c839ae5a518 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -478,7 +478,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( - "inputName" -> "length", + "inputName" -> "`length`", "inputType" -> "INT or SMALLINT", "inputExpr" -> "\"a\"", "sqlExpr" -> "\"randstr(a, 10)\""), @@ -530,7 +530,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { intercept[AnalysisException](df.select(expr)), condition = "DATATYPE_MISMATCH.NON_FOLDABLE_INPUT", parameters = Map( - "inputName" -> "min", + "inputName" -> "`min`", "inputType" -> "integer or floating-point", "inputExpr" -> "\"a\"", "sqlExpr" -> "\"uniform(a, 10)\""), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 7b19ad988d308..84408d8e2495d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -861,12 +861,10 @@ class JsonFunctionsSuite extends QueryTest with SharedSparkSession { exception = intercept[AnalysisException] { df.select(from_json($"value", schema, Map("mode" -> "DROPMALFORMED"))).collect() }, - condition = "_LEGACY_ERROR_TEMP_1099", + condition = "PARSE_MODE_UNSUPPORTED", parameters = Map( - "funcName" -> "from_json", - "mode" -> "DROPMALFORMED", - "permissiveMode" -> "PERMISSIVE", - "failFastMode" -> "FAILFAST")) + "funcName" -> "`from_json`", + "mode" -> "DROPMALFORMED")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index ec240d71b851f..c94f57a11426a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -352,6 +352,44 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { // scalastyle:on } + test("UTF-8 string is valid") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(is_valid_utf8($"a")), Row(true)) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(is_valid_utf8($"b")), Row(null)) + checkAnswer(Seq(Array[Byte](-1)).toDF("a").select(is_valid_utf8($"a")), Row(false)) + // scalastyle:on + } + + test("UTF-8 string make valid") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(make_valid_utf8($"a")), Row("大千世界")) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(make_valid_utf8($"b")), Row(null)) + checkAnswer(Seq(Array[Byte](-1)).toDF("a").select(make_valid_utf8($"a")), Row("\uFFFD")) + // scalastyle:on + } + + test("UTF-8 string validate") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(validate_utf8($"a")), Row("大千世界")) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(validate_utf8($"b")), Row(null)) + checkError( + exception = intercept[SparkIllegalArgumentException] { + Seq(Array[Byte](-1)).toDF("a").select(validate_utf8($"a")).collect() + }, + condition = "INVALID_UTF8_STRING", + parameters = Map("str" -> "\\xFF") + ) + // scalastyle:on + } + + test("UTF-8 string try validate") { + // scalastyle:off + checkAnswer(Seq("大千世界").toDF("a").select(try_validate_utf8($"a")), Row("大千世界")) + checkAnswer(Seq(("abc", null)).toDF("a", "b").select(try_validate_utf8($"b")), Row(null)) + checkAnswer(Seq(Array[Byte](-1)).toDF("a").select(try_validate_utf8($"a")), Row(null)) + // scalastyle:on + } + test("string translate") { val df = Seq(("translate", "")).toDF("a", "b") checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala index 19d4ac23709b6..fe5c6ef004920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantEndToEndSuite.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.sql -import org.apache.spark.SparkThrowable +import org.apache.spark.{SparkException, SparkRuntimeException} import org.apache.spark.sql.QueryTest.sameRows import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} @@ -359,16 +359,24 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession { val expectedMetadata: Array[Byte] = Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c') assert(actual === new VariantVal(expectedValue, expectedMetadata)) } - withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") { - val df = Seq(json).toDF("j") - .selectExpr("from_json(j,'variant')") - checkError( - exception = intercept[SparkThrowable] { + // Check whether the parse_json and from_json expressions throw the correct exception. + Seq("from_json(j, 'variant')", "parse_json(j)").foreach { expr => + withSQLConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS.key -> "false") { + val df = Seq(json).toDF("j").selectExpr(expr) + val exception = intercept[SparkException] { df.collect() - }, - condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", - parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST") - ) + } + checkError( + exception = exception, + condition = "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION", + parameters = Map("badRecord" -> json, "failFastMode" -> "FAILFAST") + ) + checkError( + exception = exception.getCause.asInstanceOf[SparkRuntimeException], + condition = "VARIANT_DUPLICATE_KEY", + parameters = Map("key" -> "a") + ) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index c5e64c96b2c8a..1df045764d8b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -936,7 +936,8 @@ class AdaptiveQueryExecSuite val error = intercept[SparkException] { joined.collect() } - assert(error.getMessage() contains "coalesce test error") + assert((Seq(error) ++ Option(error.getCause) ++ error.getSuppressed()).exists( + e => e.getMessage() != null && e.getMessage().contains("coalesce test error"))) val adaptivePlan = joined.queryExecution.executedPlan.asInstanceOf[AdaptiveSparkPlanExec] @@ -2829,6 +2830,38 @@ class AdaptiveQueryExecSuite assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) assert(findTopLevelUnion(adaptivePlan).size == 0) } + + withSQLConf( + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "100") { + withTempView("t1", "t2", "t3", "t4") { + Seq(1).toDF().createOrReplaceTempView("t1") + spark.range(100).createOrReplaceTempView("t2") + spark.range(2).createOrReplaceTempView("t3") + spark.range(2).createOrReplaceTempView("t4") + val (_, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT tt2.value + |FROM ( + | SELECT value + | FROM t1 + | WHERE NOT EXISTS ( + | SELECT 1 + | FROM ( + | SELECT t2.id + | FROM t2 + | JOIN t3 ON t2.id = t3.id + | AND t2.id > 100 + | ) tt + | WHERE t1.value = tt.id + | ) + | AND t1.value = 1 + |) tt2 + | LEFT JOIN t4 ON tt2.value = t4.id + |""".stripMargin + ) + assert(findTopLevelBroadcastNestedLoopJoin(adaptivePlan).size == 1) + } + } } test("SPARK-39915: Dataset.repartition(N) may not create N partitions") { @@ -3032,6 +3065,27 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-49979: AQE hang forever when collecting twice on a failed AQE plan") { + val func: Long => Boolean = (i : Long) => { + throw new Exception("SPARK-49979") + } + withUserDefinedFunction("func" -> true) { + spark.udf.register("func", func) + val df1 = spark.range(1024).select($"id".as("key1")) + val df2 = spark.range(2048).select($"id".as("key2")) + .withColumn("group_key", $"key2" % 1024) + val df = df1.filter(expr("func(key1)")).hint("MERGE").join(df2, $"key1" === $"key2") + .groupBy($"group_key").agg("key1" -> "count") + intercept[Throwable] { + df.collect() + } + // second collect should not hang forever + intercept[Throwable] { + df.collect() + } + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index c88f51a6b7d06..8091d6e64fdc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -1173,7 +1173,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { exception = intercept[SparkUnsupportedOperationException] { catalog.alterNamespace(testNs, NamespaceChange.removeProperty(p)) }, - condition = "_LEGACY_ERROR_TEMP_2069", + condition = "CANNOT_REMOVE_RESERVED_PROPERTY", parameters = Map("property" -> p)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala index 91f21c4a2ed34..059e4aadef2bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/xml/XmlSuite.scala @@ -1315,12 +1315,10 @@ class XmlSuite spark.sql(s"""SELECT schema_of_xml('1', map('mode', 'DROPMALFORMED'))""") .collect() }, - condition = "_LEGACY_ERROR_TEMP_1099", + condition = "PARSE_MODE_UNSUPPORTED", parameters = Map( - "funcName" -> "schema_of_xml", - "mode" -> "DROPMALFORMED", - "permissiveMode" -> "PERMISSIVE", - "failFastMode" -> FailFastMode.name) + "funcName" -> "`schema_of_xml`", + "mode" -> "DROPMALFORMED") ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index aca968745d198..0cc4f7bf2548e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -504,6 +504,12 @@ class ColumnVectorSuite extends SparkFunSuite with SQLHelper { val arr = new ColumnarArray(testVector, 0, testVector.capacity) assert(arr.toSeq(testVector.dataType) == expected) assert(arr.copy().toSeq(testVector.dataType) == expected) + + if (expected.nonEmpty) { + val withOffset = new ColumnarArray(testVector, 1, testVector.capacity - 1) + assert(withOffset.toSeq(testVector.dataType) == expected.tail) + assert(withOffset.copy().toSeq(testVector.dataType) == expected.tail) + } } testVectors("getInts with dictionary and nulls", 3, IntegerType) { testVector =>