diff --git a/assembly/pom.xml b/assembly/pom.xml index 62888c64f7ceb..1c0a0e2a23786 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 046abd63ba1cf..cd4bca564165f 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index b833b1fcbc008..186ded95e30ac 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 12082c9b0dfe6..4453d4f3a39b3 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index a9accba20c2d7..08d185e39ec01 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index c3d8242e71752..a2ff08edf2762 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 38aa65c957e40..41536893ea953 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 725a378a13e8c..4ba2298503336 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index fcdabb29537bd..3265a4f8b0fe7 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/docs/sql-data-sources-jdbc.md b/docs/sql-data-sources-jdbc.md index 315f47696475c..e9af0ba274d7b 100644 --- a/docs/sql-data-sources-jdbc.md +++ b/docs/sql-data-sources-jdbc.md @@ -9,9 +9,9 @@ license: | The ASF licenses this file to You under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -191,7 +191,7 @@ logging into the data sources. write - + cascadeTruncate the default cascading truncate behaviour of the JDBC database in question, specified in the isCascadeTruncate in each JDBCDialect @@ -241,7 +241,25 @@ logging into the data sources. pushDownAggregate false - The option to enable or disable aggregate push-down into the JDBC data source. The default value is false, in which case Spark will not push down aggregates to the JDBC data source. Otherwise, if sets to true, aggregates will be pushed down to the JDBC data source. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates can be pushed down if and only if all the aggregate functions and the related filters can be pushed down. Spark assumes that the data source can't fully complete the aggregate and does a final aggregate over the data source output. + The option to enable or disable aggregate push-down in V2 JDBC data source. The default value is false, in which case Spark will not push down aggregates to the JDBC data source. Otherwise, if sets to true, aggregates will be pushed down to the JDBC data source. Aggregate push-down is usually turned off when the aggregate is performed faster by Spark than by the JDBC data source. Please note that aggregates can be pushed down if and only if all the aggregate functions and the related filters can be pushed down. If numPartitions equals to 1 or the group by key is the same as partitionColumn, Spark will push down aggregate to data source completely and not apply a final aggregate over the data source output. Otherwise, Spark will apply a final aggregate over the data source output. + + read + + + + pushDownLimit + false + + The option to enable or disable LIMIT push-down into V2 JDBC data source. The LIMIT push-down also includes LIMIT + SORT , a.k.a. the Top N operator. The default value is false, in which case Spark does not push down LIMIT or LIMIT with SORT to the JDBC data source. Otherwise, if sets to true, LIMIT or LIMIT with SORT is pushed down to the JDBC data source. If numPartitions is greater than 1, SPARK still applies LIMIT or LIMIT with SORT on the result from data source even if LIMIT or LIMIT with SORT is pushed down. Otherwise, if LIMIT or LIMIT with SORT is pushed down and numPartitions equals to 1, SPARK will not apply LIMIT or LIMIT with SORT on the result from data source. + + read + + + + pushDownTableSample + false + + The option to enable or disable TABLESAMPLE push-down into V2 JDBC data source. The default value is false, in which case Spark does not push down TABLESAMPLE to the JDBC data source. Otherwise, if value sets to true, TABLESAMPLE is pushed down to the JDBC data source. read @@ -288,7 +306,7 @@ logging into the data sources. Note that kerberos authentication with keytab is not always supported by the JDBC driver.
Before using keytab and principal configuration options, please make sure the following requirements are met: -* The included JDBC driver version supports kerberos authentication with keytab. +* The included JDBC driver version supports kerberos authentication with keytab. * There is a built-in connection provider which supports the used database. There is a built-in connection providers for the following databases: diff --git a/examples/pom.xml b/examples/pom.xml index 25a1a9131e65f..d0f2f1d724f2b 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/external/avro/pom.xml b/external/avro/pom.xml index bc12b51b13b5e..daddd8770efc0 100644 --- a/external/avro/pom.xml +++ b/external/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala index 144e9ad129feb..d0f38c12427c3 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScan.scala @@ -62,10 +62,6 @@ case class AvroScan( pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case a: AvroScan => super.equals(a) && dataSchema == a.dataSchema && options == a.options && equivalentFilters(pushedFilters, a.pushedFilters) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala index 9420608bb22ce..8fae89a945826 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/v2/avro/AvroScanBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.v2.avro import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -31,7 +31,7 @@ class AvroScanBuilder ( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { AvroScan( @@ -41,17 +41,16 @@ class AvroScanBuilder ( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.avroFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 14d8da6a1613e..fb29aec8a3403 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml @@ -162,5 +162,10 @@ mssql-jdbc test + + mysql + mysql-connector-java + test + diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala index cb0dd1e37e9ff..35711e57d0b72 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2IntegrationSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection +import java.util.Locale import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -36,8 +37,9 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class DB2IntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "db2" + override val namespaceOpt: Option[String] = Some("DB2INST1") override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("DB2_DOCKER_IMAGE_NAME", "ibmcom/db2:11.5.4.0") override val env = Map( @@ -59,8 +61,13 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.db2", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.db2.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.db2.pushDownAggregate", "true") - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INTEGER, name VARCHAR(10), salary DECIMAL(20, 2), bonus DOUBLE)") + .executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -86,4 +93,17 @@ class DB2IntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) assert(t.schema === expectedSchema) } + + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + + testVarPop() + testVarPop(true) + testVarSamp() + testVarSamp(true) + testStddevPop() + testStddevPop(true) + testStddevSamp() + testStddevSamp(true) + testCovarPop() + testCovarSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala new file mode 100644 index 0000000000000..f0e98fc2722b0 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DB2NamespaceSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * To run this test suite for a specific version (e.g., ibmcom/db2:11.5.6.0a): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 DB2_DOCKER_IMAGE_NAME=ibmcom/db2:11.5.6.0a + * ./build/sbt -Pdocker-integration-tests "testOnly *v2.DB2NamespaceSuite" + * }}} + */ +@DockerTest +class DB2NamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("DB2_DOCKER_IMAGE_NAME", "ibmcom/db2:11.5.6.0a") + override val env = Map( + "DB2INST1_PASSWORD" -> "rootpass", + "LICENSE" -> "accept", + "DBNAME" -> "db2foo", + "ARCHIVE_LOGS" -> "false", + "AUTOCONFIG" -> "false" + ) + override val usesIpc = false + override val jdbcPort: Int = 50000 + override val privileged = true + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:db2://$ip:$port/db2foo:user=db2inst1;password=rootpass;retrieveMessagesFromServerOnGetMessage=true;" //scalastyle:ignore + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "com.ibm.db2.jcc.DB2Driver").asJava) + + catalog.initialize("db2", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("NULLID"), Array("SQLJ"), Array("SYSCAT"), Array("SYSFUN"), + Array("SYSIBM"), Array("SYSIBMADM"), Array("SYSIBMINTERNAL"), Array("SYSIBMTS"), + Array("SYSPROC"), Array("SYSPUBLIC"), Array("SYSSTAT"), Array("SYSTOOLS")) + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + builtinNamespaces ++ Array(namespace) + } + + override val supportsDropSchemaCascade: Boolean = false + + testListNamespaces() + testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala new file mode 100644 index 0000000000000..72edfc9f1bf1c --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/DockerJDBCIntegrationV2Suite.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import org.apache.spark.sql.jdbc.DockerJDBCIntegrationSuite + +abstract class DockerJDBCIntegrationV2Suite extends DockerJDBCIntegrationSuite { + + /** + * Prepare databases and tables for testing. + */ + override def dataPreparation(connection: Connection): Unit = { + tablePreparation(connection) + connection.prepareStatement("INSERT INTO employee VALUES (1, 'amy', 10000, 1000)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (2, 'alex', 12000, 1200)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (1, 'cathy', 9000, 1200)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (2, 'david', 10000, 1300)") + .executeUpdate() + connection.prepareStatement("INSERT INTO employee VALUES (6, 'jen', 12000, 1200)") + .executeUpdate() + } + + def tablePreparation(connection: Connection): Unit +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala index b9f5b774a5347..4df5f4525a0fa 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerIntegrationSuite.scala @@ -24,7 +24,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -36,7 +36,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "mssql" @@ -57,10 +57,15 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBC override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mssql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mssql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.mssql.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) - override def dataPreparation(conn: Connection): Unit = {} + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INT, name VARCHAR(32), salary NUMERIC(20, 2), bonus FLOAT)") + .executeUpdate() + } override def notSupportsTableComment: Boolean = true @@ -90,4 +95,13 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBC assert(msg.contains("UpdateColumnNullability is not supported")) } + + testVarPop() + testVarPop(true) + testVarSamp() + testVarSamp(true) + testStddevPop() + testStddevPop(true) + testStddevSamp() + testStddevSamp(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala new file mode 100644 index 0000000000000..aa8dac266380a --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MsSqlServerNamespaceSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * To run this test suite for a specific version (e.g., 2019-CU13-ubuntu-20.04): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 + * MSSQLSERVER_DOCKER_IMAGE_NAME=mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04 + * ./build/sbt -Pdocker-integration-tests "testOnly *v2.MsSqlServerNamespaceSuite" + * }}} + */ +@DockerTest +class MsSqlServerNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MSSQLSERVER_DOCKER_IMAGE_NAME", + "mcr.microsoft.com/mssql/server:2019-CU13-ubuntu-20.04") + override val env = Map( + "SA_PASSWORD" -> "Sapass123", + "ACCEPT_EULA" -> "Y" + ) + override val usesIpc = false + override val jdbcPort: Int = 1433 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:sqlserver://$ip:$port;user=sa;password=Sapass123;" + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "com.microsoft.sqlserver.jdbc.SQLServerDriver").asJava) + + catalog.initialize("mssql", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("db_accessadmin"), Array("db_backupoperator"), Array("db_datareader"), + Array("db_datawriter"), Array("db_ddladmin"), Array("db_denydatareader"), + Array("db_denydatawriter"), Array("db_owner"), Array("db_securityadmin"), Array("dbo"), + Array("guest"), Array("INFORMATION_SCHEMA"), Array("sys")) + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + builtinNamespaces ++ Array(namespace) + } + + override val supportsSchemaComment: Boolean = false + + override val supportsDropSchemaCascade: Boolean = false + + testListNamespaces() + testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala index db626dfdf8c39..97f521a378eb7 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLIntegrationSuite.scala @@ -24,25 +24,22 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest /** - * - * To run this test suite for a specific version (e.g., mysql:5.7.31): + * To run this test suite for a specific version (e.g., mysql:5.7.36): * {{{ - * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.31 + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLIntegrationSuite" - * * }}} - * */ @DockerTest -class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class MySQLIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "mysql" override val db = new DatabaseOnDocker { - override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.31") + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") override val env = Map( "MYSQL_ROOT_PASSWORD" -> "rootpass" ) @@ -57,13 +54,17 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.mysql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.mysql.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.mysql.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) private var mySQLVersion = -1 - override def dataPreparation(conn: Connection): Unit = { - mySQLVersion = conn.getMetaData.getDatabaseMajorVersion + override def tablePreparation(connection: Connection): Unit = { + mySQLVersion = connection.getMetaData.getDatabaseMajorVersion + connection.prepareStatement( + "CREATE TABLE employee (dept INT, name VARCHAR(32), salary DECIMAL(20, 2)," + + " bonus DOUBLE)").executeUpdate() } override def testUpdateColumnType(tbl: String): Unit = { @@ -115,4 +116,13 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) assert(t.schema === expectedSchema) } + + override def supportsIndex: Boolean = true + + override def indexOptions: String = "KEY_BLOCK_SIZE=10" + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala new file mode 100644 index 0000000000000..d8dee61d70ea6 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/MySQLNamespaceSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.{Connection, SQLFeatureNotSupportedException} + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.NamespaceChange +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * To run this test suite for a specific version (e.g., mysql:5.7.36): + * {{{ + * ENABLE_DOCKER_INTEGRATION_TESTS=1 MYSQL_DOCKER_IMAGE_NAME=mysql:5.7.36 + * ./build/sbt -Pdocker-integration-tests "testOnly *v2*MySQLNamespaceSuite" + * }}} + */ +@DockerTest +class MySQLNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + override val imageName = sys.env.getOrElse("MYSQL_DOCKER_IMAGE_NAME", "mysql:5.7.36") + override val env = Map( + "MYSQL_ROOT_PASSWORD" -> "rootpass" + ) + override val usesIpc = false + override val jdbcPort: Int = 3306 + + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:mysql://$ip:$port/" + + s"mysql?user=root&password=rootpass&allowPublicKeyRetrieval=true&useSSL=false" + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "com.mysql.jdbc.Driver").asJava) + + catalog.initialize("mysql", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("information_schema"), Array("mysql"), Array("performance_schema"), Array("sys")) + + override def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + Array(builtinNamespaces.head, namespace) ++ builtinNamespaces.tail + } + + override val supportsSchemaComment: Boolean = false + + override val supportsDropSchemaRestrict: Boolean = false + + testListNamespaces() + testDropNamespaces() + + test("Create or remove comment of namespace unsupported") { + val e1 = intercept[AnalysisException] { + catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) + } + assert(e1.getMessage.contains("Failed create name space: foo")) + assert(e1.getCause.isInstanceOf[SQLFeatureNotSupportedException]) + assert(e1.getCause.asInstanceOf[SQLFeatureNotSupportedException].getMessage + .contains("Create namespace comment is not supported")) + assert(catalog.namespaceExists(Array("foo")) === false) + catalog.createNamespace(Array("foo"), Map.empty[String, String].asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + val e2 = intercept[AnalysisException] { + catalog.alterNamespace(Array("foo"), NamespaceChange + .setProperty("comment", "comment for foo")) + } + assert(e2.getMessage.contains("Failed create comment on name space: foo")) + assert(e2.getCause.isInstanceOf[SQLFeatureNotSupportedException]) + assert(e2.getCause.asInstanceOf[SQLFeatureNotSupportedException].getMessage + .contains("Create namespace comment is not supported")) + val e3 = intercept[AnalysisException] { + catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) + } + assert(e3.getMessage.contains("Failed remove comment on name space: foo")) + assert(e3.getCause.isInstanceOf[SQLFeatureNotSupportedException]) + assert(e3.getCause.asInstanceOf[SQLFeatureNotSupportedException].getMessage + .contains("Remove namespace comment is not supported")) + catalog.dropNamespace(Array("foo"), cascade = true) + assert(catalog.namespaceExists(Array("foo")) === false) + } +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index 45d793aaa743e..b38f2675243e6 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.jdbc.v2 import java.sql.Connection +import java.util.Locale import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -53,8 +54,9 @@ import org.apache.spark.tags.DockerTest * It has been validated with 18.4.0 Express Edition. */ @DockerTest -class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "oracle" + override val namespaceOpt: Option[String] = Some("SYSTEM") override val db = new DatabaseOnDocker { lazy override val imageName = sys.env("ORACLE_DOCKER_IMAGE_NAME") override val env = Map( @@ -69,9 +71,15 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.oracle", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.oracle.url", db.getJdbcUrl(dockerIp, externalPort)) + .set("spark.sql.catalog.oracle.pushDownAggregate", "true") override val connectionTimeout = timeout(7.minutes) - override def dataPreparation(conn: Connection): Unit = {} + + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept NUMBER(32), name VARCHAR2(32), salary NUMBER(20, 2)," + + " bonus BINARY_DOUBLE)").executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -89,4 +97,14 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest assert(msg1.contains( s"Cannot update $catalogName.alt_table field ID: string cannot be cast to int")) } + + override def caseConvert(tableName: String): String = tableName.toUpperCase(Locale.ROOT) + + testVarPop() + testVarSamp() + testStddevPop() + testStddevSamp() + testCovarPop() + testCovarSamp() + testCorr() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala new file mode 100644 index 0000000000000..31f26d2990666 --- /dev/null +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleNamespaceSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.jdbc.v2 + +import java.sql.Connection + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.tags.DockerTest + +/** + * The following are the steps to test this: + * + * 1. Choose to use a prebuilt image or build Oracle database in a container + * - The documentation on how to build Oracle RDBMS in a container is at + * https://github.com/oracle/docker-images/blob/master/OracleDatabase/SingleInstance/README.md + * - Official Oracle container images can be found at https://container-registry.oracle.com + * - A trustable and streamlined Oracle XE database image can be found on Docker Hub at + * https://hub.docker.com/r/gvenzl/oracle-xe see also https://github.com/gvenzl/oci-oracle-xe + * 2. Run: export ORACLE_DOCKER_IMAGE_NAME=image_you_want_to_use_for_testing + * - Example: export ORACLE_DOCKER_IMAGE_NAME=gvenzl/oracle-xe:latest + * 3. Run: export ENABLE_DOCKER_INTEGRATION_TESTS=1 + * 4. Start docker: sudo service docker start + * - Optionally, docker pull $ORACLE_DOCKER_IMAGE_NAME + * 5. Run Spark integration tests for Oracle with: ./build/sbt -Pdocker-integration-tests + * "testOnly org.apache.spark.sql.jdbc.v2.OracleNamespaceSuite" + * + * A sequence of commands to build the Oracle XE database container image: + * $ git clone https://github.com/oracle/docker-images.git + * $ cd docker-images/OracleDatabase/SingleInstance/dockerfiles + * $ ./buildContainerImage.sh -v 18.4.0 -x + * $ export ORACLE_DOCKER_IMAGE_NAME=oracle/database:18.4.0-xe + * + * This procedure has been validated with Oracle 18.4.0 Express Edition. + */ +@DockerTest +class OracleNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNamespaceTest { + override val db = new DatabaseOnDocker { + lazy override val imageName = + sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "gvenzl/oracle-xe:18.4.0") + val oracle_password = "Th1s1sThe0racle#Pass" + override val env = Map( + "ORACLE_PWD" -> oracle_password, // oracle images uses this + "ORACLE_PASSWORD" -> oracle_password // gvenzl/oracle-xe uses this + ) + override val usesIpc = false + override val jdbcPort: Int = 1521 + override def getJdbcUrl(ip: String, port: Int): String = + s"jdbc:oracle:thin:system/$oracle_password@//$ip:$port/xe" + } + + val map = new CaseInsensitiveStringMap( + Map("url" -> db.getJdbcUrl(dockerIp, externalPort), + "driver" -> "oracle.jdbc.OracleDriver").asJava) + + catalog.initialize("system", map) + + override def dataPreparation(conn: Connection): Unit = {} + + override def builtinNamespaces: Array[Array[String]] = + Array(Array("ANONYMOUS"), Array("APEX_030200"), Array("APEX_PUBLIC_USER"), Array("APPQOSSYS"), + Array("BI"), Array("DIP"), Array("FLOWS_FILES"), Array("HR"), Array("OE"), Array("PM"), + Array("SCOTT"), Array("SH"), Array("SPATIAL_CSW_ADMIN_USR"), Array("SPATIAL_WFS_ADMIN_USR"), + Array("XS$NULL")) + + // Cannot create schema dynamically + // TODO testListNamespaces() + // TODO testDropNamespaces() +} diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala index 932ddb90f6cb0..d76e13c1cd421 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresIntegrationSuite.scala @@ -22,7 +22,7 @@ import java.sql.Connection import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.jdbc.{DatabaseOnDocker, DockerJDBCIntegrationSuite} +import org.apache.spark.sql.jdbc.DatabaseOnDocker import org.apache.spark.sql.types._ import org.apache.spark.tags.DockerTest @@ -34,7 +34,7 @@ import org.apache.spark.tags.DockerTest * }}} */ @DockerTest -class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTest { +class PostgresIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTest { override val catalogName: String = "postgresql" override val db = new DatabaseOnDocker { override val imageName = sys.env.getOrElse("POSTGRES_DOCKER_IMAGE_NAME", "postgres:13.0-alpine") @@ -49,7 +49,15 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes override def sparkConf: SparkConf = super.sparkConf .set("spark.sql.catalog.postgresql", classOf[JDBCTableCatalog].getName) .set("spark.sql.catalog.postgresql.url", db.getJdbcUrl(dockerIp, externalPort)) - override def dataPreparation(conn: Connection): Unit = {} + .set("spark.sql.catalog.postgresql.pushDownTableSample", "true") + .set("spark.sql.catalog.postgresql.pushDownLimit", "true") + .set("spark.sql.catalog.postgresql.pushDownAggregate", "true") + + override def tablePreparation(connection: Connection): Unit = { + connection.prepareStatement( + "CREATE TABLE employee (dept INTEGER, name VARCHAR(32), salary NUMERIC(20, 2)," + + " bonus double precision)").executeUpdate() + } override def testUpdateColumnType(tbl: String): Unit = { sql(s"CREATE TABLE $tbl (ID INTEGER)") @@ -75,4 +83,25 @@ class PostgresIntegrationSuite extends DockerJDBCIntegrationSuite with V2JDBCTes val expectedSchema = new StructType().add("ID", IntegerType, true, defaultMetadata) assert(t.schema === expectedSchema) } + + override def supportsTableSample: Boolean = true + + override def supportsIndex: Boolean = true + + override def indexOptions: String = "FILLFACTOR=70" + + testVarPop() + testVarPop(true) + testVarSamp() + testVarSamp(true) + testStddevPop() + testStddevPop(true) + testStddevSamp() + testStddevSamp(true) + testCovarPop() + testCovarPop(true) + testCovarSamp() + testCovarSamp(true) + testCorr() + testCorr(true) } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala index b5cf3dfcb474d..4a615bddd7dfb 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/PostgresNamespaceSuite.scala @@ -53,7 +53,9 @@ class PostgresNamespaceSuite extends DockerJDBCIntegrationSuite with V2JDBCNames override def dataPreparation(conn: Connection): Unit = {} - override def builtinNamespaces: Array[Array[String]] = { + override def builtinNamespaces: Array[Array[String]] = Array(Array("information_schema"), Array("pg_catalog"), Array("public")) - } + + testListNamespaces() + testDropNamespaces() } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala index 95d59fec2fac6..bae0d7c361635 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCNamespaceTest.scala @@ -17,47 +17,117 @@ package org.apache.spark.sql.jdbc.v2 +import java.util +import java.util.Collections + import scala.collection.JavaConverters._ -import org.apache.log4j.Level +import org.apache.logging.log4j.Level import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.NamespaceChange +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException +import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} import org.apache.spark.tags.DockerTest @DockerTest private[v2] trait V2JDBCNamespaceTest extends SharedSparkSession with DockerIntegrationFunSuite { val catalog = new JDBCTableCatalog() + private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val schema: StructType = new StructType() + .add("id", IntegerType) + .add("data", StringType) + def builtinNamespaces: Array[Array[String]] - test("listNamespaces: basic behavior") { - catalog.createNamespace(Array("foo"), Map("comment" -> "test comment").asJava) - assert(catalog.listNamespaces() === Array(Array("foo")) ++ builtinNamespaces) - assert(catalog.listNamespaces(Array("foo")) === Array()) - assert(catalog.namespaceExists(Array("foo")) === true) - - val logAppender = new LogAppender("catalog comment") - withLogAppender(logAppender) { - catalog.alterNamespace(Array("foo"), NamespaceChange - .setProperty("comment", "comment for foo")) - catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) + def listNamespaces(namespace: Array[String]): Array[Array[String]] = { + Array(namespace) ++ builtinNamespaces + } + + def supportsSchemaComment: Boolean = true + + def supportsDropSchemaCascade: Boolean = true + + def supportsDropSchemaRestrict: Boolean = true + + def testListNamespaces(): Unit = { + test("listNamespaces: basic behavior") { + val commentMap = if (supportsSchemaComment) { + Map("comment" -> "test comment") + } else { + Map.empty[String, String] + } + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.listNamespaces() === listNamespaces(Array("foo"))) + assert(catalog.listNamespaces(Array("foo")) === Array()) + assert(catalog.namespaceExists(Array("foo")) === true) + + if (supportsSchemaComment) { + val logAppender = new LogAppender("catalog comment") + withLogAppender(logAppender) { + catalog.alterNamespace(Array("foo"), NamespaceChange + .setProperty("comment", "comment for foo")) + catalog.alterNamespace(Array("foo"), NamespaceChange.removeProperty("comment")) + } + val createCommentWarning = logAppender.loggingEvents + .filter(_.getLevel == Level.WARN) + .map(_.getMessage.getFormattedMessage) + .exists(_.contains("catalog comment")) + assert(createCommentWarning === false) + } + + if (supportsDropSchemaRestrict) { + catalog.dropNamespace(Array("foo"), cascade = false) + } else { + catalog.dropNamespace(Array("foo"), cascade = true) + } + assert(catalog.namespaceExists(Array("foo")) === false) + assert(catalog.listNamespaces() === builtinNamespaces) + val msg = intercept[AnalysisException] { + catalog.listNamespaces(Array("foo")) + }.getMessage + assert(msg.contains("Namespace 'foo' not found")) + } + } + + def testDropNamespaces(): Unit = { + test("Drop namespace") { + val ident1 = Identifier.of(Array("foo"), "tab") + // Drop empty namespace without cascade + val commentMap = if (supportsSchemaComment) { + Map("comment" -> "test comment") + } else { + Map.empty[String, String] + } + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + if (supportsDropSchemaRestrict) { + catalog.dropNamespace(Array("foo"), cascade = false) + } else { + catalog.dropNamespace(Array("foo"), cascade = true) + } + assert(catalog.namespaceExists(Array("foo")) === false) + + // Drop non empty namespace without cascade + catalog.createNamespace(Array("foo"), commentMap.asJava) + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.createTable(ident1, schema, Array.empty, emptyProps) + if (supportsDropSchemaRestrict) { + intercept[NonEmptyNamespaceException] { + catalog.dropNamespace(Array("foo"), cascade = false) + } + } + + // Drop non empty namespace with cascade + if (supportsDropSchemaCascade) { + assert(catalog.namespaceExists(Array("foo")) === true) + catalog.dropNamespace(Array("foo"), cascade = true) + assert(catalog.namespaceExists(Array("foo")) === false) + } } - val createCommentWarning = logAppender.loggingEvents - .filter(_.getLevel == Level.WARN) - .map(_.getRenderedMessage) - .exists(_.contains("catalog comment")) - assert(createCommentWarning === false) - - catalog.dropNamespace(Array("foo")) - assert(catalog.namespaceExists(Array("foo")) === false) - assert(catalog.listNamespaces() === builtinNamespaces) - val msg = intercept[AnalysisException] { - catalog.listNamespaces(Array("foo")) - }.getMessage - assert(msg.contains("Namespace 'foo' not found")) } } diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala index 1afe26afe1a9f..7cab8cd77df66 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/V2JDBCTest.scala @@ -19,7 +19,13 @@ package org.apache.spark.sql.jdbc.v2 import org.apache.log4j.Level -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, DataFrame} +import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sample} +import org.apache.spark.sql.connector.catalog.{Catalogs, Identifier, TableCatalog} +import org.apache.spark.sql.connector.catalog.index.SupportsIndex +import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.jdbc.DockerIntegrationFunSuite import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -27,7 +33,15 @@ import org.apache.spark.tags.DockerTest @DockerTest private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFunSuite { + import testImplicits._ + val catalogName: String + + val namespaceOpt: Option[String] = None + + private def catalogAndNamespace = + namespaceOpt.map(namespace => s"$catalogName.$namespace").getOrElse(catalogName) + // dialect specific update column type test def testUpdateColumnType(tbl: String): Unit @@ -180,5 +194,313 @@ private[v2] trait V2JDBCTest extends SharedSparkSession with DockerIntegrationFu testCreateTableWithProperty(s"$catalogName.new_table") } } -} + def supportsIndex: Boolean = false + + def indexOptions: String = "" + + test("SPARK-36895: Test INDEX Using SQL") { + if (supportsIndex) { + withTable(s"$catalogName.new_table") { + sql(s"CREATE TABLE $catalogName.new_table(col1 INT, col2 INT, col3 INT," + + " col4 INT, col5 INT)") + val loaded = Catalogs.load(catalogName, conf) + val jdbcTable = loaded.asInstanceOf[TableCatalog] + .loadTable(Identifier.of(Array.empty[String], "new_table")) + .asInstanceOf[SupportsIndex] + assert(jdbcTable.indexExists("i1") == false) + assert(jdbcTable.indexExists("i2") == false) + + val indexType = "DUMMY" + var m = intercept[UnsupportedOperationException] { + sql(s"CREATE index i1 ON $catalogName.new_table USING $indexType (col1)") + }.getMessage + assert(m.contains(s"Index Type $indexType is not supported." + + s" The supported Index Types are:")) + + sql(s"CREATE index i1 ON $catalogName.new_table USING BTREE (col1)") + sql(s"CREATE index i2 ON $catalogName.new_table (col2, col3, col5)" + + s" OPTIONS ($indexOptions)") + + assert(jdbcTable.indexExists("i1") == true) + assert(jdbcTable.indexExists("i2") == true) + + // This should pass without exception + sql(s"CREATE index IF NOT EXISTS i1 ON $catalogName.new_table (col1)") + + m = intercept[IndexAlreadyExistsException] { + sql(s"CREATE index i1 ON $catalogName.new_table (col1)") + }.getMessage + assert(m.contains("Failed to create index i1 in new_table")) + + sql(s"DROP index i1 ON $catalogName.new_table") + sql(s"DROP index i2 ON $catalogName.new_table") + + assert(jdbcTable.indexExists("i1") == false) + assert(jdbcTable.indexExists("i2") == false) + + // This should pass without exception + sql(s"DROP index IF EXISTS i1 ON $catalogName.new_table") + + m = intercept[NoSuchIndexException] { + sql(s"DROP index i1 ON $catalogName.new_table") + }.getMessage + assert(m.contains("Failed to drop index i1 in new_table")) + } + } + } + + def supportsTableSample: Boolean = false + + private def checkSamplePushed(df: DataFrame, pushed: Boolean = true): Unit = { + val sample = df.queryExecution.optimizedPlan.collect { + case s: Sample => s + } + if (pushed) { + assert(sample.isEmpty) + } else { + assert(sample.nonEmpty) + } + } + + private def checkFilterPushed(df: DataFrame, pushed: Boolean = true): Unit = { + val filter = df.queryExecution.optimizedPlan.collect { + case f: Filter => f + } + if (pushed) { + assert(filter.isEmpty) + } else { + assert(filter.nonEmpty) + } + } + + private def limitPushed(df: DataFrame, limit: Int): Boolean = { + df.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + return v1.pushedDownOperators.limit == Some(limit) + } + } + false + } + + private def checkColumnPruned(df: DataFrame, col: String): Unit = { + val scan = df.queryExecution.optimizedPlan.collectFirst { + case s: DataSourceV2ScanRelation => s + }.get + assert(scan.schema.names.sameElements(Seq(col))) + } + + test("SPARK-37038: Test TABLESAMPLE") { + if (supportsTableSample) { + withTable(s"$catalogName.new_table") { + sql(s"CREATE TABLE $catalogName.new_table (col1 INT, col2 INT)") + spark.range(10).select($"id" * 2, $"id" * 2 + 1).write.insertInto(s"$catalogName.new_table") + + // sample push down + column pruning + val df1 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + + " REPEATABLE (12345)") + checkSamplePushed(df1) + checkColumnPruned(df1, "col1") + assert(df1.collect().length < 10) + + // sample push down only + val df2 = sql(s"SELECT * FROM $catalogName.new_table TABLESAMPLE (50 PERCENT)" + + " REPEATABLE (12345)") + checkSamplePushed(df2) + assert(df2.collect().length < 10) + + // sample(BUCKET ... OUT OF) push down + limit push down + column pruning + val df3 = sql(s"SELECT col1 FROM $catalogName.new_table TABLESAMPLE (BUCKET 6 OUT OF 10)" + + " LIMIT 2") + checkSamplePushed(df3) + assert(limitPushed(df3, 2)) + checkColumnPruned(df3, "col1") + assert(df3.collect().length <= 2) + + // sample(... PERCENT) push down + limit push down + column pruning + val df4 = sql(s"SELECT col1 FROM $catalogName.new_table" + + " TABLESAMPLE (50 PERCENT) REPEATABLE (12345) LIMIT 2") + checkSamplePushed(df4) + assert(limitPushed(df4, 2)) + checkColumnPruned(df4, "col1") + assert(df4.collect().length <= 2) + + // sample push down + filter push down + limit push down + val df5 = sql(s"SELECT * FROM $catalogName.new_table" + + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") + checkSamplePushed(df5) + checkFilterPushed(df5) + assert(limitPushed(df5, 2)) + assert(df5.collect().length <= 2) + + // sample + filter + limit + column pruning + // sample pushed down, filer/limit not pushed down, column pruned + // Todo: push down filter/limit + val df6 = sql(s"SELECT col1 FROM $catalogName.new_table" + + " TABLESAMPLE (BUCKET 6 OUT OF 10) WHERE col1 > 0 LIMIT 2") + checkSamplePushed(df6) + checkFilterPushed(df6, false) + assert(!limitPushed(df6, 2)) + checkColumnPruned(df6, "col1") + assert(df6.collect().length <= 2) + + // sample + limit + // Push down order is sample -> filter -> limit + // only limit is pushed down because in this test sample is after limit + val df7 = spark.read.table(s"$catalogName.new_table").limit(2).sample(0.5) + checkSamplePushed(df7, false) + assert(limitPushed(df7, 2)) + + // sample + filter + // Push down order is sample -> filter -> limit + // only filter is pushed down because in this test sample is after filter + val df8 = spark.read.table(s"$catalogName.new_table").where($"col1" > 1).sample(0.5) + checkSamplePushed(df8, false) + checkFilterPushed(df8) + assert(df8.collect().length < 10) + } + } + } + + protected def checkAggregateRemoved(df: DataFrame): Unit = { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg + } + assert(aggregates.isEmpty) + } + + private def checkAggregatePushed(df: DataFrame, funcName: String): Unit = { + df.queryExecution.optimizedPlan.collect { + case DataSourceV2ScanRelation(_, scan, _) => + assert(scan.isInstanceOf[V1ScanWrapper]) + val wrapper = scan.asInstanceOf[V1ScanWrapper] + assert(wrapper.pushedDownOperators.aggregation.isDefined) + val aggregationExpressions = + wrapper.pushedDownOperators.aggregation.get.aggregateExpressions() + assert(aggregationExpressions.length == 1) + assert(aggregationExpressions(0).isInstanceOf[GeneralAggregateFunc]) + assert(aggregationExpressions(0).asInstanceOf[GeneralAggregateFunc].name() == funcName) + } + } + + protected def caseConvert(tableName: String): String = tableName + + protected def testVarPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: VAR_POP with distinct: $isDistinct") { + val df = sql(s"SELECT VAR_POP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "VAR_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 10000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testVarSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: VAR_SAMP with distinct: $isDistinct") { + val df = sql( + s"SELECT VAR_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "VAR_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000d) + assert(row(1).getDouble(0) === 5000d) + assert(row(2).isNullAt(0)) + } + } + + protected def testStddevPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: STDDEV_POP with distinct: $isDistinct") { + val df = sql( + s"SELECT STDDEV_POP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "STDDEV_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 100d) + assert(row(1).getDouble(0) === 50d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testStddevSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: STDDEV_SAMP with distinct: $isDistinct") { + val df = sql( + s"SELECT STDDEV_SAMP(${distinct}bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "STDDEV_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 141.4213562373095d) + assert(row(1).getDouble(0) === 70.71067811865476d) + assert(row(2).isNullAt(0)) + } + } + + protected def testCovarPop(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: COVAR_POP with distinct: $isDistinct") { + val df = sql( + s"SELECT COVAR_POP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "COVAR_POP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 10000d) + assert(row(1).getDouble(0) === 2500d) + assert(row(2).getDouble(0) === 0d) + } + } + + protected def testCovarSamp(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: COVAR_SAMP with distinct: $isDistinct") { + val df = sql( + s"SELECT COVAR_SAMP(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "COVAR_SAMP") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 20000d) + assert(row(1).getDouble(0) === 5000d) + assert(row(2).isNullAt(0)) + } + } + + protected def testCorr(isDistinct: Boolean = false): Unit = { + val distinct = if (isDistinct) "DISTINCT " else "" + test(s"scan with aggregate push-down: CORR with distinct: $isDistinct") { + val df = sql( + s"SELECT CORR(${distinct}bonus, bonus) FROM $catalogAndNamespace." + + s"${caseConvert("employee")} WHERE dept > 0 GROUP BY dept ORDER BY dept") + checkFilterPushed(df) + checkAggregateRemoved(df) + checkAggregatePushed(df, "CORR") + val row = df.collect() + assert(row.length === 3) + assert(row(0).getDouble(0) === 1d) + assert(row(1).getDouble(0) === 1d) + assert(row(2).isNullAt(0)) + } + } +} diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 8ef608a021fe7..dea5a0a23c92a 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 649c7af28e92a..148a9625fed19 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kafka-0-10-token-provider/pom.xml b/external/kafka-0-10-token-provider/pom.xml index dfbaa18d4c698..882c194653c84 100644 --- a/external/kafka-0-10-token-provider/pom.xml +++ b/external/kafka-0-10-token-provider/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index eb5c9d4c9ca33..6774d254e80df 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 5d087dc9dd633..ab23f231fd978 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index e02048e28a2f7..55c892bdb7510 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index a8338509ad0e2..87259b3cd607e 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index da5bc3e8fcbce..5e2b40fd8d917 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index 318921a298493..1206506f214e6 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index 8d66de29b51bc..b609eaf181019 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index d9c22bf33e8e3..aae11098ae944 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index fcd7ade1810aa..796d55e7d7785 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/pom.xml b/pom.xml index 09de5d6f45ff7..25243d6e1132d 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 pom Spark Project Parent POM http://spark.apache.org/ diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dba74ac9bb217..cc148d9e247f6 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -34,8 +34,75 @@ import com.typesafe.tools.mima.core.ProblemFilters._ */ object MimaExcludes { - // Exclude rules for 3.2.x - lazy val v32excludes = v31excludes ++ Seq( + // Exclude rules for 3.4.x + lazy val v34excludes = v33excludes ++ Seq( + ) + + // Exclude rules for 3.3.x from 3.2.0 + lazy val v33excludes = v32excludes ++ Seq( + // [SPARK-35672][CORE][YARN] Pass user classpath entries to executors using config instead of command line + // The followings are necessary for Scala 2.13. + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.CoarseGrainedExecutorBackend#Arguments.*"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.CoarseGrainedExecutorBackend#Arguments.*"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.CoarseGrainedExecutorBackend$Arguments$"), + + // [SPARK-37391][SQL] JdbcConnectionProvider tells if it modifies security context + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.jdbc.JdbcConnectionProvider.modifiesSecurityContext"), + + // [SPARK-37780][SQL] QueryExecutionListener support SQLConf as constructor parameter + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this"), + // [SPARK-37786][SQL] StreamingQueryListener support use SQLConf.get to get corresponding SessionState's SQLConf + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.this"), + // [SPARK-38432][SQL] Reactor framework so as JDBC dialect could compile filter by self way + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.toV2"), + + // [SPARK-37600][BUILD] Upgrade to Hadoop 3.3.2 + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4Compressor"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4Factory"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.hadoop.shaded.net.jpountz.lz4.LZ4SafeDecompressor") + ) + + // Exclude rules for 3.2.x from 3.1.1 + lazy val v32excludes = Seq( + // Spark Internals + ProblemFilters.exclude[Problem]("org.apache.spark.rpc.*"), + ProblemFilters.exclude[Problem]("org.spark-project.jetty.*"), + ProblemFilters.exclude[Problem]("org.spark_project.jetty.*"), + ProblemFilters.exclude[Problem]("org.sparkproject.jetty.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.internal.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.unused.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.unsafe.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.memory.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.util.collection.unsafe.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.errors.*"), + // DSv2 catalog and expression APIs are unstable yet. We should enable this back. + ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.catalog.*"), + ProblemFilters.exclude[Problem]("org.apache.spark.sql.connector.expressions.*"), + // Avro source implementation is internal. + ProblemFilters.exclude[Problem]("org.apache.spark.sql.v2.avro.*"), + + // [SPARK-34848][CORE] Add duration to TaskMetricDistributions + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this"), + + // [SPARK-34488][CORE] Support task Metrics Distributions and executor Metrics Distributions + // in the REST API call for a specified stage + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), + + // [SPARK-36173][CORE] Support getting CPU number in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.cpus"), + + // [SPARK-35896] Include more granular metrics for stateful operators in StreamingQueryProgress + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), + + (problem: Problem) => problem match { + case MissingClassProblem(cls) => !cls.fullName.startsWith("org.sparkproject.jpmml") && + !cls.fullName.startsWith("org.sparkproject.dmg.pmml") + case _ => true + }, + // [SPARK-33808][SQL] DataSource V2: Build logical writes in the optimizer ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.connector.write.V1WriteBuilder"), @@ -72,1722 +139,10 @@ object MimaExcludes { ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions") ) - // Exclude rules for 3.1.x - lazy val v31excludes = v30excludes ++ Seq( - // mima plugin update caused new incompatibilities to be detected - // core module - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.shuffle.sort.io.LocalDiskShuffleMapOutputWriter.commitAllPartitions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.shuffle.api.ShuffleMapOutputWriter.commitAllPartitions"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.environmentDetails"), - // mllib module - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.totalIterations"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.$init$"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.labels"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.truePositiveRateByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.falsePositiveRateByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.precisionByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.recallByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.fMeasureByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.fMeasureByLabel"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.accuracy"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedTruePositiveRate"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFalsePositiveRate"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.roc"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.areaUnderROC"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.pr"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.fMeasureByThreshold"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.precisionByThreshold"), - ProblemFilters.exclude[NewMixinForwarderProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.recallByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.FMClassifier.trainImpl"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.FMRegressor.trainImpl"), - // [SPARK-31077] Remove ChiSqSelector dependency on mllib.ChiSqSelectorModel - // private constructor - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.this"), - - // [SPARK-31127] Implement abstract Selector - // org.apache.spark.ml.feature.ChiSqSelectorModel type hierarchy change - // before: class ChiSqSelector extends Estimator with ChiSqSelectorParams - // after: class ChiSqSelector extends PSelector - // false positive, no binary incompatibility - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.ChiSqSelector"), - - // [SPARK-24634] Add a new metric regarding number of inputs later than watermark plus allowed delay - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.$default$4"), - - //[SPARK-31893] Add a generic ClassificationSummary trait - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionTrainingSummary.weightCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.weightCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$_setter_$org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$binaryMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$BinaryClassificationSummary$$sparkSession"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$_setter_$org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.org$apache$spark$ml$classification$ClassificationSummary$$multiclassMetrics"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.weightCol"), - - // [SPARK-32879] Pass SparkSession.Builder options explicitly to SparkSession - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SparkSession.this") - ) - - // Exclude rules for 3.0.x - lazy val v30excludes = v24excludes ++ Seq( - // [SPARK-23429][CORE] Add executor memory metrics to heartbeat and expose in executors REST API - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.this"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate$"), - - // [SPARK-29306] Add support for Stage level scheduling for executors - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.productPrefix"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages#RetrieveSparkAppConfig.toString"), - - // [SPARK-29399][core] Remove old ExecutorPlugin interface. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ExecutorPlugin"), - - // [SPARK-28980][SQL][CORE][MLLIB] Remove more old deprecated items in Spark 3 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.clustering.KMeans.train"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.classification.LogisticRegressionWithSGD$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.classification.LogisticRegressionWithSGD.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.isSorted"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.RidgeRegressionWithSGD$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.RidgeRegressionWithSGD.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.LassoWithSGD.this"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LassoWithSGD$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD$"), - - // [SPARK-28486][CORE][PYTHON] Map PythonBroadcast's data file to a BroadcastBlock to avoid delete by GC - ProblemFilters.exclude[InaccessibleMethodProblem]("java.lang.Object.finalize"), - - // [SPARK-27366][CORE] Support GPU Resources in Spark job scheduling - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.resources"), - - // [SPARK-29417][CORE] Resource Scheduling - add TaskContext.resource java api - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.resourcesJMap"), - - // [SPARK-27410][MLLIB] Remove deprecated / no-op mllib.KMeans getRuns, setRuns - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.getRuns"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeans.setRuns"), - - // [SPARK-26580][SQL][ML][FOLLOW-UP] Throw exception when use untyped UDF by default - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.UnaryTransformer.this"), - - // [SPARK-27090][CORE] Removing old LEGACY_DRIVER_IDENTIFIER ("") - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.LEGACY_DRIVER_IDENTIFIER"), - - // [SPARK-25838] Remove formatVersion from Saveable - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.LocalLDAModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.KMeansModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.PowerIterationClusteringModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.GaussianMixtureModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.recommendation.MatrixFactorizationModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.ChiSqSelectorModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.Word2VecModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.classification.SVMModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.classification.LogisticRegressionModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.classification.NaiveBayesModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.Saveable.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.FPGrowthModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.PrefixSpanModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.IsotonicRegressionModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.RidgeRegressionModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.LassoModel.formatVersion"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionModel.formatVersion"), - - // [SPARK-26132] Remove support for Scala 2.11 in Spark 3.0.0 - ProblemFilters.exclude[DirectAbstractMethodProblem]("scala.concurrent.Future.transformWith"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("scala.concurrent.Future.transform"), - - // [SPARK-26254][CORE] Extract Hive + Kafka dependencies from Core. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.deploy.security.HiveDelegationTokenProvider"), - - // [SPARK-26329][CORE] Faster polling of executor memory metrics. - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd.apply"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd.copy$default$6"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerTaskEnd.this"), - - // [SPARK-26311][CORE]New feature: apply custom log URL pattern for executor log URLs - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerApplicationStart.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerApplicationStart.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerApplicationStart.this"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerApplicationStart$"), - - // [SPARK-27630][CORE] Properly handle task end events from completed stages - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerSpeculativeTaskSubmitted.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerSpeculativeTaskSubmitted.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerSpeculativeTaskSubmitted.this"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerSpeculativeTaskSubmitted$"), - - // [SPARK-26632][Core] Separate Thread Configurations of Driver and Executor - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"), - - // [SPARK-16872][ML][PYSPARK] Impl Gaussian Naive Bayes Classifier - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.NaiveBayesModel.this"), - - // [SPARK-25765][ML] Add training cost to BisectingKMeans summary - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel.this"), - - // [SPARK-24243][CORE] Expose exceptions from InProcessAppHandle - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.launcher.SparkAppHandle.getError"), - - // [SPARK-25867] Remove KMeans computeCost - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), - - // [SPARK-26127] Remove deprecated setters from tree regression and classification models - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.DecisionTreeClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxIter"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setStepSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setNumTrees"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.DecisionTreeRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxIter"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setStepSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSeed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInfoGain"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setSubsamplingRate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCacheNodeIds"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setCheckpointInterval"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxDepth"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setImpurity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxMemoryInMB"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMaxBins"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setMinInstancesPerNode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setNumTrees"), - - // [SPARK-26090] Resolve most miscellaneous deprecation and build warnings for Spark 3 - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.stat.test.BinarySampleBeanInfo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.mllib.regression.LabeledPointBeanInfo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.LabeledPointBeanInfo"), - - // [SPARK-28780][ML] Delete the incorrect setWeightCol method in LinearSVCModel - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LinearSVCModel.setWeightCol"), - - // [SPARK-29645][ML][PYSPARK] ML add param RelativeError - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.relativeError"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getRelativeError"), - - // [SPARK-28968][ML] Add HasNumFeatures in the scala side - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.FeatureHasher.getNumFeatures"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.FeatureHasher.numFeatures"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.HashingTF.getNumFeatures"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.HashingTF.numFeatures"), - - // [SPARK-25908][CORE][SQL] Remove old deprecated items in Spark 3 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.BarrierTaskContext.isRunningLocally"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskContext.isRunningLocally"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleBytesWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleWriteTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.shuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.fMeasure"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.recall"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.precision"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLWriter.context"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.MLReader.context"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.util.GeneralMLWriter.context"), - - // [SPARK-25737] Remove JavaSparkContextVarargsWorkaround - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.api.java.JavaSparkContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.union"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.union"), - - // [SPARK-16775] Remove deprecated accumulator v1 APIs - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulator"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Accumulator$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulableParam"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$FloatAccumulatorParam$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$DoubleAccumulatorParam$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$LongAccumulatorParam$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.AccumulatorParam$IntAccumulatorParam$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.accumulable"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.accumulableCollection"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.accumulator"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.LegacyAccumulatorWrapper"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.intAccumulator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulable"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.doubleAccumulator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.accumulator"), - - // [SPARK-24109] Remove class SnappyOutputStreamWrapper - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.SnappyCompressionCodec.version"), - - // [SPARK-19287] JavaPairRDD flatMapValues requires function returning Iterable, not Iterator - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.api.java.JavaPairRDD.flatMapValues"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaPairDStream.flatMapValues"), - - // [SPARK-25680] SQL execution listener shouldn't happen on execution thread - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.util.ExecutionListenerManager.clone"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.util.ExecutionListenerManager.this"), - - // [SPARK-25862][SQL] Remove rangeBetween APIs introduced in SPARK-21608 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.unboundedFollowing"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.unboundedPreceding"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.currentRow"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.Window.rangeBetween"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.expressions.WindowSpec.rangeBetween"), - - // [SPARK-23781][CORE] Merge token renewer functionality into HadoopDelegationTokenManager - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.nextCredentialRenewalTime"), - - // [SPARK-26133][ML] Remove deprecated OneHotEncoder and rename OneHotEncoderEstimator to OneHotEncoder - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.OneHotEncoderEstimator"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.OneHotEncoder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.OneHotEncoder.transform"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.feature.OneHotEncoderEstimator$"), - - // [SPARK-30329][ML] add iterator/foreach methods for Vectors - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.activeIterator"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.activeIterator"), - - // [SPARK-26141] Enable custom metrics implementation in shuffle write - // Following are Java private classes - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"), - - // [SPARK-26139] Implement shuffle write metrics in SQL - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ShuffleDependency.this"), - - // [SPARK-26362][CORE] Remove 'spark.driver.allowMultipleContexts' to disallow multiple creation of SparkContexts - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.setActiveContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.markPartiallyConstructed"), - - // [SPARK-26457] Show hadoop configurations in HistoryServer environment tab - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationEnvironmentInfo.this"), - - // [SPARK-30144][ML] Make MultilayerPerceptronClassificationModel extend MultilayerPerceptronParams - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.layers"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"), - - // [SPARK-30630][ML] Remove numTrees in GBT - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.numTrees"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.numTrees"), - - // Data Source V2 API changes - (problem: Problem) => problem match { - case MissingClassProblem(cls) => - !cls.fullName.startsWith("org.apache.spark.sql.sources.v2") - case _ => true - }, - - // [SPARK-27521][SQL] Move data source v2 to catalyst module - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarBatch"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ArrowColumnVector"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarRow"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarArray"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarMap"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnVector"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThanOrEqual"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringEndsWith"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThanOrEqual$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.In$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Not"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNotNull"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThan"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThanOrEqual"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualNullSafe$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThan$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.In"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.And"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringStartsWith$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualNullSafe"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringEndsWith$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThanOrEqual$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Not$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNull$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThan$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNotNull$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Or"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualTo$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThan"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringContains"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Filter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNull"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualTo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.And$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Or$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringStartsWith"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringContains$"), - - // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.inputTypes"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullableTypes_="), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.dataType"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.f"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.this"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$2"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$1"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$3"), - - // [SPARK-11215][ML] Add multiple columns support to StringIndexer - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), - - // [SPARK-26616][MLlib] Expose document frequency in IDFModel - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf"), - - // [SPARK-28199][SS] Remove deprecated ProcessingTime - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.ProcessingTime"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.ProcessingTime$"), - - // [SPARK-25382][SQL][PYSPARK] Remove ImageSchema.readImages in 3.0 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.image.ImageSchema.readImages"), - - // [SPARK-25341][CORE] Support rolling back a shuffle map stage and re-generate the shuffle files - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.copy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.copy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleBlockId.copy$default$2"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleBlockId.copy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleBlockId.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.apply"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.apply"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.ShuffleBlockId.apply"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleIndexBlockId.mapId"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleDataBlockId.mapId"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.storage.ShuffleBlockId.mapId"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.FetchFailed.mapId"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.FetchFailed$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.apply"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.FetchFailed.copy$default$5"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.copy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.FetchFailed.copy$default$3"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.FetchFailed.this"), - - // [SPARK-28957][SQL] Copy any "spark.hive.foo=bar" spark properties into hadoop conf as "hive.foo=bar" - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.appendS3AndSparkHadoopConfigurations"), - - // [SPARK-29348] Add observable metrics. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryProgress.this"), - - // [SPARK-30377][ML] Make AFTSurvivalRegression extend Regressor - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.setFeaturesCol"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.setPredictionCol"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setFeaturesCol"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setLabelCol"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.setPredictionCol"), - - // [SPARK-29543][SS][UI] Init structured streaming ui - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStartedEvent.this"), - - // [SPARK-30667][CORE] Add allGather method to BarrierTaskContext - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.RequestToSync") - ) - - // Exclude rules for 2.4.x - lazy val v24excludes = v23excludes ++ Seq( - // [SPARK-25248] add package private methods to TaskContext - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskFailed"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markInterrupted"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.fetchFailed"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.markTaskCompleted"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperties"), - - // [SPARK-10697][ML] Add lift to Association rules - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.fpm.FPGrowthModel.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.fpm.AssociationRules#Rule.this"), - - // [SPARK-24296][CORE] Replicate large blocks as a stream. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.this"), - // [SPARK-23528] Add numIter to ClusteringSummary - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.ClusteringSummary.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansSummary.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.BisectingKMeansSummary.this"), - // [SPARK-6237][NETWORK] Network-layer changes to allow stream upload - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockRpcServer.receive"), - - // [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"), - - // [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printVersionAndExit"), - - // [SPARK-23412][ML] Add cosine distance measure to BisectingKmeans - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.org$apache$spark$ml$param$shared$HasDistanceMeasure$_setter_$distanceMeasure_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.getDistanceMeasure"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasDistanceMeasure.distanceMeasure"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.BisectingKMeansModel#SaveLoadV1_0.load"), - - // [SPARK-20659] Remove StorageStatus, or make it private - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOffHeapStorageMemory"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOffHeapStorageMemory"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.usedOnHeapStorageMemory"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.SparkExecutorInfo.totalOnHeapStorageMemory"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.getExecutorStorageStatus"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numBlocks"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocks"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.containsBlock"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddBlocksById"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.numRddBlocksById"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.memUsedByRdd"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.cacheSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.rddStorageLevel"), - - // [SPARK-23455][ML] Default Params in ML should be saved separately in metadata - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.paramMap"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$paramMap_="), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.defaultParamMap"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.param.Params.org$apache$spark$ml$param$Params$_setter_$defaultParamMap_="), - - // [SPARK-7132][ML] Add fit with validation set to spark.ml GBT - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.getValidationIndicatorCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.org$apache$spark$ml$param$shared$HasValidationIndicatorCol$_setter_$validationIndicatorCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasValidationIndicatorCol.validationIndicatorCol"), - - // [SPARK-23042] Use OneHotEncoderModel to encode labels in MultilayerPerceptronClassifier - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.classification.LabelConverter"), - - // [SPARK-21842][MESOS] Support Kerberos ticket renewal and creation in Mesos - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getDateOfNextUpdate"), - - // [SPARK-23366] Improve hot reading path in ReadAheadInputStream - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.io.ReadAheadInputStream.this"), - - // [SPARK-22941][CORE] Do not exit JVM when submit fails with in-process launcher. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.addJarToClasspath"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.mergeFileLists"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment$default$2"), - - // Data Source V2 API changes - // TODO: they are unstable APIs and should not be tracked by mima. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.ReadSupportWithSchema"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.createDataReaderFactories"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.createBatchDataReaderFactories"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch.planBatchInputPartitions"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsScanUnsafeRow"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader.createDataReaderFactories"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.DataSourceReader.planInputPartitions"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.SupportsPushDownCatalystFilters"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataReader"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.getStatistics"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics.estimateStatistics"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.DataReaderFactory"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.ContinuousDataReader"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.SupportsWriteInternalRow"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createDataWriter"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createDataWriter"), - - // Changes to HasRawPredictionCol. - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.rawPredictionCol"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.org$apache$spark$ml$param$shared$HasRawPredictionCol$_setter_$rawPredictionCol_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasRawPredictionCol.getRawPredictionCol"), - - // [SPARK-15526][ML][FOLLOWUP] Make JPMML provided scope to avoid including unshaded JARs - (problem: Problem) => problem match { - case MissingClassProblem(cls) => - !cls.fullName.startsWith("org.sparkproject.jpmml") && - !cls.fullName.startsWith("org.sparkproject.dmg.pmml") && - !cls.fullName.startsWith("org.spark_project.jpmml") && - !cls.fullName.startsWith("org.spark_project.dmg.pmml") - case _ => true - } - ) - - // Exclude rules for 2.3.x - lazy val v23excludes = v22excludes ++ Seq( - // [SPARK-22897] Expose stageAttemptId in TaskContext - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), - - // SPARK-22789: Map-only continuous processing execution - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$9"), - - // SPARK-22372: Make cluster submission use SparkApplication. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getSecretKeyFromUserCredentials"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.isYarnMode"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getCurrentUserCredentials"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.addSecretKeyToUserCredentials"), - - // SPARK-18085: Better History Server scalability for many / large applications - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.env.EnvironmentListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.exec.ExecutorsListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.storage.StorageListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.StorageStatusListener"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorStageSummary.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.JobData.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkStatusTracker.this"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ui.jobs.JobProgressListener"), - - // [SPARK-20495][SQL] Add StorageLevel to cacheTable API - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), - - // [SPARK-19937] Add remote bytes read to disk. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetricDistributions.this"), - - // [SPARK-21276] Update lz4-java to the latest (v1.4.0) - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.io.LZ4BlockInputStream"), - - // [SPARK-17139] Add model summary for MultinomialLogisticRegression - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictionCol"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.labels"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.truePositiveRateByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.falsePositiveRateByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.precisionByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.recallByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.fMeasureByLabel"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.accuracy"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedTruePositiveRate"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFalsePositiveRate"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedRecall"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedPrecision"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.weightedFMeasure"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.asBinary"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.org$apache$spark$ml$classification$LogisticRegressionSummary$_setter_$org$apache$spark$ml$classification$LogisticRegressionSummary$$multiclassMetrics_="), - - // [SPARK-14280] Support Scala 2.12 - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transformWith"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.FutureAction.transform"), - - // [SPARK-21087] CrossValidator, TrainValidationSplit expose sub models after fitting: Scala - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.CrossValidatorModel$CrossValidatorModelWriter"), - ProblemFilters.exclude[FinalClassProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel$TrainValidationSplitModelWriter"), - - // [SPARK-21728][CORE] Allow SparkSubmit to use Logging - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFileList"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.downloadFile"), - - // [SPARK-21714][CORE][YARN] Avoiding re-uploading remote resources in yarn client mode - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.prepareSubmitEnvironment"), - - // [SPARK-22324][SQL][PYTHON] Upgrade Arrow to 0.8.0 - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.network.util.AbstractFileRegion.transfered"), - - // [SPARK-20643][CORE] Add listener implementation to collect app state - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$5"), - - // [SPARK-20648][CORE] Port JobsTab and StageTab to the new UI backend - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$12"), - - // [SPARK-21462][SS] Added batchId to StreamingQueryProgress.json - // [SPARK-21409][SS] Expose state store memory usage in SQL metrics and progress updates - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StateOperatorProgress.this"), - - // [SPARK-22278][SS] Expose current event time watermark and current processing time in GroupState - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentWatermarkMs"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.GroupState.getCurrentProcessingTimeMs"), - - // [SPARK-20542][ML][SQL] Add an API to Bucketizer that can bin multiple columns - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasOutputCols.org$apache$spark$ml$param$shared$HasOutputCols$_setter_$outputCols_="), - - // [SPARK-18619][ML] Make QuantileDiscretizer/Bucketizer/StringIndexer/RFormula inherit from HasHandleInvalid - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.Bucketizer.getHandleInvalid"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexer.getHandleInvalid"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.getHandleInvalid"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.getHandleInvalid") - ) - - // Exclude rules for 2.2.x - lazy val v22excludes = v21excludes ++ Seq( - // [SPARK-20355] Add per application spark version on the history server headerpage - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), - - // [SPARK-19652][UI] Do auth checks for REST API access. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.withSparkUI"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.status.api.v1.UIRootFromServletContext"), - - // [SPARK-18663][SQL] Simplify CountMinSketch aggregate implementation - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.util.sketch.CountMinSketch.toByteArray"), - - // [SPARK-18949] [SQL] Add repairTable API to Catalog - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.recoverPartitions"), - - // [SPARK-18537] Add a REST api to spark streaming - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.streaming.scheduler.StreamingListener.onStreamingStarted"), - - // [SPARK-19148][SQL] do not expose the external table concept in Catalog - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable"), - - // [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"), - - // [SPARK-19267] Fetch Failure handling robust to user error handling - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"), - - // [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$10"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$11"), - - // [SPARK-17161] Removing Python-friendly constructors not needed - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.OneVsRestModel.this"), - - // [SPARK-19820] Allow reason to be specified to task kill - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.TaskKilled$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.countTowardsTaskFailures"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.toErrorString"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.TaskKilled.toString"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.killTaskIfInterrupted"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getKillReason"), - - // [SPARK-19876] Add one time trigger, and improve Trigger APIs - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.sql.streaming.Trigger"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.streaming.ProcessingTime"), - - // [SPARK-17471][ML] Add compressed method to ML matrices - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressed"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedColMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.compressedRowMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isRowMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.isColMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSparseSizeInBytes"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDense"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparse"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseRowMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseRowMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseColMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getDenseSizeInBytes"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"), - - // [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum") - ) ++ Seq( - // [SPARK-17019] Expose on-heap and off-heap memory usage in various places - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.this"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.apply"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.StorageStatus.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.RDDDataDistribution.this") - ) - - // Exclude rules for 2.1.x - lazy val v21excludes = v20excludes ++ { - Seq( - // [SPARK-17671] Spark 2.0 history server summary page is slow even set spark.history.ui.maxApplications - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.deploy.history.HistoryServer.getApplicationList"), - // [SPARK-14743] Improve delegation token handling in secure cluster - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTimeFromNowToRenewal"), - // [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"), - // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select"), - // [SPARK-16967] Move Mesos to Module - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkMasterRegex.MESOS_REGEX"), - // [SPARK-16240] ML persistence backward compatibility for LDA - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.clustering.LDA$"), - // [SPARK-17717] Add Find and Exists method to Catalog. - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getDatabase"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getTable"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.getFunction"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.databaseExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.tableExists"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.functionExists"), - - // [SPARK-17731][SQL][Streaming] Metrics for structured streaming - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SourceStatus.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.SourceStatus.offsetDesc"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.status"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SinkStatus.this"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryInfo"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryStarted.queryInfo"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryProgress.queryInfo"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.queryInfo"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryStarted"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryProgress"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryListener$QueryTerminated"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryStarted"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryStarted"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryProgress"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryProgress"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryTerminated"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener.onQueryTerminated"), - - // [SPARK-18516][SQL] Split state and progress in streaming - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.SourceStatus"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.SinkStatus"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.sinkStatus"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.sourceStatuses"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQuery.id"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.lastProgress"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.recentProgress"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.id"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.get"), - - // [SPARK-17338][SQL] add global temp view - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropGlobalTempView"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.dropTempView"), - - // [SPARK-18034] Upgrade to MiMa 0.1.11 to fix flakiness. - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.aggregationDepth"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.getAggregationDepth"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasAggregationDepth.org$apache$spark$ml$param$shared$HasAggregationDepth$_setter_$aggregationDepth_="), - - // [SPARK-18236] Reduce duplicate objects in Spark UI and HistoryServer - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.TaskInfo.accumulables"), - - // [SPARK-18657] Add StreamingQuery.runId - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQuery.runId"), - - // [SPARK-18694] Add StreamingQuery.explain and exception to Python and fix StreamingQueryException - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryException$"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.startOffset"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.endOffset"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryException.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryException.query") - ) - } - - // Exclude rules for 2.0.x - lazy val v20excludes = { - Seq( - ProblemFilters.exclude[Problem]("org.apache.spark.rpc.*"), - ProblemFilters.exclude[Problem]("org.spark-project.jetty.*"), - ProblemFilters.exclude[Problem]("org.spark_project.jetty.*"), - ProblemFilters.exclude[Problem]("org.sparkproject.jetty.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.internal.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.unused.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.unsafe.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.memory.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.util.collection.unsafe.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.catalyst.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.execution.*"), - ProblemFilters.exclude[Problem]("org.apache.spark.sql.internal.*"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationAttemptInfo.this"), - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.status.api.v1.ApplicationAttemptInfo.$default$5"), - // SPARK-14042 Add custom coalescer support - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.coalesce"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.rdd.PartitionCoalescer$LocationIterator"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.rdd.PartitionCoalescer"), - // SPARK-15532 Remove isRootContext flag from SQLContext. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.isRootContext"), - // SPARK-12600 Remove SQL deprecated methods - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.applySchema"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.parquetFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jdbc"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.dialectClassName"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.getSQLDialect"), - // SPARK-13664 Replace HadoopFsRelation with FileFormat - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.source.libsvm.LibSVMRelation"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelationProvider"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.HadoopFsRelation$FileStatusCache"), - // SPARK-15543 Rename DefaultSources to make them more self-describing - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.source.libsvm.DefaultSource") - ) ++ Seq( - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory"), - // SPARK-14358 SparkListener from trait to abstract class - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.addSparkListener"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.JavaSparkListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkFirehoseListener"), - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.scheduler.SparkListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.jobs.JobProgressListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.exec.ExecutorsListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.env.EnvironmentListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ui.storage.StorageListener"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.StorageStatusListener") - ) ++ - Seq( - // SPARK-3369 Fix Iterable/Iterator in Java API - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.FlatMapFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.FlatMapFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.DoubleFlatMapFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.FlatMapFunction2.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.FlatMapFunction2.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.PairFlatMapFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.PairFlatMapFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.CoGroupFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.CoGroupFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.MapPartitionsFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.MapPartitionsFunction.call"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.function.FlatMapGroupsFunction.call"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.function.FlatMapGroupsFunction.call") - ) ++ - Seq( - // [SPARK-6429] Implement hashCode and equals together - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Partition.org$apache$spark$Partition$$super=uals") - ) ++ - Seq( - // SPARK-4819 replace Guava Optional - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getSparkHome"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner") - ) ++ - Seq( - // SPARK-12481 Remove Hadoop 1.x - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil"), - // SPARK-12615 Remove deprecated APIs in core - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.$default$6"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.numericRDDToDoubleRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intToIntWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.writableWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToAsyncRDDActions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.boolToBoolWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longToLongWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToOrderedRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.booleanWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringToText"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleToDoubleWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToSequenceFileRDDFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesToBytesWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringWritableConverter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatToFloatWritable"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions$default$4"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addOnCompleteCallback"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.runningLocally"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.attemptId"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.defaultMinSplits"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.runJob"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.runJob"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.tachyonFolderName"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.initLocalProperties"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearJars"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearFiles"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith$default$2"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.toArray"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith$default$2"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.filterWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.foreachWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit$default$2"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.SequenceFileRDDFunctions.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.splits"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.toArray"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.externalBlockStoreFolderName"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockManager"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore") - ) ++ Seq( - // SPARK-12149 Added new fields to ExecutorSummary - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") - ) ++ - // SPARK-12665 Remove deprecated and unused classes - Seq( - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.graphx.GraphKryoRegistrator"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$") - ) ++ Seq( - // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge") - ) ++ Seq( - // SPARK-12510 Refactor ActorReceiver to support Java - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") - ) ++ Seq( - // SPARK-12895 Implement TaskMetrics using accumulators - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.internalMetricsToAccumulators"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectInternalAccumulators"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectAccumulators") - ) ++ Seq( - // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Accumulator.this"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.Accumulator.initialValue") - ) ++ Seq( - // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":") - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log__="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log__="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log__="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=") - ) ++ Seq( - // SPARK-12689 Migrate DDL parsing to the newly absorbed parser - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser") - ) ++ Seq( - // SPARK-7799 Add "streaming-akka" project - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$6"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$5"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$4"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$3"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.streaming.zeromq.ZeroMQReceiver"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver$Supervisor") - ) ++ Seq( - // SPARK-12348 Remove deprecated Streaming APIs. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.dstream.DStream.foreach"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions$default$4"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.awaitTermination"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.networkStream"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.api.java.JavaStreamingContextFactory"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.awaitTermination"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.sc"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreach"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.getOrCreate") - ) ++ Seq( - // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus") - ) ++ Seq( - // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation") - ) ++ Seq( - // SPARK-6363 Make Scala 2.11 the default Scala version - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint") - ) ++ Seq( - // SPARK-7889 - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI"), - // SPARK-13296 - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.UDFRegistration.register"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$") - ) ++ Seq( - // SPARK-12995 Remove deprecated APIs in graphx - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.lib.SVDPlusPlus.runSVDPlusPlus"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets") - ) ++ Seq( - // SPARK-13426 Remove the support of SIMR - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX") - ) ++ Seq( - // SPARK-13413 Remove SparkContext.metricsSystem/schedulerBackend_ setter - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metricsSystem"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.schedulerBackend_=") - ) ++ Seq( - // SPARK-13220 Deprecate yarn-client and yarn-cluster mode - ProblemFilters.exclude[MissingMethodProblem]( - "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler") - ) ++ Seq( - // SPARK-13465 TaskContext. - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener") - ) ++ Seq ( - // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this") - ) ++ Seq( - // SPARK-13526 Move SQLContext per-session states to new class - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.sql.UDFRegistration.this") - ) ++ Seq( - // [SPARK-13486][SQL] Move SQLConf into an internal package - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$") - ) ++ Seq( - //SPARK-11011 UserDefinedType serialization should be strongly typed - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"), - // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition") - ) ++ Seq( - // [SPARK-13244][SQL] Migrates DataFrame to Dataset - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.tables"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.sql"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.baseRelationToDataFrame"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.table"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrame.apply"), - - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.localSeqToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.stringRddToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.rddToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"), - - // [SPARK-14451][SQL] Move encoder definition into Aggregator interface - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"), - - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions") - ) ++ Seq( - // [SPARK-13686][MLLIB][STREAMING] Add a constructor parameter `reqParam` to (Streaming)LinearRegressionWithSGD - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD.this") - ) ++ Seq( - // SPARK-15250 Remove deprecated json API in DataFrameReader - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameReader.json") - ) ++ Seq( - // SPARK-13920: MIMA checks should apply to @Experimental and @DeveloperAPI APIs - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineCombinersByKey"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineValuesByKey"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.run"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.runJob"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.actorSystem"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.cacheManager"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getConfigurationFromJobContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTaskAttemptIDFromTaskAttemptContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.newConfiguration"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback_="), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.setBytesReadCallback"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.updateBytesRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decFetchWaitTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decLocalBlocksFetched"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRecordsRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBlocksFetched"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBytesRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleBytesWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleWriteTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleBytesWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleWriteTime"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.setShuffleRecordsWritten"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.PCAModel.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithContext"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.taskMetrics"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.TaskInfo.attempt"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.ExperimentalMethods.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUDF"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUdf"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.cumeDist"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.denseRank"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.inputFileName"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.isNaN"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.percentRank"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.rowNumber"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.sparkPartitionId"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.externalBlockStoreSize"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsedByRdd"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatusListener.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.scheduler.BatchInfo.streamIdToNumRecords"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.storageStatusList"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.storage.StorageListener.storageStatusList"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.apply"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.copy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.InputMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.OutputMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transformImpl"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.extractLabeledPoints"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.GBTClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayes.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRest.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRestModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.RandomForestClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeans.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logLikelihood"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logPerplexity"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.RegressionEvaluator.evaluate"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Binarizer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Bucketizer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelector.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.HashingTF.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDF.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDFModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IndexToString.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Interaction.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.OneHotEncoder.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCA.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCAModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormula.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormulaModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.SQLTransformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StopWordsRemover.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorAssembler.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexer.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexerModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorSlicer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2Vec.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALS.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.GBTRegressor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.extractWeightedLabeledPoints"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.extractWeightedLabeledPoints"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegression.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionTrainingSummary.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplit.fit"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.RegressionMetrics.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameWriter.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.broadcast"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.callUDF"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.InsertableRelation.insert"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.fMeasureByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.pr"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.precisionByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.predictions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.recallByThreshold"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.roc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.describeTopics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.getVectors"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.itemFactors"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.userFactors"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.predictions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.residuals"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.name"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.value"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.drop"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.fill"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.replace"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.jdbc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.json"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.load"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.orc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.parquet"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.table"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.text"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.crosstab"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.freqItems"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.sampleBy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.createExternalTable"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.emptyDataFrame"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.range"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.functions.udf"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.JobLogger"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorHelper"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics$"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics$"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics$"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.functions$"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Predictor.train"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListener.onOtherEvent"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.InsertableRelation.insert") - ) ++ Seq( - // [SPARK-13926] Automatically use Kryo serializer when shuffling RDDs with simple types - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ShuffleDependency.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ShuffleDependency.serializer"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.Serializer$") - ) ++ Seq( - // SPARK-13927: add row/column iterator to local matrices - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.rowIter"), - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.colIter") - ) ++ Seq( - // SPARK-13948: MiMa Check should catch if the visibility change to `private` - // TODO(josh): Some of these may be legitimate incompatibilities; we should follow up before the 2.0.0 release - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.toDS"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.askTimeout"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.lookupTimeout"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.UnaryTransformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassifier.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressor.train"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.groupBy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.select"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.toDF"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Logging.initializeLogIfNecessary"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerEvent.logEvent"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance") - ) ++ Seq( - // [SPARK-14014] Replace existing analysis.Catalog with SessionCatalog - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.this") - ) ++ Seq( - // [SPARK-13928] Move org.apache.spark.Logging into org.apache.spark.internal.Logging - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Logging"), - (problem: Problem) => problem match { - case MissingTypesProblem(_, missing) - if missing.map(_.fullName).sameElements(Seq("org.apache.spark.Logging")) => false - case _ => true - } - ) ++ Seq( - // [SPARK-13990] Automatically pick serializer when caching RDDs - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.uploadBlock") - ) ++ Seq( - // [SPARK-14089][CORE][MLLIB] Remove methods that has been deprecated since 1.1, 1.2, 1.3, 1.4, and 1.5 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.getThreadLocal"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeReduce"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeAggregate"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.defaultStategy"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.saveLabeledData"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLabeledData"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.optimization.LBFGS.setMaxNumIterations"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.setScoreCol") - ) ++ Seq( - // [SPARK-14205][SQL] remove trait Queryable - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset") - ) ++ Seq( - // [SPARK-11262][ML] Unit test for gradient, loss layers, memory management - // for multilayer perceptron. - // This class is marked as `private`. - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction") - ) ++ Seq( - // [SPARK-13674][SQL] Add wholestage codegen support to Sample - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this") - ) ++ Seq( - // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary - ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this") - ) ++ Seq( - // [SPARK-14437][Core] Use the address that NettyBlockTransferService listens to create BlockManagerId - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.this") - ) ++ Seq( - // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this") - ) ++ Seq( - // [SPARK-14475] Propagate user-defined context from driver to executors - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"), - // [SPARK-14617] Remove deprecated APIs in TaskMetrics - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$"), - // [SPARK-14628] Simplify task metrics by always tracking read/write metrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.readMethod"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.writeMethod") - ) ++ Seq( - // SPARK-14628: Always track input/output/shuffle metrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.totalBlocksFetched"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.inputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.outputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleWriteMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleReadMetrics"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.inputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.outputMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleWriteMetrics"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleReadMetrics"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") - ) ++ Seq( - // SPARK-13643: Move functionality from SQLContext to SparkSession - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.getSchema") - ) ++ Seq( - // [SPARK-14407] Hides HadoopFsRelation related data source API into execution package - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriter"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriterFactory") - ) ++ Seq( - // SPARK-14734: Add conversions between mllib and ml Vector, Matrix types - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asML"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asML") - ) ++ Seq( - // SPARK-14704: Create accumulators in TaskMetrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.this") - ) ++ Seq( - // SPARK-14861: Replace internal usages of SQLContext with SparkSession - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.clustering.LocalLDAModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.clustering.DistributedLDAModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.ml.clustering.LDAModel.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]( - "org.apache.spark.ml.clustering.LDAModel.sqlContext"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.sql.Dataset.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]( - "org.apache.spark.sql.DataFrameReader.this") - ) ++ Seq( - // SPARK-14542 configurable buffer size for pipe RDD - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.pipe"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.pipe") - ) ++ Seq( - // [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory - ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable") - ) ++ Seq( - // [SPARK-14952][Core][ML] Remove methods deprecated in 1.6 - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.input.PortableDataStream.close"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.weights"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionModel.weights") - ) ++ Seq( - // [SPARK-10653] [Core] Remove unnecessary things from SparkEnv - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.sparkFilesDir"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.blockTransferService") - ) ++ Seq( - // SPARK-14654: New accumulator API - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ExceptionFailure$"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.apply"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.metrics"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.remoteBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.totalBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.localBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.remoteBlocksFetched"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.localBlocksFetched") - ) ++ Seq( - // [SPARK-14615][ML] Use the new ML Vector and Matrix in the ML pipeline based algorithms - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.getOldDocConcentration"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.estimatedDocConcentration"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.topicsMatrix"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.clusterCenters"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LabelConverter.decodeLabel"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LabelConverter.encodeLabeledPoint"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.weights"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.predict"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.predictRaw"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.raw2probabilityInPlace"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.theta"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.pi"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.probability2prediction"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.predictRaw"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.raw2prediction"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.raw2probabilityInPlace"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.predict"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.coefficients"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.raw2prediction"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.predictRaw"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.ClassificationModel.predictRaw"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.ElementwiseProduct.getScalingVec"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ElementwiseProduct.setScalingVec"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.PCAModel.pc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.originalMax"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.originalMin"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.this"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.IDFModel.idf"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.mean"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.std"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.predict"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.coefficients"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.predictQuantiles"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.predictions"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.boundaries"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.predict"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.coefficients"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.this") - ) ++ Seq( - // [SPARK-15290] Move annotations, like @Since / @DeveloperApi, into spark-tags - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.package$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.package"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.Private"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.AlphaComponent"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.Experimental"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.DeveloperApi") - ) ++ Seq( - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asBreeze"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asBreeze") - ) ++ Seq( - // [SPARK-15914] Binary compatibility is broken since consolidation of Dataset and DataFrame - // in Spark 2.0. However, source level compatibility is still maintained. - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.load"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jsonFile"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jdbc"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.parquetFile"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.applySchema") - ) ++ Seq( - // SPARK-17096: Improve exception string reported through the StreamingQueryListener - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.stackTrace"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.this") - ) ++ Seq( - // SPARK-17406 limit timeline executor events - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorIdToData"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksActive"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksComplete"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputRecords"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleRead"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksFailed"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleWrite"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToDuration"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputBytes"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToLogUrls"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputBytes"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputRecords"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTotalCores"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksMax"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToJvmGCTime") - ) ++ Seq( - // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") - ) ++ Seq( - // [SPARK-17498] StringIndexer enhancement for handling unseen labels - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel") - ) ++ Seq( - // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") - ) ++ Seq( - // [SPARK-12221] Add CPU time to metrics - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") - ) ++ Seq( - // [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassifier"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassificationModel"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressor"), - ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressionModel"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"), - ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy") - ) ++ Seq( - // [SPARK-21680][ML][MLLIB]optimize Vector compress - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.toSparseWithSize"), - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.toSparseWithSize") - ) ++ Seq( - // [SPARK-3181][ML]Implement huber loss for LinearRegression. - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.org$apache$spark$ml$param$shared$HasLoss$_setter_$loss_="), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.getLoss"), - ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.loss") - ) - } - def excludes(version: String) = version match { + case v if v.startsWith("3.4") => v34excludes + case v if v.startsWith("3.3") => v33excludes case v if v.startsWith("3.2") => v32excludes - case v if v.startsWith("3.1") => v31excludes - case v if v.startsWith("3.0") => v30excludes - case v if v.startsWith("2.4") => v24excludes - case v if v.startsWith("2.3") => v23excludes - case v if v.startsWith("2.2") => v22excludes - case v if v.startsWith("2.1") => v21excludes - case v if v.startsWith("2.0") => v20excludes case _ => Seq() } } diff --git a/repl/pom.xml b/repl/pom.xml index 36d9b0e5e43aa..714fdf9d0d8a5 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 77f4385e277a2..dcd4ceace7fad 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../../pom.xml diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index 1d12e2ebce1c7..95ea5e12c35bc 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 301462026b190..0c764d83c503a 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index db7e3e03107ec..d049e217637d3 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 6c089f9feb3e3..631edbd8eb3e4 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java index 34f07b12b3666..66e8a431458f9 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java @@ -20,10 +20,7 @@ import java.util.Map; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; -import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; -import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.*; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -147,8 +144,10 @@ public void alterNamespace( } @Override - public boolean dropNamespace(String[] namespace) throws NoSuchNamespaceException { - return asNamespaceCatalog().dropNamespace(namespace); + public boolean dropNamespace( + String[] namespace, + boolean cascade) throws NoSuchNamespaceException, NonEmptyNamespaceException { + return asNamespaceCatalog().dropNamespace(namespace, cascade); } private TableCatalog asTableCatalog() { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java index f70746b612e92..c1a4960068d24 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java @@ -20,6 +20,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException; import java.util.Map; @@ -136,15 +137,20 @@ void alterNamespace( NamespaceChange... changes) throws NoSuchNamespaceException; /** - * Drop a namespace from the catalog, recursively dropping all objects within the namespace. + * Drop a namespace from the catalog with cascade mode, recursively dropping all objects + * within the namespace if cascade is true. *

* If the catalog implementation does not support this operation, it may throw * {@link UnsupportedOperationException}. * * @param namespace a multi-part namespace + * @param cascade When true, deletes all objects under the namespace * @return true if the namespace was dropped * @throws NoSuchNamespaceException If the namespace does not exist (optional) + * @throws NonEmptyNamespaceException If the namespace is non-empty and cascade is false * @throws UnsupportedOperationException If drop is not a supported operation */ - boolean dropNamespace(String[] namespace) throws NoSuchNamespaceException; + boolean dropNamespace( + String[] namespace, + boolean cascade) throws NoSuchNamespaceException, NonEmptyNamespaceException; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java new file mode 100644 index 0000000000000..1419e975f5695 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.index; + +import java.util.Map; +import java.util.Properties; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalyst.analysis.IndexAlreadyExistsException; +import org.apache.spark.sql.catalyst.analysis.NoSuchIndexException; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * Table methods for working with index + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsIndex extends Table { + + /** + * A reserved property to specify the index type. + */ + String PROP_TYPE = "type"; + + /** + * Creates an index. + * + * @param indexName the name of the index to be created + * @param columns the columns on which index to be created + * @param columnsProperties the properties of the columns on which index to be created + * @param properties the properties of the index to be created + * @throws IndexAlreadyExistsException If the index already exists. + */ + void createIndex(String indexName, + NamedReference[] columns, + Map> columnsProperties, + Map properties) + throws IndexAlreadyExistsException; + + /** + * Drops the index with the given name. + * + * @param indexName the name of the index to be dropped. + * @throws NoSuchIndexException If the index does not exist. + */ + void dropIndex(String indexName) throws NoSuchIndexException; + + /** + * Checks whether an index exists in this table. + * + * @param indexName the name of the index + * @return true if the index exists, false otherwise + */ + boolean indexExists(String indexName); + + /** + * Lists all the indexes in this table. + */ + TableIndex[] listIndexes(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java new file mode 100644 index 0000000000000..977ed8d6c7528 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog.index; + +import java.util.Collections; +import java.util.Map; +import java.util.Properties; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.NamedReference; + +/** + * Index in a table + * + * @since 3.3.0 + */ +@Evolving +public final class TableIndex { + private String indexName; + private String indexType; + private NamedReference[] columns; + private Map columnProperties = Collections.emptyMap(); + private Properties properties; + + public TableIndex( + String indexName, + String indexType, + NamedReference[] columns, + Map columnProperties, + Properties properties) { + this.indexName = indexName; + this.indexType = indexType; + this.columns = columns; + this.columnProperties = columnProperties; + this.properties = properties; + } + + /** + * @return the Index name. + */ + public String indexName() { return indexName; } + + /** + * @return the indexType of this Index. + */ + public String indexType() { return indexType; } + + /** + * @return the column(s) this Index is on. Could be multi columns (a multi-column index). + */ + public NamedReference[] columns() { return columns; } + + /** + * @return the map of column and column property map. + */ + public Map columnProperties() { return columnProperties; } + + /** + * Returns the index properties. + */ + public Properties properties() { return properties; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java new file mode 100644 index 0000000000000..26b97b46fe2ef --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import java.io.Serializable; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.DataType; + +/** + * Represents a cast expression in the public logical expression API. + * + * @since 3.3.0 + */ +@Evolving +public class Cast implements Expression, Serializable { + private Expression expression; + private DataType dataType; + + public Cast(Expression expression, DataType dataType) { + this.expression = expression; + this.dataType = dataType; + } + + public Expression expression() { return expression; } + public DataType dataType() { return dataType; } + + @Override + public Expression[] children() { return new Expression[]{ expression() }; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java index 6540c91597582..76dfe73f666cf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.connector.expressions; +import java.util.Arrays; + import org.apache.spark.annotation.Evolving; /** @@ -26,8 +28,23 @@ */ @Evolving public interface Expression { + Expression[] EMPTY_EXPRESSION = new Expression[0]; + /** * Format the expression as a human readable SQL-like string. */ - String describe(); + default String describe() { return this.toString(); } + + /** + * Returns an array of the children of this node. Children should not change. + */ + Expression[] children(); + + /** + * List of fields or columns that are referenced by this expression. + */ + default NamedReference[] references() { + return Arrays.stream(children()).map(e -> e.references()) + .flatMap(Arrays::stream).distinct().toArray(NamedReference[]::new); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java new file mode 100644 index 0000000000000..58082d5ee09c1 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.Objects; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder; + +/** + * The general representation of SQL scalar expressions, which contains the upper-cased + * expression name and all the children expressions. Please also see {@link Predicate} + * for the supported predicate expressions. + *

+ * The currently supported SQL scalar expressions: + *

    + *
  1. Name: + + *
      + *
    • SQL semantic: expr1 + expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  2. + *
  3. Name: - + *
      + *
    • SQL semantic: expr1 - expr2 or - expr
    • + *
    • Since version: 3.3.0
    • + *
    + *
  4. + *
  5. Name: * + *
      + *
    • SQL semantic: expr1 * expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  6. + *
  7. Name: / + *
      + *
    • SQL semantic: expr1 / expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  8. + *
  9. Name: % + *
      + *
    • SQL semantic: expr1 % expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  10. + *
  11. Name: & + *
      + *
    • SQL semantic: expr1 & expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  12. + *
  13. Name: | + *
      + *
    • SQL semantic: expr1 | expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  14. + *
  15. Name: ^ + *
      + *
    • SQL semantic: expr1 ^ expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  16. + *
  17. Name: ~ + *
      + *
    • SQL semantic: ~ expr
    • + *
    • Since version: 3.3.0
    • + *
    + *
  18. + *
  19. Name: CASE_WHEN + *
      + *
    • SQL semantic: + * CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END + *
    • + *
    • Since version: 3.3.0
    • + *
    + *
  20. + *
  21. Name: ABS + *
      + *
    • SQL semantic: ABS(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  22. + *
  23. Name: COALESCE + *
      + *
    • SQL semantic: COALESCE(expr1, expr2)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  24. + *
  25. Name: LN + *
      + *
    • SQL semantic: LN(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  26. + *
  27. Name: EXP + *
      + *
    • SQL semantic: EXP(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  28. + *
  29. Name: POWER + *
      + *
    • SQL semantic: POWER(expr, number)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  30. + *
  31. Name: SQRT + *
      + *
    • SQL semantic: SQRT(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  32. + *
  33. Name: FLOOR + *
      + *
    • SQL semantic: FLOOR(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  34. + *
  35. Name: CEIL + *
      + *
    • SQL semantic: CEIL(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  36. + *
  37. Name: WIDTH_BUCKET + *
      + *
    • SQL semantic: WIDTH_BUCKET(expr)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  38. + *
+ * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off, + * including: add, subtract, multiply, divide, remainder, pmod. + * + * @since 3.3.0 + */ +@Evolving +public class GeneralScalarExpression implements Expression, Serializable { + private String name; + private Expression[] children; + + public GeneralScalarExpression(String name, Expression[] children) { + this.name = name; + this.children = children; + } + + public String name() { return name; } + public Expression[] children() { return children; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + GeneralScalarExpression that = (GeneralScalarExpression) o; + return Objects.equals(name, that.name) && Arrays.equals(children, that.children); + } + + @Override + public int hashCode() { + return Objects.hash(name, children); + } + + @Override + public String toString() { + V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder(); + try { + return builder.build(this); + } catch (Throwable e) { + return name + "(" + + Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java index df9e58fa319fd..5e8aeafe74515 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java @@ -40,4 +40,7 @@ public interface Literal extends Expression { * Returns the SQL data type of the literal. */ DataType dataType(); + + @Override + default Expression[] children() { return EMPTY_EXPRESSION; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java index 167432fa0e86a..8c0f029a35832 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java @@ -32,4 +32,10 @@ public interface NamedReference extends Expression { * Each string in the returned array represents a field name. */ String[] fieldNames(); + + @Override + default Expression[] children() { return EMPTY_EXPRESSION; } + + @Override + default NamedReference[] references() { return new NamedReference[]{ this }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java index 72252457df26e..51401786ca5d7 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java @@ -40,4 +40,7 @@ public interface SortOrder extends Expression { * Returns the null ordering. */ NullOrdering nullOrdering(); + + @Override + default Expression[] children() { return new Expression[]{ expression() }; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java index 297205825c6a4..e9ead7fc5fd2a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java @@ -34,13 +34,11 @@ public interface Transform extends Expression { */ String name(); - /** - * Returns all field references in the transform arguments. - */ - NamedReference[] references(); - /** * Returns the arguments passed to the transform function. */ Expression[] arguments(); + + @Override + default Expression[] children() { return arguments(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java new file mode 100644 index 0000000000000..d09e5f7ba28a3 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; + +/** + * An aggregate function that returns the mean of all the values in a group. + * + * @since 3.3.0 + */ +@Evolving +public final class Avg implements AggregateFunc { + private final Expression input; + private final boolean isDistinct; + + public Avg(Expression column, boolean isDistinct) { + this.input = column; + this.isDistinct = isDistinct; + } + + public Expression column() { return input; } + public boolean isDistinct() { return isDistinct; } + + @Override + public Expression[] children() { return new Expression[]{ input }; } + + @Override + public String toString() { + if (isDistinct) { + return "AVG(DISTINCT " + input.describe() + ")"; + } else { + return "AVG(" + input.describe() + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java index 1273886e297bf..c840b29ad2546 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the number of the specific row in a group. @@ -27,26 +27,26 @@ */ @Evolving public final class Count implements AggregateFunc { - private final NamedReference column; + private final Expression input; private final boolean isDistinct; - public Count(NamedReference column, boolean isDistinct) { - this.column = column; + public Count(Expression column, boolean isDistinct) { + this.input = column; this.isDistinct = isDistinct; } - public NamedReference column() { return column; } + public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } + @Override + public Expression[] children() { return new Expression[]{ input }; } + @Override public String toString() { if (isDistinct) { - return "COUNT(DISTINCT " + column.describe() + ")"; + return "COUNT(DISTINCT " + input.describe() + ")"; } else { - return "COUNT(" + column.describe() + ")"; + return "COUNT(" + input.describe() + ")"; } } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java index f566ad164b8ef..ff8639cbd05a2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the number of rows in a group. @@ -31,8 +32,8 @@ public CountStar() { } @Override - public String toString() { return "COUNT(*)"; } + public Expression[] children() { return EMPTY_EXPRESSION; } @Override - public String describe() { return this.toString(); } + public String toString() { return "COUNT(*)"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java new file mode 100644 index 0000000000000..7016644543447 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.aggregate; + +import java.util.Arrays; +import java.util.stream.Collectors; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; + +/** + * The general implementation of {@link AggregateFunc}, which contains the upper-cased function + * name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial + * aggregate with this function to the source, but can only push down the entire aggregate. + *

+ * The currently supported SQL aggregate functions: + *

    + *
  1. VAR_POP(input1)
    Since 3.3.0
  2. + *
  3. VAR_SAMP(input1)
    Since 3.3.0
  4. + *
  5. STDDEV_POP(input1)
    Since 3.3.0
  6. + *
  7. STDDEV_SAMP(input1)
    Since 3.3.0
  8. + *
  9. COVAR_POP(input1, input2)
    Since 3.3.0
  10. + *
  11. COVAR_SAMP(input1, input2)
    Since 3.3.0
  12. + *
  13. CORR(input1, input2)
    Since 3.3.0
  14. + *
+ * + * @since 3.3.0 + */ +@Evolving +public final class GeneralAggregateFunc implements AggregateFunc { + private final String name; + private final boolean isDistinct; + private final Expression[] children; + + public String name() { return name; } + public boolean isDistinct() { return isDistinct; } + + public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) { + this.name = name; + this.isDistinct = isDistinct; + this.children = children; + } + + @Override + public Expression[] children() { return children; } + + @Override + public String toString() { + String inputsString = Arrays.stream(children) + .map(Expression::describe) + .collect(Collectors.joining(", ")); + if (isDistinct) { + return name + "(DISTINCT " + inputsString + ")"; + } else { + return name + "(" + inputsString + ")"; + } + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java index ed07cc9e32187..089d2bd751763 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the maximum value in a group. @@ -27,15 +27,15 @@ */ @Evolving public final class Max implements AggregateFunc { - private final NamedReference column; + private final Expression input; - public Max(NamedReference column) { this.column = column; } + public Max(Expression column) { this.input = column; } - public NamedReference column() { return column; } + public Expression column() { return input; } @Override - public String toString() { return "MAX(" + column.describe() + ")"; } + public Expression[] children() { return new Expression[]{ input }; } @Override - public String describe() { return this.toString(); } + public String toString() { return "MAX(" + input.describe() + ")"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java index 2e761037746fb..253cdea41dd76 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the minimum value in a group. @@ -27,15 +27,15 @@ */ @Evolving public final class Min implements AggregateFunc { - private final NamedReference column; + private final Expression input; - public Min(NamedReference column) { this.column = column; } + public Min(Expression column) { this.input = column; } - public NamedReference column() { return column; } + public Expression column() { return input; } @Override - public String toString() { return "MIN(" + column.describe() + ")"; } + public Expression[] children() { return new Expression[]{ input }; } @Override - public String describe() { return this.toString(); } + public String toString() { return "MIN(" + input.describe() + ")"; } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java index 057ebd89f7a19..4e01b92d8c369 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java @@ -18,7 +18,7 @@ package org.apache.spark.sql.connector.expressions.aggregate; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.Expression; /** * An aggregate function that returns the summation of all the values in a group. @@ -27,26 +27,26 @@ */ @Evolving public final class Sum implements AggregateFunc { - private final NamedReference column; + private final Expression input; private final boolean isDistinct; - public Sum(NamedReference column, boolean isDistinct) { - this.column = column; + public Sum(Expression column, boolean isDistinct) { + this.input = column; this.isDistinct = isDistinct; } - public NamedReference column() { return column; } + public Expression column() { return input; } public boolean isDistinct() { return isDistinct; } + @Override + public Expression[] children() { return new Expression[]{ input }; } + @Override public String toString() { if (isDistinct) { - return "SUM(DISTINCT " + column.describe() + ")"; + return "SUM(DISTINCT " + input.describe() + ")"; } else { - return "SUM(" + column.describe() + ")"; + return "SUM(" + input.describe() + ")"; } } - - @Override - public String describe() { return this.toString(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java new file mode 100644 index 0000000000000..accdd1acd7d0e --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; + +/** + * A predicate that always evaluates to {@code false}. + * + * @since 3.3.0 + */ +@Evolving +public final class AlwaysFalse extends Predicate implements Literal { + + public AlwaysFalse() { + super("ALWAYS_FALSE", new Predicate[]{}); + } + + public Boolean value() { + return false; + } + + public DataType dataType() { + return DataTypes.BooleanType; + } + + public String toString() { return "FALSE"; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java new file mode 100644 index 0000000000000..5a14f64b9b7e2 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; + +/** + * A predicate that always evaluates to {@code true}. + * + * @since 3.3.0 + */ +@Evolving +public final class AlwaysTrue extends Predicate implements Literal { + + public AlwaysTrue() { + super("ALWAYS_TRUE", new Predicate[]{}); + } + + public Boolean value() { + return true; + } + + public DataType dataType() { + return DataTypes.BooleanType; + } + + public String toString() { return "TRUE"; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java new file mode 100644 index 0000000000000..179a4b3c6349d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; + +/** + * A predicate that evaluates to {@code true} iff both {@code left} and {@code right} evaluate to + * {@code true}. + * + * @since 3.3.0 + */ +@Evolving +public final class And extends Predicate { + + public And(Predicate left, Predicate right) { + super("AND", new Predicate[]{left, right}); + } + + public Predicate left() { return (Predicate) children()[0]; } + public Predicate right() { return (Predicate) children()[1]; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java new file mode 100644 index 0000000000000..d65c9f0b6c3d9 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; + +/** + * A predicate that evaluates to {@code true} iff {@code child} is evaluated to {@code false}. + * + * @since 3.3.0 + */ +@Evolving +public final class Not extends Predicate { + + public Not(Predicate child) { + super("NOT", new Predicate[]{child}); + } + + public Predicate child() { return (Predicate) children()[0]; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java new file mode 100644 index 0000000000000..7f1717cc7da58 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; + +/** + * A predicate that evaluates to {@code true} iff at least one of {@code left} or {@code right} + * evaluates to {@code true}. + * + * @since 3.3.0 + */ +@Evolving +public final class Or extends Predicate { + + public Or(Predicate left, Predicate right) { + super("OR", new Predicate[]{left, right}); + } + + public Predicate left() { return (Predicate) children()[0]; } + public Predicate right() { return (Predicate) children()[1]; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java new file mode 100644 index 0000000000000..e58cddc274c5f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.expressions.filter; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; + +/** + * The general representation of predicate expressions, which contains the upper-cased expression + * name and all the children expressions. You can also use these concrete subclasses for better + * type safety: {@link And}, {@link Or}, {@link Not}, {@link AlwaysTrue}, {@link AlwaysFalse}. + *

+ * The currently supported predicate expressions: + *

    + *
  1. Name: IS_NULL + *
      + *
    • SQL semantic: expr IS NULL
    • + *
    • Since version: 3.3.0
    • + *
    + *
  2. + *
  3. Name: IS_NOT_NULL + *
      + *
    • SQL semantic: expr IS NOT NULL
    • + *
    • Since version: 3.3.0
    • + *
    + *
  4. + *
  5. Name: STARTS_WITH + *
      + *
    • SQL semantic: expr1 LIKE 'expr2%'
    • + *
    • Since version: 3.3.0
    • + *
    + *
  6. + *
  7. Name: ENDS_WITH + *
      + *
    • SQL semantic: expr1 LIKE '%expr2'
    • + *
    • Since version: 3.3.0
    • + *
    + *
  8. + *
  9. Name: CONTAINS + *
      + *
    • SQL semantic: expr1 LIKE '%expr2%'
    • + *
    • Since version: 3.3.0
    • + *
    + *
  10. + *
  11. Name: IN + *
      + *
    • SQL semantic: expr IN (expr1, expr2, ...)
    • + *
    • Since version: 3.3.0
    • + *
    + *
  12. + *
  13. Name: = + *
      + *
    • SQL semantic: expr1 = expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  14. + *
  15. Name: <> + *
      + *
    • SQL semantic: expr1 <> expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  16. + *
  17. Name: <=> + *
      + *
    • SQL semantic: null-safe version of expr1 = expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  18. + *
  19. Name: < + *
      + *
    • SQL semantic: expr1 < expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  20. + *
  21. Name: <= + *
      + *
    • SQL semantic: expr1 <= expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  22. + *
  23. Name: > + *
      + *
    • SQL semantic: expr1 > expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  24. + *
  25. Name: >= + *
      + *
    • SQL semantic: expr1 >= expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  26. + *
  27. Name: AND + *
      + *
    • SQL semantic: expr1 AND expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  28. + *
  29. Name: OR + *
      + *
    • SQL semantic: expr1 OR expr2
    • + *
    • Since version: 3.3.0
    • + *
    + *
  30. + *
  31. Name: NOT + *
      + *
    • SQL semantic: NOT expr
    • + *
    • Since version: 3.3.0
    • + *
    + *
  32. + *
  33. Name: ALWAYS_TRUE + *
      + *
    • SQL semantic: TRUE
    • + *
    • Since version: 3.3.0
    • + *
    + *
  34. + *
  35. Name: ALWAYS_FALSE + *
      + *
    • SQL semantic: FALSE
    • + *
    • Since version: 3.3.0
    • + *
    + *
  36. + *
+ * + * @since 3.3.0 + */ +@Evolving +public class Predicate extends GeneralScalarExpression { + + public Predicate(String name, Expression[] children) { + super(name, children); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java index b46f620d4fedb..27ee534d804ff 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java @@ -21,9 +21,9 @@ /** * An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ - * interfaces to do operator pushdown, and keep the operator pushdown result in the returned - * {@link Scan}. When pushing down operators, Spark pushes down filters first, then pushes down - * aggregates or applies column pruning. + * interfaces to do operator push down, and keep the operator push down result in the returned + * {@link Scan}. When pushing down operators, the push down order is: + * sample -> filter -> aggregate -> limit -> column pruning. * * @since 3.0.0 */ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java index 3e643b5493310..4d88ec19c897b 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java @@ -22,18 +22,20 @@ /** * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to - * push down aggregates. Spark assumes that the data source can't fully complete the - * grouping work, and will group the data source output again. For queries like - * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate - * to the data source, the data source can still output data with duplicated keys, which is OK - * as Spark will do GROUP BY key again. The final query plan can be something like this: + * push down aggregates. + *

+ * If the data source can't fully complete the grouping work, then + * {@link #supportCompletePushDown(Aggregation)} should return false, and Spark will group the data + * source output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after + * pushing down the aggregate to the data source, the data source can still output data with + * duplicated keys, which is OK as Spark will do GROUP BY key again. The final query plan can be + * something like this: *

- *   Aggregate [key#1], [min(min(value)#2) AS m#3]
- *     +- RelationV2[key#1, min(value)#2]
+ *   Aggregate [key#1], [min(min_value#2) AS m#3]
+ *     +- RelationV2[key#1, min_value#2]
  * 
* Similarly, if there is no grouping expression, the data source can still output more than one * rows. - * *

* When pushing down operators, Spark pushes down filters to the data source first, then push down * aggregates or apply column pruning. Depends on data source implementation, aggregates may or @@ -45,11 +47,21 @@ @Evolving public interface SupportsPushDownAggregates extends ScanBuilder { + /** + * Whether the datasource support complete aggregation push-down. Spark will do grouping again + * if this method returns false. + * + * @param aggregation Aggregation in SQL statement. + * @return true if the aggregation can be pushed down to datasource completely, false otherwise. + */ + default boolean supportCompletePushDown(Aggregation aggregation) { return false; } + /** * Pushes down Aggregation to datasource. The order of the datasource scan output columns should * be: grouping columns, aggregate columns (in the same order as the aggregate functions in * the given Aggregation). * + * @param aggregation Aggregation in SQL statement. * @return true if the aggregation can be pushed down to datasource, false otherwise. */ boolean pushAggregation(Aggregation aggregation); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java new file mode 100644 index 0000000000000..fa6447bc068d5 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down LIMIT. Please note that the combination of LIMIT with other operations + * such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownLimit extends ScanBuilder { + + /** + * Pushes down LIMIT to the data source. + */ + boolean pushLimit(int limit); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java new file mode 100644 index 0000000000000..3630feb4680ea --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; + +/** + * A mix-in interface for {@link Scan}. Data sources can implement this interface to + * push down SAMPLE. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownTableSample extends ScanBuilder { + + /** + * Pushes down SAMPLE to the data source. + */ + boolean pushTableSample( + double lowerBound, + double upperBound, + boolean withReplacement, + long seed); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java new file mode 100644 index 0000000000000..cba1592c4fa14 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.SortOrder; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N + * with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc. + * is NOT pushed down. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownTopN extends ScanBuilder { + + /** + * Pushes down top N to the data source. + */ + boolean pushTopN(SortOrder[] orders, int limit); + + /** + * Whether the top N is partially pushed or not. If it returns true, then Spark will do top N + * again. This method will only be called when {@link #pushTopN} returns true. + */ + default boolean isPartiallyPushed() { return true; } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java new file mode 100644 index 0000000000000..1fec939aeb474 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.read; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.connector.expressions.filter.Predicate; + +/** + * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to + * push down V2 {@link Predicate} to the data source and reduce the size of the data to be read. + * Please Note that this interface is preferred over {@link SupportsPushDownFilters}, which uses + * V1 {@link org.apache.spark.sql.sources.Filter} and is less efficient due to the + * internal -> external data conversion. + * + * @since 3.3.0 + */ +@Evolving +public interface SupportsPushDownV2Filters extends ScanBuilder { + + /** + * Pushes down predicates, and returns predicates that need to be evaluated after scanning. + *

+ * Rows should be returned from the data source if and only if all of the predicates match. + * That is, predicates must be interpreted as ANDed together. + */ + Predicate[] pushPredicates(Predicate[] predicates); + + /** + * Returns the predicates that are pushed to the data source via + * {@link #pushPredicates(Predicate[])}. + *

+ * There are 3 kinds of predicates: + *

    + *
  1. pushable predicates which don't need to be evaluated again after scanning.
  2. + *
  3. pushable predicates which still need to be evaluated after scanning, e.g. parquet row + * group predicate.
  4. + *
  5. non-pushable predicates.
  6. + *
+ *

+ * Both case 1 and 2 should be considered as pushed predicates and should be returned + * by this method. + *

+ * It's possible that there is no predicates in the query and + * {@link #pushPredicates(Predicate[])} is never called, + * empty array should be returned for this case. + */ + Predicate[] pushedPredicates(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java new file mode 100644 index 0000000000000..c9dfa2003e3c1 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java @@ -0,0 +1,231 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.util; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +import org.apache.spark.sql.connector.expressions.Cast; +import org.apache.spark.sql.connector.expressions.Expression; +import org.apache.spark.sql.connector.expressions.NamedReference; +import org.apache.spark.sql.connector.expressions.GeneralScalarExpression; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.types.DataType; + +/** + * The builder to generate SQL from V2 expressions. + */ +public class V2ExpressionSQLBuilder { + + public String build(Expression expr) { + if (expr instanceof Literal) { + return visitLiteral((Literal) expr); + } else if (expr instanceof NamedReference) { + return visitNamedReference((NamedReference) expr); + } else if (expr instanceof Cast) { + Cast cast = (Cast) expr; + return visitCast(build(cast.expression()), cast.dataType()); + } else if (expr instanceof GeneralScalarExpression) { + GeneralScalarExpression e = (GeneralScalarExpression) expr; + String name = e.name(); + switch (name) { + case "IN": { + List children = + Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); + return visitIn(children.get(0), children.subList(1, children.size())); + } + case "IS_NULL": + return visitIsNull(build(e.children()[0])); + case "IS_NOT_NULL": + return visitIsNotNull(build(e.children()[0])); + case "STARTS_WITH": + return visitStartsWith(build(e.children()[0]), build(e.children()[1])); + case "ENDS_WITH": + return visitEndsWith(build(e.children()[0]), build(e.children()[1])); + case "CONTAINS": + return visitContains(build(e.children()[0]), build(e.children()[1])); + case "=": + case "<>": + case "<=>": + case "<": + case "<=": + case ">": + case ">=": + return visitBinaryComparison( + name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); + case "+": + case "*": + case "/": + case "%": + case "&": + case "|": + case "^": + return visitBinaryArithmetic( + name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); + case "-": + if (e.children().length == 1) { + return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); + } else { + return visitBinaryArithmetic( + name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1])); + } + case "AND": + return visitAnd(name, build(e.children()[0]), build(e.children()[1])); + case "OR": + return visitOr(name, build(e.children()[0]), build(e.children()[1])); + case "NOT": + return visitNot(build(e.children()[0])); + case "~": + return visitUnaryArithmetic(name, inputToSQL(e.children()[0])); + case "ABS": + case "COALESCE": + case "LN": + case "EXP": + case "POWER": + case "SQRT": + case "FLOOR": + case "CEIL": + case "WIDTH_BUCKET": + return visitSQLFunction(name, + Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new)); + case "CASE_WHEN": { + List children = + Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList()); + return visitCaseWhen(children.toArray(new String[e.children().length])); + } + // TODO supports other expressions + default: + return visitUnexpectedExpr(expr); + } + } else { + return visitUnexpectedExpr(expr); + } + } + + protected String visitLiteral(Literal literal) { + return literal.toString(); + } + + protected String visitNamedReference(NamedReference namedRef) { + return namedRef.toString(); + } + + protected String visitIn(String v, List list) { + if (list.isEmpty()) { + return "CASE WHEN " + v + " IS NULL THEN NULL ELSE FALSE END"; + } + return v + " IN (" + list.stream().collect(Collectors.joining(", ")) + ")"; + } + + protected String visitIsNull(String v) { + return v + " IS NULL"; + } + + protected String visitIsNotNull(String v) { + return v + " IS NOT NULL"; + } + + protected String visitStartsWith(String l, String r) { + // Remove quotes at the beginning and end. + // e.g. converts "'str'" to "str". + String value = r.substring(1, r.length() - 1); + return l + " LIKE '" + value + "%'"; + } + + protected String visitEndsWith(String l, String r) { + // Remove quotes at the beginning and end. + // e.g. converts "'str'" to "str". + String value = r.substring(1, r.length() - 1); + return l + " LIKE '%" + value + "'"; + } + + protected String visitContains(String l, String r) { + // Remove quotes at the beginning and end. + // e.g. converts "'str'" to "str". + String value = r.substring(1, r.length() - 1); + return l + " LIKE '%" + value + "%'"; + } + + private String inputToSQL(Expression input) { + if (input.children().length > 1) { + return "(" + build(input) + ")"; + } else { + return build(input); + } + } + + protected String visitBinaryComparison(String name, String l, String r) { + switch (name) { + case "<=>": + return "(" + l + " = " + r + ") OR (" + l + " IS NULL AND " + r + " IS NULL)"; + default: + return l + " " + name + " " + r; + } + } + + protected String visitBinaryArithmetic(String name, String l, String r) { + return l + " " + name + " " + r; + } + + protected String visitCast(String l, DataType dataType) { + return "CAST(" + l + " AS " + dataType.typeName() + ")"; + } + + protected String visitAnd(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitOr(String name, String l, String r) { + return "(" + l + ") " + name + " (" + r + ")"; + } + + protected String visitNot(String v) { + return "NOT (" + v + ")"; + } + + protected String visitUnaryArithmetic(String name, String v) { return name + v; } + + protected String visitCaseWhen(String[] children) { + StringBuilder sb = new StringBuilder("CASE"); + for (int i = 0; i < children.length; i += 2) { + String c = children[i]; + int j = i + 1; + if (j < children.length) { + String v = children[j]; + sb.append(" WHEN "); + sb.append(c); + sb.append(" THEN "); + sb.append(v); + } else { + sb.append(" ELSE "); + sb.append(c); + } + } + sb.append(" END"); + return sb.toString(); + } + + protected String visitSQLFunction(String funcName, String[] inputs) { + return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")"; + } + + protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException { + throw new IllegalArgumentException("Unexpected V2 expression: " + expr); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 70f821d5f8af0..fb177251a7306 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -78,3 +78,6 @@ class PartitionsAlreadyExistException(message: String) extends AnalysisException class FunctionAlreadyExistsException(db: String, func: String) extends AnalysisException(s"Function '$func' already exists in database '$db'") + +class IndexAlreadyExistsException(message: String, cause: Option[Throwable] = None) + extends AnalysisException(message, cause = cause) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index ba5a9c618c650..8b0710b2c1f19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -95,3 +95,6 @@ class NoSuchPartitionsException(message: String) extends AnalysisException(messa class NoSuchTempFunctionException(func: String) extends AnalysisException(s"Temporary function '$func' not found") + +class NoSuchIndexException(message: String, cause: Option[Throwable] = None) + extends AnalysisException(message, cause = cause) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala new file mode 100644 index 0000000000000..f3ff28f74fcc3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ + + +/** + * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception + * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. + */ +case class NonEmptyNamespaceException( + override val message: String, + override val cause: Option[Throwable] = None) + extends AnalysisException(message, cause = cause) { + + def this(namespace: Array[String]) = { + this(s"Namespace '${namespace.quoted}' is non empty.") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala index 0007d3868eda2..dea7ea0f144bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.MultiAlias import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project} +import org.apache.spark.sql.types.Metadata /** * Helper methods for collecting and replacing aliases. @@ -86,10 +87,15 @@ trait AliasHelper { protected def trimNonTopLevelAliases[T <: Expression](e: T): T = { val res = e match { case a: Alias => + val metadata = if (a.metadata == Metadata.empty) { + None + } else { + Some(a.metadata) + } a.copy(child = trimAliases(a.child))( exprId = a.exprId, qualifier = a.qualifier, - explicitMetadata = Some(a.metadata), + explicitMetadata = metadata, nonInheritableMetadataKeys = a.nonInheritableMetadataKeys) case a: MultiAlias => a.copy(child = trimAliases(a.child)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 9714a096a69a2..533f7f20b2530 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -69,15 +69,15 @@ case class Average( case _ => DoubleType } - private lazy val sumDataType = child.dataType match { + lazy val sumDataType = child.dataType match { case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s) case _: YearMonthIntervalType => YearMonthIntervalType() case _: DayTimeIntervalType => DayTimeIntervalType() case _ => DoubleType } - private lazy val sum = AttributeReference("sum", sumDataType)() - private lazy val count = AttributeReference("count", LongType)() + lazy val sum = AttributeReference("sum", sumDataType)() + lazy val count = AttributeReference("count", LongType)() override lazy val aggBufferAttributes = sum :: count :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index eb040e23290c9..1a57ee83fa3ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -912,56 +912,69 @@ object ColumnPruning extends Rule[LogicalPlan] { */ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( - _.containsPattern(PROJECT), ruleId) { - case p1 @ Project(_, p2: Project) => - if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) { - p1 - } else { + def apply(plan: LogicalPlan): LogicalPlan = { + val alwaysInline = conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + plan.transformUpWithPruning(_.containsPattern(PROJECT), ruleId) { + case p1 @ Project(_, p2: Project) + if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) => p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList)) - } - case p @ Project(_, agg: Aggregate) => - if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions) || - !canCollapseAggregate(p, agg)) { - p - } else { + case p @ Project(_, agg: Aggregate) + if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) => agg.copy(aggregateExpressions = buildCleanedProjectList( p.projectList, agg.aggregateExpressions)) - } - case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) + case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _)))) if isRenaming(l1, l2) => - val newProjectList = buildCleanedProjectList(l1, l2) - g.copy(child = limit.copy(child = p2.copy(projectList = newProjectList))) - case Project(l1, limit @ LocalLimit(_, p2 @ Project(l2, _))) if isRenaming(l1, l2) => - val newProjectList = buildCleanedProjectList(l1, l2) - limit.copy(child = p2.copy(projectList = newProjectList)) - case Project(l1, r @ Repartition(_, _, p @ Project(l2, _))) if isRenaming(l1, l2) => - r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, p.projectList))) - case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if isRenaming(l1, l2) => - s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList))) - } - - private def haveCommonNonDeterministicOutput( - upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = { - val aliases = getAliasMap(lower) + val newProjectList = buildCleanedProjectList(l1, l2) + g.copy(child = limit.copy(child = p2.copy(projectList = newProjectList))) + case Project(l1, limit @ LocalLimit(_, p2 @ Project(l2, _))) if isRenaming(l1, l2) => + val newProjectList = buildCleanedProjectList(l1, l2) + limit.copy(child = p2.copy(projectList = newProjectList)) + case Project(l1, r @ Repartition(_, _, p @ Project(l2, _))) if isRenaming(l1, l2) => + r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, p.projectList))) + case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if isRenaming(l1, l2) => + s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList))) + } + } - // Collapse upper and lower Projects if and only if their overlapped expressions are all - // deterministic. - upper.exists(_.collect { - case a: Attribute if aliases.contains(a) => aliases(a).child - }.exists(!_.deterministic)) + /** + * Check if we can collapse expressions safely. + */ + def canCollapseExpressions( + consumers: Seq[Expression], + producers: Seq[NamedExpression], + alwaysInline: Boolean): Boolean = { + canCollapseExpressions(consumers, getAliasMap(producers), alwaysInline) } /** - * A project cannot be collapsed with an aggregate when there are correlated scalar - * subqueries in the project list, because currently we only allow correlated subqueries - * in aggregate if they are also part of the grouping expressions. Otherwise the plan - * after subquery rewrite will not be valid. + * Check if we can collapse expressions safely. */ - private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = { - p.projectList.forall(_.collect { - case s: ScalarSubquery if s.outerAttrs.nonEmpty => s - }.isEmpty) + def canCollapseExpressions( + consumers: Seq[Expression], + producerMap: Map[Attribute, Expression], + alwaysInline: Boolean = false): Boolean = { + // We can only collapse expressions if all input expressions meet the following criteria: + // - The input is deterministic. + // - The input is only consumed once OR the underlying input expression is cheap. + consumers.flatMap(collectReferences) + .groupBy(identity) + .mapValues(_.size) + .forall { + case (reference, count) => + val producer = producerMap.getOrElse(reference, reference) + producer.deterministic && (count == 1 || alwaysInline || { + val relatedConsumers = consumers.filter(_.references.contains(reference)) + val extractOnly = relatedConsumers.forall(isExtractOnly(_, reference)) + shouldInline(producer, extractOnly) + }) + } + } + + private def isExtractOnly(expr: Expression, ref: Attribute): Boolean = expr match { + case a: Alias => isExtractOnly(a.child, ref) + case e: ExtractValue => isExtractOnly(e.children.head, ref) + case a: Attribute => a.semanticEquals(ref) + case _ => false } private def buildCleanedProjectList( @@ -971,6 +984,34 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { upper.map(replaceAliasButKeepName(_, aliases)) } + /** + * Check if the given expression is cheap that we can inline it. + */ + private def shouldInline(e: Expression, extractOnlyConsumer: Boolean): Boolean = e match { + case _: Attribute | _: OuterReference => true + case _ if e.foldable => true + // PythonUDF is handled by the rule ExtractPythonUDFs + case _: PythonUDF => true + // Alias and ExtractValue are very cheap. + case _: Alias | _: ExtractValue => e.children.forall(shouldInline(_, extractOnlyConsumer)) + // These collection create functions are not cheap, but we have optimizer rules that can + // optimize them out if they are only consumed by ExtractValue, so we need to allow to inline + // them to avoid perf regression. As an example: + // Project(s.a, s.b, Project(create_struct(a, b, c) as s, child)) + // We should collapse these two projects and eventually get Project(a, b, child) + case _: CreateNamedStruct | _: CreateArray | _: CreateMap | _: UpdateFields => + extractOnlyConsumer + case _ => false + } + + /** + * Return all the references of the given expression without deduplication, which is different + * from `Expression.references`. + */ + private def collectReferences(e: Expression): Seq[Attribute] = e.collect { + case a: Attribute => a + } + private def isRenaming(list1: Seq[NamedExpression], list2: Seq[NamedExpression]): Boolean = { list1.length == list2.length && list1.zip(list2).forall { case (e1, e2) if e1.semanticEquals(e2) => true diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index fc12f48ec2a11..f33d137ffd607 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -26,46 +26,32 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf -trait OperationHelper { - type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan) - - protected def collectAliases(fields: Seq[Expression]): AttributeMap[Expression] = - AttributeMap(fields.collect { - case a: Alias => (a.toAttribute, a.child) - }) - - protected def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = { - // use transformUp instead of transformDown to avoid dead loop - // in case of there's Alias whose exprId is the same as its child attribute. - expr.transformUp { - case a @ Alias(ref: AttributeReference, name) => - aliases.get(ref) - .map(Alias(_, name)(a.exprId, a.qualifier)) - .getOrElse(a) - - case a: AttributeReference => - aliases.get(a) - .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a) - } - } -} +trait OperationHelper extends AliasHelper with PredicateHelper { + import org.apache.spark.sql.catalyst.optimizer.CollapseProject.canCollapseExpressions -/** - * A pattern that matches any number of project or filter operations on top of another relational - * operator. All filter operators are collected and their conditions are broken up and returned - * together with the top project operator. - * [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if - * necessary. - */ -object PhysicalOperation extends OperationHelper with PredicateHelper { + type ReturnType = + (Seq[NamedExpression], Seq[Expression], LogicalPlan) + type IntermediateType = + (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Alias]) def unapply(plan: LogicalPlan): Option[ReturnType] = { - val (fields, filters, child, _) = collectProjectsAndFilters(plan) + val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE) + val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline) Some((fields.getOrElse(child.output), filters, child)) } /** - * Collects all deterministic projects and filters, in-lining/substituting aliases if necessary. + * This legacy mode is for PhysicalOperation which has been there for years and we want to be + * extremely safe to not change its behavior. There are two differences when legacy mode is off: + * 1. We postpone the deterministic check to the very end (calling `canCollapseExpressions`), + * so that it's more likely to collect more projects and filters. + * 2. We follow CollapseProject and only collect adjacent projects if they don't produce + * repeated expensive expressions. + */ + protected def legacyMode: Boolean + + /** + * Collects all adjacent projects and filters, in-lining/substituting aliases if necessary. * Here are two examples for alias in-lining/substitution. * Before: * {{{ @@ -78,25 +64,60 @@ object PhysicalOperation extends OperationHelper with PredicateHelper { * SELECT key AS c2 FROM t1 WHERE key > 10 * }}} */ - private def collectProjectsAndFilters(plan: LogicalPlan): - (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Expression]) = + private def collectProjectsAndFilters( + plan: LogicalPlan, + alwaysInline: Boolean): IntermediateType = { + def empty: IntermediateType = (None, Nil, plan, AttributeMap.empty) + plan match { - case Project(fields, child) if fields.forall(_.deterministic) => - val (_, filters, other, aliases) = collectProjectsAndFilters(child) - val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] - (Some(substitutedFields), filters, other, collectAliases(substitutedFields)) + case Project(fields, child) if !legacyMode || fields.forall(_.deterministic) => + val (_, filters, other, aliases) = collectProjectsAndFilters(child, alwaysInline) + if (legacyMode || canCollapseExpressions(fields, aliases, alwaysInline)) { + val replaced = fields.map(replaceAliasButKeepName(_, aliases)) + (Some(replaced), filters, other, getAliasMap(replaced)) + } else { + empty + } - case Filter(condition, child) if condition.deterministic => - val (fields, filters, other, aliases) = collectProjectsAndFilters(child) - val substitutedCondition = substitute(aliases)(condition) - (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) + case Filter(condition, child) if !legacyMode || condition.deterministic => + val (fields, filters, other, aliases) = collectProjectsAndFilters(child, alwaysInline) + val canIncludeThisFilter = if (legacyMode) { + true + } else { + // When collecting projects and filters, we effectively push down filters through + // projects. We need to meet the following conditions to do so: + // 1) no Project collected so far or the collected Projects are all deterministic + // 2) the collected filters and this filter are all deterministic, or this is the + // first collected filter. + // 3) this filter does not repeat any expensive expressions from the collected + // projects. + fields.forall(_.forall(_.deterministic)) && { + filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic) + } && canCollapseExpressions(Seq(condition), aliases, alwaysInline) + } + if (canIncludeThisFilter) { + val replaced = replaceAlias(condition, aliases) + (fields, filters ++ splitConjunctivePredicates(replaced), other, aliases) + } else { + empty + } - case h: ResolvedHint => - collectProjectsAndFilters(h.child) + case h: ResolvedHint => collectProjectsAndFilters(h.child, alwaysInline) - case other => - (None, Nil, other, AttributeMap(Seq())) + case _ => empty } + } +} + +/** + * A pattern that matches any number of project or filter operations on top of another relational + * operator. All filter operators are collected and their conditions are broken up and returned + * together with the top project operator. + * [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if + * necessary. + */ +object PhysicalOperation extends OperationHelper with PredicateHelper { + override protected def legacyMode: Boolean = true } /** @@ -105,70 +126,7 @@ object PhysicalOperation extends OperationHelper with PredicateHelper { * requirement of CollapseProject and CombineFilters. */ object ScanOperation extends OperationHelper with PredicateHelper { - type ScanReturnType = Option[(Option[Seq[NamedExpression]], - Seq[Expression], LogicalPlan, AttributeMap[Expression])] - - def unapply(plan: LogicalPlan): Option[ReturnType] = { - collectProjectsAndFilters(plan) match { - case Some((fields, filters, child, _)) => - Some((fields.getOrElse(child.output), filters, child)) - case None => None - } - } - - private def hasCommonNonDeterministic( - expr: Seq[Expression], - aliases: AttributeMap[Expression]): Boolean = { - expr.exists(_.collect { - case a: AttributeReference if aliases.contains(a) => aliases(a) - }.exists(!_.deterministic)) - } - - private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = { - plan match { - case Project(fields, child) => - collectProjectsAndFilters(child) match { - case Some((_, filters, other, aliases)) => - // Follow CollapseProject and only keep going if the collected Projects - // do not have common non-deterministic expressions. - if (!hasCommonNonDeterministic(fields, aliases)) { - val substitutedFields = - fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]] - Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields))) - } else { - None - } - case None => None - } - - case Filter(condition, child) => - collectProjectsAndFilters(child) match { - case Some((fields, filters, other, aliases)) => - // When collecting projects and filters, we effectively push down filters through - // projects. We need to meet the following conditions to do so: - // 1) no Project collected so far or the collected Projects are all deterministic - // 2) the collected filters and this filter are all deterministic, or this is the - // first collected filter. - val canCombineFilters = fields.forall(_.forall(_.deterministic)) && { - filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic) - } - val substitutedCondition = substitute(aliases)(condition) - if (canCombineFilters && !hasCommonNonDeterministic(Seq(condition), aliases)) { - Some((fields, filters ++ splitConjunctivePredicates(substitutedCondition), - other, aliases)) - } else { - None - } - case None => None - } - - case h: ResolvedHint => - collectProjectsAndFilters(h.child) - - case other => - Some((None, Nil, other, AttributeMap(Seq()))) - } - } + override protected def legacyMode: Boolean = false } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala index 39642fd541706..185a1a2644e2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala @@ -38,12 +38,13 @@ private[sql] object CatalogV2Implicits { implicit class BucketSpecHelper(spec: BucketSpec) { def asTransform: BucketTransform = { + val references = spec.bucketColumnNames.map(col => reference(Seq(col))) if (spec.sortColumnNames.nonEmpty) { - throw QueryCompilationErrors.cannotConvertBucketWithSortColumnsToTransformError(spec) + val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col))) + bucket(spec.numBuckets, references.toArray, sortedCol.toArray) + } else { + bucket(spec.numBuckets, references.toArray) } - - val references = spec.bucketColumnNames.map(col => reference(Seq(col))) - bucket(spec.numBuckets, references.toArray) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala index 2863d94d198b2..e3eab6f6730f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala @@ -45,6 +45,12 @@ private[sql] object LogicalExpressions { def bucket(numBuckets: Int, references: Array[NamedReference]): BucketTransform = BucketTransform(literal(numBuckets, IntegerType), references) + def bucket( + numBuckets: Int, + references: Array[NamedReference], + sortedCols: Array[NamedReference]): BucketTransform = + BucketTransform(literal(numBuckets, IntegerType), references, sortedCols) + def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference) def years(reference: NamedReference): YearsTransform = YearsTransform(reference) @@ -82,9 +88,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R override def arguments: Array[Expression] = Array(ref) - override def describe: String = name + "(" + reference.describe + ")" - - override def toString: String = describe + override def toString: String = name + "(" + reference.describe + ")" protected def withNewRef(ref: NamedReference): Transform @@ -97,7 +101,8 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R private[sql] final case class BucketTransform( numBuckets: Literal[Int], - columns: Seq[NamedReference]) extends RewritableTransform { + columns: Seq[NamedReference], + sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform { override val name: String = "bucket" @@ -107,9 +112,13 @@ private[sql] final case class BucketTransform( override def arguments: Array[Expression] = numBuckets +: columns.toArray - override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})" - - override def toString: String = describe + override def toString: String = + if (sortedColumns.nonEmpty) { + s"bucket(${arguments.map(_.describe).mkString(", ")}," + + s" ${sortedColumns.map(_.describe).mkString(", ")})" + } else { + s"bucket(${arguments.map(_.describe).mkString(", ")})" + } override def withReferences(newReferences: Seq[NamedReference]): Transform = { this.copy(columns = newReferences) @@ -117,11 +126,12 @@ private[sql] final case class BucketTransform( } private[sql] object BucketTransform { - def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match { + def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] = + expr match { case transform: Transform => transform match { - case BucketTransform(n, FieldReference(parts)) => - Some((n, FieldReference(parts))) + case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) => + Some((n, FieldReference(parts), FieldReference(sortCols))) case _ => None } @@ -129,11 +139,17 @@ private[sql] object BucketTransform { None } - def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match { + def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] = + transform match { + case NamedTransform("bucket", Seq( + Lit(value: Int, IntegerType), + Ref(partCols: Seq[String]), + Ref(sortCols: Seq[String]))) => + Some((value, FieldReference(partCols), FieldReference(sortCols))) case NamedTransform("bucket", Seq( Lit(value: Int, IntegerType), - Ref(seq: Seq[String]))) => - Some((value, FieldReference(seq))) + Ref(partCols: Seq[String]))) => + Some((value, FieldReference(partCols), FieldReference(Seq.empty[String]))) case _ => None } @@ -149,9 +165,7 @@ private[sql] final case class ApplyTransform( arguments.collect { case named: NamedReference => named } } - override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})" - - override def toString: String = describe + override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})" } /** @@ -318,21 +332,19 @@ private[sql] object HoursTransform { } private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { - override def describe: String = { + override def toString: String = { if (dataType.isInstanceOf[StringType]) { s"'$value'" } else { s"$value" } } - override def toString: String = describe } private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference { import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper override def fieldNames: Array[String] = parts.toArray - override def describe: String = parts.quoted - override def toString: String = describe + override def toString: String = parts.quoted } private[sql] object FieldReference { @@ -346,7 +358,7 @@ private[sql] final case class SortValue( direction: SortDirection, nullOrdering: NullOrdering) extends SortOrder { - override def describe(): String = s"$expression $direction $nullOrdering" + override def toString(): String = s"$expression $direction $nullOrdering" } private[sql] object SortValue { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index e7af006ad7023..0c7a1030fd434 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, NoSuchTableException, ResolvedNamespace, ResolvedTable, ResolvedView, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, InvalidUDFClassException} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, InvalidUDFClassException} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.JoinType @@ -555,6 +555,11 @@ object QueryCompilationErrors { new AnalysisException(s"Database $db is not empty. One or more $details exist.") } + def cannotDropNonemptyNamespaceError(namespace: Seq[String]): Throwable = { + new AnalysisException(s"Cannot drop a non-empty namespace: ${namespace.quoted}. " + + "Use CASCADE option to drop a non-empty namespace.") + } + def invalidNameForTableOrDatabaseError(name: String): Throwable = { new AnalysisException(s"`$name` is not a valid name for tables/databases. " + "Valid names only contain alphabet characters, numbers and _.") @@ -1371,11 +1376,6 @@ object QueryCompilationErrors { new AnalysisException("Cannot use interval type in the table schema.") } - def cannotConvertBucketWithSortColumnsToTransformError(spec: BucketSpec): Throwable = { - new AnalysisException( - s"Cannot convert bucketing with sort columns to a transform: $spec") - } - def cannotConvertTransformsToPartitionColumnsError(nonIdTransforms: Seq[Transform]): Throwable = { new AnalysisException("Transforms cannot be converted to partition columns: " + nonIdTransforms.map(_.describe).mkString(", ")) @@ -2371,4 +2371,8 @@ object QueryCompilationErrors { messageParameters = Array(fieldName.quoted, path.quoted), origin = context) } + + def noSuchFunctionError(database: String, funcInfo: String): Throwable = { + new AnalysisException(s"$database does not support function: $funcInfo") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index 7f77243af8a88..88ab9e530a1a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1804,4 +1804,16 @@ object QueryExecutionErrors { def pivotNotAfterGroupByUnsupportedError(): Throwable = { new UnsupportedOperationException("pivot is only supported after a groupBy") } + + def unsupportedCreateNamespaceCommentError(): Throwable = { + new SQLFeatureNotSupportedException("Create namespace comment is not supported") + } + + def unsupportedRemoveNamespaceCommentError(): Throwable = { + new SQLFeatureNotSupportedException("Remove namespace comment is not supported") + } + + def unsupportedDropNamespaceRestrictError(): Throwable = { + new SQLFeatureNotSupportedException("Drop namespace restrict is not supported") + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 15927a9ffdfbf..96ca754cad220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -851,6 +851,14 @@ object SQLConf { .checkValue(threshold => threshold >= 0, "The threshold must not be negative.") .createWithDefault(10) + val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown") + .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" + + " down to Parquet for optimization. MAX/MIN/COUNT for complex types and timestamp" + + " can't be pushed down") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat") .doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " + "values will be written in Apache Parquet's fixed-length byte array format, which other " + @@ -942,6 +950,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ORC_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.aggregatePushdown") + .doc("If true, aggregates will be pushed down to ORC for optimization. Support MIN, MAX and " + + "COUNT as aggregate expression. For MIN/MAX, support boolean, integer, float and date " + + "type. For COUNT, support all data types.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val ORC_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.orc.mergeSchema") .doc("When true, the Orc data source merges schemas collected from all data files, " + "otherwise the schema is picked from a random data file.") @@ -1852,6 +1868,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val COLLAPSE_PROJECT_ALWAYS_INLINE = buildConf("spark.sql.optimizer.collapseProjectAlwaysInline") + .doc("Whether to always collapse two adjacent projections and inline expressions even if " + + "it causes extra duplication.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion") .internal() .doc("Whether to delete the expired log files in file stream sink.") @@ -3679,8 +3702,12 @@ class SQLConf extends Serializable with Logging { def parquetFilterPushDownInFilterThreshold: Int = getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD) + def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED) + def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED) + def orcAggregatePushDown: Boolean = getConf(ORC_AGGREGATE_PUSHDOWN_ENABLED) + def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED) def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala new file mode 100644 index 0000000000000..9c2a4ac78a24a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.internal.connector + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.sources.Filter + +/** + * A mix-in interface for {@link FileScanBuilder}. File sources can implement this interface to + * push down filters to the file source. The pushed down filters will be separated into partition + * filters and data filters. Partition filters are used for partition pruning and data filters are + * used to reduce the size of the data to be read. + */ +trait SupportsPushDownCatalystFilters { + + /** + * Pushes down catalyst Expression filters (which will be separated into partition filters and + * data filters), and returns data filters that need to be evaluated after scanning. + */ + def pushFilters(filters: Seq[Expression]): Seq[Expression] + + /** + * Returns the data filters that are pushed to the data source via + * {@link #pushFilters(Expression[])}. + */ + def pushedFilters: Array[Filter] +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala index 80658f7cec2e3..e358ff0cb6677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -18,7 +18,12 @@ package org.apache.spark.sql.sources import org.apache.spark.annotation.{Evolving, Stable} +import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, Predicate} +import org.apache.spark.sql.types.StringType +import org.apache.spark.unsafe.types.UTF8String //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -64,6 +69,11 @@ sealed abstract class Filter { private[sql] def containsNestedColumn: Boolean = { this.v2references.exists(_.length > 1) } + + /** + * Converts V1 filter to V2 filter + */ + private[sql] def toV2: Predicate } /** @@ -78,6 +88,11 @@ sealed abstract class Filter { @Stable case class EqualTo(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("=", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -93,6 +108,11 @@ case class EqualTo(attribute: String, value: Any) extends Filter { @Stable case class EqualNullSafe(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("<=>", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -107,6 +127,11 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter { @Stable case class GreaterThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate(">", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -121,6 +146,11 @@ case class GreaterThan(attribute: String, value: Any) extends Filter { @Stable case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate(">=", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -135,6 +165,11 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter { @Stable case class LessThan(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("<", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -149,6 +184,11 @@ case class LessThan(attribute: String, value: Any) extends Filter { @Stable case class LessThanOrEqual(attribute: String, value: Any) extends Filter { override def references: Array[String] = Array(attribute) ++ findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("<=", + Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType))) + } } /** @@ -185,6 +225,13 @@ case class In(attribute: String, values: Array[Any]) extends Filter { } override def references: Array[String] = Array(attribute) ++ values.flatMap(findReferences) + override def toV2: Predicate = { + val literals = values.map { value => + val literal = Literal(value) + LiteralValue(literal.value, literal.dataType) + } + new Predicate("IN", FieldReference(attribute) +: literals) + } } /** @@ -198,6 +245,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter { @Stable case class IsNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("IS_NULL", Array(FieldReference(attribute))) } /** @@ -211,6 +259,7 @@ case class IsNull(attribute: String) extends Filter { @Stable case class IsNotNull(attribute: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("IS_NOT_NULL", Array(FieldReference(attribute))) } /** @@ -221,6 +270,7 @@ case class IsNotNull(attribute: String) extends Filter { @Stable case class And(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references + override def toV2: Predicate = new Predicate("AND", Seq(left, right).map(_.toV2).toArray) } /** @@ -231,6 +281,7 @@ case class And(left: Filter, right: Filter) extends Filter { @Stable case class Or(left: Filter, right: Filter) extends Filter { override def references: Array[String] = left.references ++ right.references + override def toV2: Predicate = new Predicate("OR", Seq(left, right).map(_.toV2).toArray) } /** @@ -241,6 +292,7 @@ case class Or(left: Filter, right: Filter) extends Filter { @Stable case class Not(child: Filter) extends Filter { override def references: Array[String] = child.references + override def toV2: Predicate = new Predicate("NOT", Array(child.toV2)) } /** @@ -255,6 +307,8 @@ case class Not(child: Filter) extends Filter { @Stable case class StringStartsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("STARTS_WITH", + Array(FieldReference(attribute), LiteralValue(UTF8String.fromString(value), StringType))) } /** @@ -269,6 +323,8 @@ case class StringStartsWith(attribute: String, value: String) extends Filter { @Stable case class StringEndsWith(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("ENDS_WITH", + Array(FieldReference(attribute), LiteralValue(UTF8String.fromString(value), StringType))) } /** @@ -283,6 +339,8 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { @Stable case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) + override def toV2: Predicate = new Predicate("CONTAINS", + Array(FieldReference(attribute), LiteralValue(UTF8String.fromString(value), StringType))) } /** @@ -293,6 +351,7 @@ case class StringContains(attribute: String, value: String) extends Filter { @Evolving case class AlwaysTrue() extends Filter { override def references: Array[String] = Array.empty + override def toV2: Predicate = new V2AlwaysTrue() } @Evolving @@ -307,6 +366,7 @@ object AlwaysTrue extends AlwaysTrue { @Evolving case class AlwaysFalse() extends Filter { override def references: Array[String] = Array.empty + override def toV2: Predicate = new V2AlwaysFalse() } @Evolving @@ -316,4 +376,9 @@ object AlwaysFalse extends AlwaysFalse { @Evolving case class Trivial(value: Boolean) extends Filter { override def references: Array[String] = findReferences(value) + override def toV2: Predicate = { + val literal = Literal(value) + new Predicate("TRIVIAL", + Array(LiteralValue(literal.value, literal.dataType))) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala index 1e7f9b0edd91c..c1d13d14b05f7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala @@ -121,6 +121,16 @@ class CollapseProjectSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("SPARK-36718: do not collapse project if non-cheap expressions will be repeated") { + val query = testRelation + .select(('a + 1).as('a_plus_1)) + .select(('a_plus_1 + 'a_plus_1).as('a_2_plus_2)) + .analyze + + val optimized = Optimize.execute(query) + comparePlans(optimized, query) + } + test("preserve top-level alias metadata while collapsing projects") { def hasMetadata(logicalPlan: LogicalPlan): Boolean = { logicalPlan.asInstanceOf[Project].projectList.exists(_.metadata.contains("key")) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala index b1baeccbe94b9..eb3899c9187db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala @@ -57,7 +57,14 @@ class ScanOperationSuite extends SparkFunSuite { test("Project which has the same non-deterministic expression with its child Project") { val project3 = Project(Seq(colA, colR), Project(Seq(colA, aliasR), relation)) - assert(ScanOperation.unapply(project3).isEmpty) + project3 match { + case ScanOperation(projects, filters, _: Project) => + assert(projects.size === 2) + assert(projects(0) === colA) + assert(projects(1) === colR) + assert(filters.isEmpty) + case _ => assert(false) + } } test("Project which has different non-deterministic expressions with its child Project") { @@ -73,13 +80,18 @@ class ScanOperationSuite extends SparkFunSuite { test("Filter with non-deterministic Project") { val filter1 = Filter(EqualTo(colA, Literal(1)), Project(Seq(colA, aliasR), relation)) - assert(ScanOperation.unapply(filter1).isEmpty) + filter1 match { + case ScanOperation(projects, filters, _: Filter) => + assert(projects.size === 2) + assert(filters.isEmpty) + case _ => assert(false) + } } test("Non-deterministic Filter with deterministic Project") { - val filter3 = Filter(EqualTo(MonotonicallyIncreasingID(), Literal(1)), + val filter2 = Filter(EqualTo(MonotonicallyIncreasingID(), Literal(1)), Project(Seq(colA, colB), relation)) - filter3 match { + filter2 match { case ScanOperation(projects, filters, _: LocalRelation) => assert(projects.size === 2) assert(projects(0) === colA) @@ -91,7 +103,11 @@ class ScanOperationSuite extends SparkFunSuite { test("Deterministic filter which has a non-deterministic child Filter") { - val filter4 = Filter(EqualTo(colA, Literal(1)), Filter(EqualTo(aliasR, Literal(1)), relation)) - assert(ScanOperation.unapply(filter4).isEmpty) + val filter3 = Filter(EqualTo(colA, Literal(1)), Filter(EqualTo(aliasR, Literal(1)), relation)) + filter3 match { + case ScanOperation(projects, filters, _: Filter) => + assert(filters.isEmpty) + case _ => assert(false) + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index 0cca1cc9bebf2..d00bc31e07f19 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -820,7 +820,7 @@ class CatalogSuite extends SparkFunSuite { assert(catalog.namespaceExists(testNs) === false) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === false) } @@ -833,7 +833,7 @@ class CatalogSuite extends SparkFunSuite { assert(catalog.namespaceExists(testNs) === true) assert(catalog.loadNamespaceMetadata(testNs).asScala === Map("property" -> "value")) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === true) assert(catalog.namespaceExists(testNs) === false) @@ -845,7 +845,7 @@ class CatalogSuite extends SparkFunSuite { catalog.createNamespace(testNs, Map("property" -> "value").asJava) catalog.createTable(testIdent, schema, Array.empty, emptyProps) - assert(catalog.dropNamespace(testNs)) + assert(catalog.dropNamespace(testNs, cascade = true)) assert(!catalog.namespaceExists(testNs)) intercept[NoSuchNamespaceException](catalog.listTables(testNs)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala index 2f3c5a38538c8..e0604576a94bc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala @@ -161,7 +161,7 @@ class InMemoryTable( case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } - case BucketTransform(numBuckets, ref) => + case BucketTransform(numBuckets, ref, _) => val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) val valueHashCode = if (value == null) 0 else value.hashCode ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index 0c403baca2113..41063a41b9719 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ -import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NonEmptyNamespaceException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.connector.distributions.{Distribution, Distributions} import org.apache.spark.sql.connector.expressions.{SortOrder, Transform} import org.apache.spark.sql.types.StructType @@ -193,10 +193,16 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp namespaces.put(namespace.toList, CatalogV2Util.applyNamespaceChanges(metadata, changes)) } - override def dropNamespace(namespace: Array[String]): Boolean = { - listNamespaces(namespace).foreach(dropNamespace) + override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = { try { - listTables(namespace).foreach(dropTable) + if (!cascade) { + if (listTables(namespace).nonEmpty || listNamespaces(namespace).nonEmpty) { + throw new NonEmptyNamespaceException(namespace) + } + } else { + listNamespaces(namespace).foreach(namespace => dropNamespace(namespace, cascade)) + listTables(namespace).foreach(dropTable) + } } catch { case _: NoSuchNamespaceException => } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala index fbd6a886d011b..4a50e063bee68 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala @@ -28,7 +28,7 @@ class TransformExtractorSuite extends SparkFunSuite { private def lit[T](literal: T): Literal[T] = new Literal[T] { override def value: T = literal override def dataType: DataType = catalyst.expressions.Literal(literal).dataType - override def describe: String = literal.toString + override def toString: String = literal.toString } /** @@ -36,7 +36,7 @@ class TransformExtractorSuite extends SparkFunSuite { */ private def ref(names: String*): NamedReference = new NamedReference { override def fieldNames: Array[String] = names.toArray - override def describe: String = names.mkString(".") + override def toString: String = names.mkString(".") } /** @@ -44,9 +44,8 @@ class TransformExtractorSuite extends SparkFunSuite { */ private def transform(func: String, ref: NamedReference): Transform = new Transform { override def name: String = func - override def references: Array[NamedReference] = Array(ref) override def arguments: Array[Expression] = Array(ref) - override def describe: String = ref.describe + override def toString: String = ref.describe } test("Identity extractor") { @@ -135,11 +134,11 @@ class TransformExtractorSuite extends SparkFunSuite { override def name: String = "bucket" override def references: Array[NamedReference] = Array(col) override def arguments: Array[Expression] = Array(lit(16), col) - override def describe: String = s"bucket(16, ${col.describe})" + override def toString: String = s"bucket(16, ${col.describe})" } bucketTransform match { - case BucketTransform(numBuckets, FieldReference(seq)) => + case BucketTransform(numBuckets, FieldReference(seq), _) => assert(numBuckets === 16) assert(seq === Seq("a", "b")) case _ => @@ -147,7 +146,7 @@ class TransformExtractorSuite extends SparkFunSuite { } transform("unknown", ref("a", "b")) match { - case BucketTransform(_, _) => + case BucketTransform(_, _, _) => fail("Matched unknown transform") case _ => // expected diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 85bb234cf9a97..998de75018d4e 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml @@ -136,7 +136,7 @@ com.h2database h2 - 1.4.195 + 2.0.204 test diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java new file mode 100644 index 0000000000000..8adb9e8ca20be --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; + +import java.util.ArrayList; +import java.util.List; + +/** + * Columns statistics interface wrapping ORC {@link ColumnStatistics}s. + * + * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer, + * this class is used to covert ORC {@link ColumnStatistics}s from array to nested tree structure, + * according to data types. The flatten array stores all data types (including nested types) in + * tree pre-ordering. This is used for aggregate push down in ORC. + * + * For nested data types (array, map and struct), the sub-field statistics are stored recursively + * inside parent column's children field. Here is an example of {@link OrcColumnStatistics}: + * + * Data schema: + * c1: int + * c2: struct + * c3: map + * c4: array + * + * OrcColumnStatistics + * | (children) + * --------------------------------------------- + * / | \ \ + * c1 c2 c3 c4 + * (integer) (struct) (map) (array) +* (min:1, | (children) | (children) | (children) + * max:10) ----- ----- element + * / \ / \ (integer) + * c2.f1 c2.f2 key value + * (integer) (float) (integer) (string) + * (min:0.1, (min:"a", + * max:100.5) max:"zzz") + */ +public class OrcColumnStatistics { + private final ColumnStatistics statistics; + private final List children; + + public OrcColumnStatistics(ColumnStatistics statistics) { + this.statistics = statistics; + this.children = new ArrayList<>(); + } + + public ColumnStatistics getStatistics() { + return statistics; + } + + public OrcColumnStatistics get(int ordinal) { + if (ordinal < 0 || ordinal >= children.size()) { + throw new IndexOutOfBoundsException( + String.format("Ordinal %d out of bounds of statistics size %d", ordinal, children.size())); + } + return children.get(ordinal); + } + + public void add(OrcColumnStatistics newChild) { + children.add(newChild); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java new file mode 100644 index 0000000000000..546b048648844 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.orc.ColumnStatistics; +import org.apache.orc.Reader; +import org.apache.orc.TypeDescription; +import org.apache.spark.sql.types.*; + +import java.util.Arrays; +import java.util.LinkedList; +import java.util.Queue; + +/** + * {@link OrcFooterReader} is a util class which encapsulates the helper + * methods of reading ORC file footer. + */ +public class OrcFooterReader { + + /** + * Read the columns statistics from ORC file footer. + * + * @param orcReader the reader to read ORC file footer. + * @return Statistics for all columns in the file. + */ + public static OrcColumnStatistics readStatistics(Reader orcReader) { + TypeDescription orcSchema = orcReader.getSchema(); + ColumnStatistics[] orcStatistics = orcReader.getStatistics(); + StructType sparkSchema = OrcUtils.toCatalystSchema(orcSchema); + return convertStatistics(sparkSchema, new LinkedList<>(Arrays.asList(orcStatistics))); + } + + /** + * Convert a queue of ORC {@link ColumnStatistics}s into Spark {@link OrcColumnStatistics}. + * The queue of ORC {@link ColumnStatistics}s are assumed to be ordered as tree pre-order. + */ + private static OrcColumnStatistics convertStatistics( + DataType sparkSchema, Queue orcStatistics) { + OrcColumnStatistics statistics = new OrcColumnStatistics(orcStatistics.remove()); + if (sparkSchema instanceof StructType) { + for (StructField field : ((StructType) sparkSchema).fields()) { + statistics.add(convertStatistics(field.dataType(), orcStatistics)); + } + } else if (sparkSchema instanceof MapType) { + statistics.add(convertStatistics(((MapType) sparkSchema).keyType(), orcStatistics)); + statistics.add(convertStatistics(((MapType) sparkSchema).valueType(), orcStatistics)); + } else if (sparkSchema instanceof ArrayType) { + statistics.add(convertStatistics(((ArrayType) sparkSchema).elementType(), orcStatistics)); + } + return statistics; + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala new file mode 100644 index 0000000000000..b9847d48b2e17 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket} +import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} +import org.apache.spark.sql.execution.datasources.PushableColumn +import org.apache.spark.sql.types.BooleanType + +/** + * The builder to generate V2 expressions from catalyst expressions. + */ +class V2ExpressionBuilder( + e: Expression, nestedPredicatePushdownEnabled: Boolean = false, isPredicate: Boolean = false) { + + val pushableColumn = PushableColumn(nestedPredicatePushdownEnabled) + + def build(): Option[V2Expression] = generateExpression(e, isPredicate) + + private def canTranslate(b: BinaryOperator) = b match { + case _: And | _: Or => true + case _: BinaryComparison => true + case _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true + case add: Add => add.failOnError + case sub: Subtract => sub.failOnError + case mul: Multiply => mul.failOnError + case div: Divide => div.failOnError + case r: Remainder => r.failOnError + case _ => false + } + + private def generateExpression( + expr: Expression, isPredicate: Boolean = false): Option[V2Expression] = expr match { + case Literal(true, BooleanType) => Some(new AlwaysTrue()) + case Literal(false, BooleanType) => Some(new AlwaysFalse()) + case Literal(value, dataType) => Some(LiteralValue(value, dataType)) + case col @ pushableColumn(name) if nestedPredicatePushdownEnabled => + if (isPredicate && col.dataType.isInstanceOf[BooleanType]) { + Some(new V2Predicate("=", Array(FieldReference(name), LiteralValue(true, BooleanType)))) + } else { + Some(FieldReference(name)) + } + case pushableColumn(name) if !nestedPredicatePushdownEnabled => + Some(FieldReference(name)) + case in @ InSet(child, hset) => + generateExpression(child).map { v => + val children = + (v +: hset.toSeq.map(elem => LiteralValue(elem, in.dataType))).toArray[V2Expression] + new V2Predicate("IN", children) + } + // Because we only convert In to InSet in Optimizer when there are more than certain + // items. So it is possible we still get an In expression here that needs to be pushed + // down. + case In(value, list) => + val v = generateExpression(value) + val listExpressions = list.flatMap(generateExpression(_)) + if (v.isDefined && list.length == listExpressions.length) { + val children = (v.get +: listExpressions).toArray[V2Expression] + // The children looks like [expr, value1, ..., valueN] + Some(new V2Predicate("IN", children)) + } else { + None + } + case IsNull(col) => generateExpression(col) + .map(c => new V2Predicate("IS_NULL", Array[V2Expression](c))) + case IsNotNull(col) => generateExpression(col) + .map(c => new V2Predicate("IS_NOT_NULL", Array[V2Expression](c))) + case p: StringPredicate => + val left = generateExpression(p.left) + val right = generateExpression(p.right) + if (left.isDefined && right.isDefined) { + val name = p match { + case _: StartsWith => "STARTS_WITH" + case _: EndsWith => "ENDS_WITH" + case _: Contains => "CONTAINS" + } + Some(new V2Predicate(name, Array[V2Expression](left.get, right.get))) + } else { + None + } + case Cast(child, dataType, _, true) => + generateExpression(child).map(v => new V2Cast(v, dataType)) + case Abs(child, true) => generateExpression(child) + .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v))) + case Coalesce(children) => + val childrenExpressions = children.flatMap(generateExpression(_)) + if (children.length == childrenExpressions.length) { + Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression])) + } else { + None + } + case Log(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v))) + case Exp(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v))) + case Pow(left, right) => + val l = generateExpression(left) + val r = generateExpression(right) + if (l.isDefined && r.isDefined) { + Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get))) + } else { + None + } + case Sqrt(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v))) + case Floor(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v))) + case Ceil(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v))) + case wb: WidthBucket => + val childrenExpressions = wb.children.flatMap(generateExpression(_)) + if (childrenExpressions.length == wb.children.length) { + Some(new GeneralScalarExpression("WIDTH_BUCKET", + childrenExpressions.toArray[V2Expression])) + } else { + None + } + case and: And => + // AND expects predicate + val l = generateExpression(and.left, true) + val r = generateExpression(and.right, true) + if (l.isDefined && r.isDefined) { + assert(l.get.isInstanceOf[V2Predicate] && r.get.isInstanceOf[V2Predicate]) + Some(new V2And(l.get.asInstanceOf[V2Predicate], r.get.asInstanceOf[V2Predicate])) + } else { + None + } + case or: Or => + // OR expects predicate + val l = generateExpression(or.left, true) + val r = generateExpression(or.right, true) + if (l.isDefined && r.isDefined) { + assert(l.get.isInstanceOf[V2Predicate] && r.get.isInstanceOf[V2Predicate]) + Some(new V2Or(l.get.asInstanceOf[V2Predicate], r.get.asInstanceOf[V2Predicate])) + } else { + None + } + case b: BinaryOperator if canTranslate(b) => + val l = generateExpression(b.left) + val r = generateExpression(b.right) + if (l.isDefined && r.isDefined) { + b match { + case _: Predicate => + Some(new V2Predicate(b.sqlOperator, Array[V2Expression](l.get, r.get))) + case _ => + Some(new GeneralScalarExpression(b.sqlOperator, Array[V2Expression](l.get, r.get))) + } + } else { + None + } + case Not(eq: EqualTo) => + val left = generateExpression(eq.left) + val right = generateExpression(eq.right) + if (left.isDefined && right.isDefined) { + Some(new V2Predicate("<>", Array[V2Expression](left.get, right.get))) + } else { + None + } + case Not(child) => generateExpression(child, true) // NOT expects predicate + .map { v => + assert(v.isInstanceOf[V2Predicate]) + new V2Not(v.asInstanceOf[V2Predicate]) + } + case UnaryMinus(child, true) => generateExpression(child) + .map(v => new GeneralScalarExpression("-", Array[V2Expression](v))) + case BitwiseNot(child) => generateExpression(child) + .map(v => new GeneralScalarExpression("~", Array[V2Expression](v))) + case CaseWhen(branches, elseValue) => + val conditions = branches.map(_._1).flatMap(generateExpression(_, true)) + val values = branches.map(_._2).flatMap(generateExpression(_, true)) + if (conditions.length == branches.length && values.length == branches.length) { + val branchExpressions = conditions.zip(values).flatMap { case (c, v) => + Seq[V2Expression](c, v) + } + if (elseValue.isDefined) { + elseValue.flatMap(generateExpression(_)).map { v => + val children = (branchExpressions :+ v).toArray[V2Expression] + // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue] + new V2Predicate("CASE_WHEN", children) + } + } else { + // The children looks like [condition1, value1, ..., conditionN, valueN] + Some(new V2Predicate("CASE_WHEN", branchExpressions.toArray[V2Expression])) + } + } else { + None + } + // TODO supports other expressions + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index efc459c8241fa..432775c9045ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource} +import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, Filter} @@ -103,7 +103,7 @@ case class RowDataSourceScanExec( requiredSchema: StructType, filters: Set[Filter], handledFilters: Set[Filter], - aggregation: Option[Aggregation], + pushedDownOperators: PushedDownOperators, rdd: RDD[InternalRow], @transient relation: BaseRelation, tableIdentifier: Option[TableIdentifier]) @@ -134,13 +134,6 @@ case class RowDataSourceScanExec( def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]") - val (aggString, groupByString) = if (aggregation.nonEmpty) { - (seqToString(aggregation.get.aggregateExpressions), - seqToString(aggregation.get.groupByColumns)) - } else { - ("[]", "[]") - } - val markedFilters = if (filters.nonEmpty) { for (filter <- filters) yield { if (handledFilters.contains(filter)) s"*$filter" else s"$filter" @@ -149,11 +142,31 @@ case class RowDataSourceScanExec( handledFilters } - Map( - "ReadSchema" -> requiredSchema.catalogString, - "PushedFilters" -> seqToString(markedFilters.toSeq), - "PushedAggregates" -> aggString, - "PushedGroupby" -> groupByString) + val topNOrLimitInfo = + if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) { + val pushedTopN = + s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" + + s" LIMIT ${pushedDownOperators.limit.get}" + Some("PushedTopN" -> pushedTopN) + } else { + pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value") + } + + val pushedFilters = if (pushedDownOperators.pushedPredicates.nonEmpty) { + seqToString(pushedDownOperators.pushedPredicates.map(_.describe())) + } else { + seqToString(markedFilters.toSeq) + } + + Map("ReadSchema" -> requiredSchema.catalogString, + "PushedFilters" -> pushedFilters) ++ + pushedDownOperators.aggregation.fold(Map[String, String]()) { v => + Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())), + "PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++ + topNOrLimitInfo ++ + pushedDownOperators.sample.map(v => "PushedSample" -> + s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})" + ) } // Don't care about `rdd` and `tableIdentifier` when canonicalizing. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala new file mode 100644 index 0000000000000..6d8cae544f23e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils +import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} +import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType} +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} + +/** + * Utility class for aggregate push down to Parquet and ORC. + */ +object AggregatePushDownUtils { + + /** + * Get the data schema for aggregate to be pushed down. + */ + def getSchemaForPushedAggregation( + aggregation: Aggregation, + schema: StructType, + partitionNames: Set[String], + dataFilters: Seq[Expression]): Option[StructType] = { + + var finalSchema = new StructType() + + def getStructFieldForCol(colName: String): StructField = { + schema.apply(colName) + } + + def isPartitionCol(colName: String) = { + partitionNames.contains(colName) + } + + def processMinOrMax(agg: AggregateFunc): Boolean = { + val (columnName, aggType) = agg match { + case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined => + (V2ColumnUtils.extractV2Column(max.column).get, "max") + case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined => + (V2ColumnUtils.extractV2Column(min.column).get, "min") + case _ => return false + } + + if (isPartitionCol(columnName)) { + // don't push down partition column, footer doesn't have max/min for partition column + return false + } + val structField = getStructFieldForCol(columnName) + + structField.dataType match { + // not push down complex type + // not push down Timestamp because INT96 sort order is undefined, + // Parquet doesn't return statistics for INT96 + // not push down Parquet Binary because min/max could be truncated + // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary + // could be Spark StringType, BinaryType or DecimalType. + // not push down for ORC with same reason. + case BooleanType | ByteType | ShortType | IntegerType + | LongType | FloatType | DoubleType | DateType => + finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")")) + true + case _ => + false + } + } + + if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) { + // Parquet/ORC footer has max/min/count for columns + // e.g. SELECT COUNT(col1) FROM t + // but footer doesn't have max/min/count for a column if max/min/count + // are combined with filter or group by + // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8 + // SELECT COUNT(col1) FROM t GROUP BY col2 + // However, if the filter is on partition column, max/min/count can still be pushed down + // Todo: add support if groupby column is partition col + // (https://issues.apache.org/jira/browse/SPARK-36646) + return None + } + aggregation.groupByColumns.foreach { col => + // don't push down if the group by columns are not the same as the partition columns (orders + // doesn't matter because reorder can be done at data source layer) + if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None + finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head)) + } + + aggregation.aggregateExpressions.foreach { + case max: Max => + if (!processMinOrMax(max)) return None + case min: Min => + if (!processMinOrMax(min)) return None + case count: Count + if V2ColumnUtils.extractV2Column(count.column).isDefined && !count.isDistinct => + val columnName = V2ColumnUtils.extractV2Column(count.column).get + finalSchema = finalSchema.add(StructField(s"count($columnName)", LongType)) + case _: CountStar => + finalSchema = finalSchema.add(StructField("count(*)", LongType)) + case _ => + return None + } + + Some(finalSchema) + } + + /** + * Check if two Aggregation `a` and `b` is equal or not. + */ + def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = { + a.aggregateExpressions.sortBy(_.hashCode()) + .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) && + a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode())) + } + + /** + * Convert the aggregates result from `InternalRow` to `ColumnarBatch`. + * This is used for columnar reader. + */ + def convertAggregatesRowToBatch( + aggregatesAsRow: InternalRow, + aggregatesSchema: StructType, + offHeap: Boolean): ColumnarBatch = { + val converter = new RowToColumnConverter(aggregatesSchema) + val columnVectors = if (offHeap) { + OffHeapColumnVector.allocateColumns(1, aggregatesSchema) + } else { + OnHeapColumnVector.allocateColumns(1, aggregatesSchema) + } + converter.convert(aggregatesAsRow, columnVectors.toArray) + new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index a53665fe2f0e4..408da524cbb04 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -38,13 +38,15 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 +import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.FieldReference -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators import org.apache.spark.sql.execution.streaming.StreamingRelation import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.sources._ @@ -335,7 +337,7 @@ object DataSourceStrategy l.output.toStructType, Set.empty, Set.empty, - None, + PushedDownOperators(None, None, None, Seq.empty, Seq.empty), toCatalystRDD(l, baseRelation.buildScan()), baseRelation, None) :: Nil @@ -409,7 +411,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - None, + PushedDownOperators(None, None, None, Seq.empty, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -432,7 +434,7 @@ object DataSourceStrategy requestedColumns.toStructType, pushedFilters.toSet, handledFilters, - None, + PushedDownOperators(None, None, None, Seq.empty, Seq.empty), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, relation.catalogTable.map(_.identifier)) @@ -698,23 +700,44 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } - protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = { - if (aggregates.filter.isEmpty) { - aggregates.aggregateFunction match { - case aggregate.Min(PushableColumnWithoutNestedColumn(name)) => - Some(new Min(FieldReference(name))) - case aggregate.Max(PushableColumnWithoutNestedColumn(name)) => - Some(new Max(FieldReference(name))) + protected[sql] def translateAggregate(agg: AggregateExpression): Option[AggregateFunc] = { + if (agg.filter.isEmpty) { + agg.aggregateFunction match { + case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) + case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) case count: aggregate.Count if count.children.length == 1 => count.children.head match { - // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table + // COUNT(any literal) is the same as COUNT(*) case Literal(_, _) => Some(new CountStar()) - case PushableColumnWithoutNestedColumn(name) => - Some(new Count(FieldReference(name), aggregates.isDistinct)) + case PushableExpression(expr) => Some(new Count(expr, agg.isDistinct)) case _ => None } - case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) => - Some(new Sum(FieldReference(name), aggregates.isDistinct)) + case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, agg.isDistinct)) + case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, agg.isDistinct)) + case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc( + "VAR_POP", agg.isDistinct, Array(FieldReference(name)))) + case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc( + "VAR_SAMP", agg.isDistinct, Array(FieldReference(name)))) + case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc( + "STDDEV_POP", agg.isDistinct, Array(FieldReference(name)))) + case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) => + Some(new GeneralAggregateFunc( + "STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name)))) + case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct, + Array(FieldReference(left), FieldReference(right)))) + case aggregate.CovSample(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct, + Array(FieldReference(left), FieldReference(right)))) + case aggregate.Corr(PushableColumnWithoutNestedColumn(left), + PushableColumnWithoutNestedColumn(right), _) => + Some(new GeneralAggregateFunc("CORR", agg.isDistinct, + Array(FieldReference(left), FieldReference(right)))) case _ => None } } else { @@ -722,6 +745,49 @@ object DataSourceStrategy } } + /** + * Translate aggregate expressions and group by expressions. + * + * @return translated aggregation. + */ + protected[sql] def translateAggregation( + aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = { + + def columnAsString(e: Expression): Option[FieldReference] = e match { + case PushableColumnWithoutNestedColumn(name) => + Some(FieldReference(name).asInstanceOf[FieldReference]) + case _ => None + } + + val translatedAggregates = aggregates.flatMap(translateAggregate) + val translatedGroupBys = groupBy.flatMap(columnAsString) + + if (translatedAggregates.length != aggregates.length || + translatedGroupBys.length != groupBy.length) { + return None + } + + Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)) + } + + protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[V2SortOrder] = { + def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match { + case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) => + val directionV2 = directionV1 match { + case Ascending => SortDirection.ASCENDING + case Descending => SortDirection.DESCENDING + } + val nullOrderingV2 = nullOrderingV1 match { + case NullsFirst => NullOrdering.NULLS_FIRST + case NullsLast => NullOrdering.NULLS_LAST + } + Some(SortValue(FieldReference(name), directionV2, nullOrderingV2)) + case _ => None + } + + sortOrders.flatMap(translateOortOrder) + } + /** * Convert RDD of Row into RDD of InternalRow with objects in catalyst types */ @@ -787,3 +853,10 @@ object PushableColumnAndNestedColumn extends PushableColumnBase { object PushableColumnWithoutNestedColumn extends PushableColumnBase { override val nestedPredicatePushdownEnabled = false } + +/** + * Get the expression of DS V2 to represent catalyst expression that can be pushed down. + */ +object PushableExpression { + def unapply(e: Expression): Option[V2Expression] = new V2ExpressionBuilder(e).build() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index fcd95a27bf8ca..67d03998a2a24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -28,6 +28,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.SparkUpgradeException import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_LEGACY_INT96, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper} import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions @@ -39,7 +40,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils -object DataSourceUtils { +object DataSourceUtils extends PredicateHelper { /** * The key to use for storing partitionBy columns as options. */ @@ -242,4 +243,22 @@ object DataSourceUtils { options } } + + def getPartitionFiltersAndDataFilters( + partitionSchema: StructType, + normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = { + val partitionColumns = normalizedFilters.flatMap { expr => + expr.collect { + case attr: AttributeReference if partitionSchema.names.contains(attr.name) => + attr + } + } + val partitionSet = AttributeSet(partitionColumns) + val (partitionFilters, dataFilters) = normalizedFilters.partition(f => + f.references.subsetOf(partitionSet) + ) + val extraPartitionFilter = + dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) + (ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala index 0927027bee0bc..2e8e5426d47be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala @@ -17,52 +17,24 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogStatistics import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan} -import org.apache.spark.sql.types.StructType /** * Prune the partitions of file source based table using partition filters. Currently, this rule - * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]] and [[DataSourceV2ScanRelation]] - * with [[FileScan]]. + * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]]. * * For [[HadoopFsRelation]], the location will be replaced by pruned file index, and corresponding * statistics will be updated. And the partition filters will be kept in the filters of returned * logical plan. - * - * For [[DataSourceV2ScanRelation]], both partition filters and data filters will be added to - * its underlying [[FileScan]]. And the partition filters will be removed in the filters of - * returned logical plan. */ private[sql] object PruneFileSourcePartitions extends Rule[LogicalPlan] with PredicateHelper { - private def getPartitionKeyFiltersAndDataFilters( - sparkSession: SparkSession, - relation: LeafNode, - partitionSchema: StructType, - filters: Seq[Expression], - output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = { - val normalizedFilters = DataSourceStrategy.normalizeExprs( - filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output) - val partitionColumns = - relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver) - val partitionSet = AttributeSet(partitionColumns) - val (partitionFilters, dataFilters) = normalizedFilters.partition(f => - f.references.subsetOf(partitionSet) - ) - val extraPartitionFilter = - dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet)) - - (ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters) - } - private def rebuildPhysicalOperation( projects: Seq[NamedExpression], filters: Seq[Expression], @@ -91,12 +63,14 @@ private[sql] object PruneFileSourcePartitions _, _)) if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined => - val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters( - fsRelation.sparkSession, logicalRelation, partitionSchema, filters, + val normalizedFilters = DataSourceStrategy.normalizeExprs( + filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), logicalRelation.output) + val (partitionKeyFilters, _) = DataSourceUtils + .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters) if (partitionKeyFilters.nonEmpty) { - val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq) + val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters) val prunedFsRelation = fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession) // Change table stats based on the sizeInBytes of pruned files @@ -117,23 +91,5 @@ private[sql] object PruneFileSourcePartitions } else { op } - - case op @ PhysicalOperation(projects, filters, - v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output)) - if filters.nonEmpty => - val (partitionKeyFilters, dataFilters) = - getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation, - scan.readPartitionSchema, filters, output) - // The dataFilters are pushed down only once - if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) { - val prunedV2Relation = - v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters)) - // The pushed down partition filters don't need to be reevaluated. - val afterScanFilters = - ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty) - rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation) - } else { - op - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 8b2ae2beb6d4a..8e047d7f7c7d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -191,6 +191,14 @@ class JDBCOptions( // An option to allow/disallow pushing down aggregate into JDBC data source val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean + // An option to allow/disallow pushing down LIMIT into V2 JDBC data source + // This only applies to Data Source V2 JDBC + val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean + + // An option to allow/disallow pushing down TABLESAMPLE into JDBC data source + // This only applies to Data Source V2 JDBC + val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean + // The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either // by --files option of spark-submit or manually val keytab = { @@ -263,6 +271,8 @@ object JDBCOptions { val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement") val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate") val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate") + val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit") + val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample") val JDBC_KEYTAB = newOption("keytab") val JDBC_PRINCIPAL = newOption("principal") val JDBC_TABLE_COMMENT = newOption("tableComment") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index e024e4bb02102..b30b460ac67db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -25,9 +25,10 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects} -import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.CompletionIterator @@ -59,7 +60,7 @@ object JDBCRDD extends Logging { def getQueryOutputSchema( query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = { - val conn: Connection = JdbcUtils.createConnectionFactory(options)() + val conn: Connection = dialect.createConnectionFactory(options)(-1) try { val statement = conn.prepareStatement(query) try { @@ -91,106 +92,38 @@ object JDBCRDD extends Logging { new StructType(columns.map(name => fieldMap(name))) } - /** - * Turns a single Filter into a String representing a SQL expression. - * Returns None for an unhandled filter. - */ - def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = { - def quote(colName: String): String = dialect.quoteIdentifier(colName) - - Option(f match { - case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}" - case EqualNullSafe(attr, value) => - val col = quote(attr) - s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " + - s"${dialect.compileValue(value)} IS NULL) OR " + - s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))" - case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}" - case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}" - case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}" - case IsNull(attr) => s"${quote(attr)} IS NULL" - case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL" - case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'" - case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'" - case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'" - case In(attr, value) if value.isEmpty => - s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END" - case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})" - case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null) - case Or(f1, f2) => - // We can't compile Or filter unless both sub-filters are compiled successfully. - // It applies too for the following And filter. - // If we can make sure compileFilter supports all filters, we can remove this check. - val or = Seq(f1, f2).flatMap(compileFilter(_, dialect)) - if (or.size == 2) { - or.map(p => s"($p)").mkString(" OR ") - } else { - null - } - case And(f1, f2) => - val and = Seq(f1, f2).flatMap(compileFilter(_, dialect)) - if (and.size == 2) { - and.map(p => s"($p)").mkString(" AND ") - } else { - null - } - case _ => null - }) - } - - def compileAggregates( - aggregates: Seq[AggregateFunc], - dialect: JdbcDialect): Option[Seq[String]] = { - def quote(colName: String): String = dialect.quoteIdentifier(colName) - - Some(aggregates.map { - case min: Min => - if (min.column.fieldNames.length != 1) return None - s"MIN(${quote(min.column.fieldNames.head)})" - case max: Max => - if (max.column.fieldNames.length != 1) return None - s"MAX(${quote(max.column.fieldNames.head)})" - case count: Count => - if (count.column.fieldNames.length != 1) return None - val distinct = if (count.isDistinct) "DISTINCT " else "" - val column = quote(count.column.fieldNames.head) - s"COUNT($distinct$column)" - case sum: Sum => - if (sum.column.fieldNames.length != 1) return None - val distinct = if (sum.isDistinct) "DISTINCT " else "" - val column = quote(sum.column.fieldNames.head) - s"SUM($distinct$column)" - case _: CountStar => - s"COUNT(*)" - case _ => return None - }) - } - /** * Build and return JDBCRDD from the given information. * * @param sc - Your SparkContext. * @param schema - The Catalyst schema of the underlying database table. * @param requiredColumns - The names of the columns or aggregate columns to SELECT. - * @param filters - The filters to include in all WHERE clauses. + * @param predicates - The predicates to include in all WHERE clauses. * @param parts - An array of JDBCPartitions specifying partition ids and * per-partition WHERE clauses. * @param options - JDBC options that contains url, table and other information. * @param outputSchema - The schema of the columns or aggregate columns to SELECT. * @param groupByColumns - The pushed down group by columns. + * @param sample - The pushed down tableSample. + * @param limit - The pushed down limit. If the value is 0, it means no limit or limit + * is not pushed down. + * @param sortOrders - The sort orders cooperates with limit to realize top N. * * @return An RDD representing "SELECT requiredColumns FROM fqTable". */ + // scalastyle:off argcount def scanTable( sc: SparkContext, schema: StructType, requiredColumns: Array[String], - filters: Array[Filter], + predicates: Array[Predicate], parts: Array[Partition], options: JDBCOptions, outputSchema: Option[StructType] = None, - groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = { + groupByColumns: Option[Array[String]] = None, + sample: Option[TableSampleInfo] = None, + limit: Int = 0, + sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = { val url = options.url val dialect = JdbcDialects.get(url) val quotedColumns = if (groupByColumns.isEmpty) { @@ -201,15 +134,19 @@ object JDBCRDD extends Logging { } new JDBCRDD( sc, - JdbcUtils.createConnectionFactory(options), + dialect.createConnectionFactory(options), outputSchema.getOrElse(pruneSchema(schema, requiredColumns)), quotedColumns, - filters, + predicates, parts, url, options, - groupByColumns) + groupByColumns, + sample, + limit, + sortOrders) } + // scalastyle:on argcount } /** @@ -219,14 +156,17 @@ object JDBCRDD extends Logging { */ private[jdbc] class JDBCRDD( sc: SparkContext, - getConnection: () => Connection, + getConnection: Int => Connection, schema: StructType, columns: Array[String], - filters: Array[Filter], + predicates: Array[Predicate], partitions: Array[Partition], url: String, options: JDBCOptions, - groupByColumns: Option[Array[String]]) + groupByColumns: Option[Array[String]], + sample: Option[TableSampleInfo], + limit: Int, + sortOrders: Array[SortOrder]) extends RDD[InternalRow](sc, Nil) { /** @@ -242,10 +182,10 @@ private[jdbc] class JDBCRDD( /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = - filters - .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url))) - .map(p => s"($p)").mkString(" AND ") + private val filterWhereClause: String = { + val dialect = JdbcDialects.get(url) + predicates.flatMap(dialect.compileExpression(_)).map(p => s"($p)").mkString(" AND ") + } /** * A WHERE clause representing both `filters`, if any, and the current partition. @@ -274,6 +214,14 @@ private[jdbc] class JDBCRDD( } } + private def getOrderByClause: String = { + if (sortOrders.nonEmpty) { + s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}" + } else { + "" + } + } + /** * Runs the SQL query against the JDBC driver. * @@ -322,7 +270,7 @@ private[jdbc] class JDBCRDD( val inputMetrics = context.taskMetrics().inputMetrics val part = thePart.asInstanceOf[JDBCPartition] - conn = getConnection() + conn = getConnection(part.idx) val dialect = JdbcDialects.get(url) import scala.collection.JavaConverters._ dialect.beforeFetch(conn, options.asProperties.asScala.toMap) @@ -349,8 +297,16 @@ private[jdbc] class JDBCRDD( val myWhereClause = getWhereClause(part) - val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" + - s" $getGroupByClause" + val myTableSampleClause: String = if (sample.nonEmpty) { + JdbcDialects.get(url).getTableSample(sample.get) + } else { + "" + } + + val myLimitClause: String = dialect.getLimitClause(limit) + + val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" + + s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause" stmt = conn.prepareStatement(sqlText, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) stmt.setFetchSize(options.fetchSize) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 8098fa0b83a95..0f1a1b6dc667b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -27,7 +27,10 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} +import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ @@ -268,10 +271,11 @@ private[sql] case class JDBCRelation( override val needConversion: Boolean = false - // Check if JDBCRDD.compileFilter can accept input filters + // Check if JdbcDialect can compile input filters override def unhandledFilters(filters: Array[Filter]): Array[Filter] = { if (jdbcOptions.pushDownPredicate) { - filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty) + val dialect = JdbcDialects.get(jdbcOptions.url) + filters.filter(f => dialect.compileExpression(f.toV2).isEmpty) } else { filters } @@ -279,17 +283,17 @@ private[sql] case class JDBCRelation( override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { // When pushDownPredicate is false, all Filters that need to be pushed down should be ignored - val pushedFilters = if (jdbcOptions.pushDownPredicate) { - filters + val pushedPredicates = if (jdbcOptions.pushDownPredicate) { + filters.map(_.toV2) } else { - Array.empty[Filter] + Array.empty[Predicate] } // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, schema, requiredColumns, - pushedFilters, + pushedPredicates, parts, jdbcOptions).asInstanceOf[RDD[Row]] } @@ -297,18 +301,24 @@ private[sql] case class JDBCRelation( def buildScan( requiredColumns: Array[String], finalSchema: StructType, - filters: Array[Filter], - groupByColumns: Option[Array[String]]): RDD[Row] = { + predicates: Array[Predicate], + groupByColumns: Option[Array[String]], + tableSample: Option[TableSampleInfo], + limit: Int, + sortOrders: Array[SortOrder]): RDD[Row] = { // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sparkSession.sparkContext, schema, requiredColumns, - filters, + predicates, parts, jdbcOptions, Some(finalSchema), - groupByColumns).asInstanceOf[RDD[Row]] + groupByColumns, + tableSample, + limit, + sortOrders).asInstanceOf[RDD[Row]] } override def insert(data: DataFrame, overwrite: Boolean): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala index d953ba45cc2fb..2760c7ac3019c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._ +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider} class JdbcRelationProvider extends CreatableRelationProvider @@ -45,8 +46,8 @@ class JdbcRelationProvider extends CreatableRelationProvider df: DataFrame): BaseRelation = { val options = new JdbcOptionsInWrite(parameters) val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis - - val conn = JdbcUtils.createConnectionFactory(options)() + val dialect = JdbcDialects.get(options.url) + val conn = dialect.createConnectionFactory(options)(-1) try { val tableExists = JdbcUtils.tableExists(conn, options) if (tableExists) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 60fcaf94e1986..2d0cbcff8ecc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Driver, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} +import java.sql.{Connection, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException} import java.time.{Instant, LocalDate} +import java.util import java.util.Locale import java.util.concurrent.TimeUnit +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.Try import scala.util.control.NonFatal @@ -37,8 +40,9 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp} import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} -import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} import org.apache.spark.sql.types._ @@ -50,23 +54,6 @@ import org.apache.spark.util.NextIterator * Util functions for JDBC tables. */ object JdbcUtils extends Logging { - /** - * Returns a factory for creating connections to the given JDBC URL. - * - * @param options - JDBC options that contains url, table and other information. - * @throws IllegalArgumentException if the driver could not open a JDBC connection. - */ - def createConnectionFactory(options: JDBCOptions): () => Connection = { - val driverClass: String = options.driverClass - () => { - DriverRegistry.register(driverClass) - val driver: Driver = DriverRegistry.get(driverClass) - val connection = ConnectionProvider.create(driver, options.parameters) - require(connection != null, - s"The driver could not open a JDBC connection. Check the URL: ${options.url}") - connection - } - } /** * Returns true if the table already exists in the JDBC database. @@ -651,7 +638,6 @@ object JdbcUtils extends Logging { * updated even with error if it doesn't support transaction, as there're dirty outputs. */ def savePartition( - getConnection: () => Connection, table: String, iterator: Iterator[Row], rddSchema: StructType, @@ -662,7 +648,7 @@ object JdbcUtils extends Logging { options: JDBCOptions): Unit = { val outMetrics = TaskContext.get().taskMetrics().outputMetrics - val conn = getConnection() + val conn = dialect.createConnectionFactory(options)(-1) var committed = false var finalIsolationLevel = Connection.TRANSACTION_NONE @@ -874,7 +860,6 @@ object JdbcUtils extends Logging { val table = options.table val dialect = JdbcDialects.get(url) val rddSchema = df.schema - val getConnection: () => Connection = createConnectionFactory(options) val batchSize = options.batchSize val isolationLevel = options.isolationLevel @@ -886,8 +871,7 @@ object JdbcUtils extends Logging { case _ => df } repartitionedDF.rdd.foreachPartition { iterator => savePartition( - getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, - options) + table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options) } } @@ -971,52 +955,108 @@ object JdbcUtils extends Logging { } /** - * Creates a namespace. + * Creates a schema. */ - def createNamespace( + def createSchema( conn: Connection, options: JDBCOptions, - namespace: String, + schema: String, comment: String): Unit = { + val statement = conn.createStatement + try { + statement.setQueryTimeout(options.queryTimeout) + val dialect = JdbcDialects.get(options.url) + dialect.createSchema(statement, schema, comment) + } finally { + statement.close() + } + } + + def schemaExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = { val dialect = JdbcDialects.get(options.url) - executeStatement(conn, options, s"CREATE SCHEMA ${dialect.quoteIdentifier(namespace)}") - if (!comment.isEmpty) createNamespaceComment(conn, options, namespace, comment) + dialect.schemasExists(conn, options, schema) } - def createNamespaceComment( + def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = { + val dialect = JdbcDialects.get(options.url) + dialect.listSchemas(conn, options) + } + + def alterSchemaComment( conn: Connection, options: JDBCOptions, - namespace: String, + schema: String, comment: String): Unit = { val dialect = JdbcDialects.get(options.url) - try { - executeStatement( - conn, options, dialect.getSchemaCommentQuery(namespace, comment)) - } catch { - case e: Exception => - logWarning("Cannot create JDBC catalog comment. The catalog comment will be ignored.") - } + executeStatement(conn, options, dialect.getSchemaCommentQuery(schema, comment)) } - def removeNamespaceComment( + def removeSchemaComment( conn: Connection, options: JDBCOptions, - namespace: String): Unit = { + schema: String): Unit = { val dialect = JdbcDialects.get(options.url) - try { - executeStatement(conn, options, dialect.removeSchemaCommentQuery(namespace)) - } catch { - case e: Exception => - logWarning("Cannot drop JDBC catalog comment.") - } + executeStatement(conn, options, dialect.removeSchemaCommentQuery(schema)) + } + + /** + * Drops a schema from the JDBC database. + */ + def dropSchema( + conn: Connection, options: JDBCOptions, schema: String, cascade: Boolean): Unit = { + val dialect = JdbcDialects.get(options.url) + executeStatement(conn, options, dialect.dropSchema(schema, cascade)) + } + + /** + * Create an index. + */ + def createIndex( + conn: Connection, + indexName: String, + tableName: String, + columns: Array[NamedReference], + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String], + options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) + executeStatement(conn, options, + dialect.createIndex(indexName, tableName, columns, columnsProperties, properties)) + } + + /** + * Check if an index exists + */ + def indexExists( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Boolean = { + val dialect = JdbcDialects.get(options.url) + dialect.indexExists(conn, indexName, tableName, options) } /** - * Drops a namespace from the JDBC database. + * Drop an index. */ - def dropNamespace(conn: Connection, options: JDBCOptions, namespace: String): Unit = { + def dropIndex( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Unit = { + val dialect = JdbcDialects.get(options.url) + executeStatement(conn, options, dialect.dropIndex(indexName, tableName)) + } + + /** + * List all the indexes in a table. + */ + def listIndexes( + conn: Connection, + tableName: String, + options: JDBCOptions): Array[TableIndex] = { val dialect = JdbcDialects.get(options.url) - executeStatement(conn, options, s"DROP SCHEMA ${dialect.quoteIdentifier(namespace)}") + dialect.listIndexes(conn, tableName, options) } private def executeStatement(conn: Connection, options: JDBCOptions, sql: String): Unit = { @@ -1028,4 +1068,105 @@ object JdbcUtils extends Logging { statement.close() } } + + /** + * Check if index exists in a table + */ + def checkIfIndexExists( + conn: Connection, + sql: String, + options: JDBCOptions): Boolean = { + val statement = conn.createStatement + try { + statement.setQueryTimeout(options.queryTimeout) + val rs = statement.executeQuery(sql) + rs.next + } catch { + case _: Exception => + logWarning("Cannot retrieved index info.") + false + } finally { + statement.close() + } + } + + /** + * Process index properties and return tuple of indexType and list of the other index properties. + */ + def processIndexProperties( + properties: util.Map[String, String], + catalogName: String): (String, Array[String]) = { + var indexType = "" + val indexPropertyList: ArrayBuffer[String] = ArrayBuffer[String]() + val supportedIndexTypeList = getSupportedIndexTypeList(catalogName) + + if (!properties.isEmpty) { + properties.asScala.foreach { case (k, v) => + if (k.equals(SupportsIndex.PROP_TYPE)) { + if (containsIndexTypeIgnoreCase(supportedIndexTypeList, v)) { + indexType = s"USING $v" + } else { + throw new UnsupportedOperationException(s"Index Type $v is not supported." + + s" The supported Index Types are: ${supportedIndexTypeList.mkString(" AND ")}") + } + } else { + indexPropertyList.append(s"$k = $v") + } + } + } + (indexType, indexPropertyList.toArray) + } + + def containsIndexTypeIgnoreCase(supportedIndexTypeList: Array[String], value: String): Boolean = { + if (supportedIndexTypeList.isEmpty) { + throw new UnsupportedOperationException( + "Cannot specify 'USING index_type' in 'CREATE INDEX'") + } + for (indexType <- supportedIndexTypeList) { + if (value.equalsIgnoreCase(indexType)) return true + } + false + } + + def getSupportedIndexTypeList(catalogName: String): Array[String] = { + catalogName match { + case "mysql" => Array("BTREE", "HASH") + case "postgresql" => Array("BTREE", "HASH", "BRIN") + case _ => Array.empty + } + } + + def executeQuery(conn: Connection, options: JDBCOptions, sql: String)( + f: ResultSet => Unit): Unit = { + val statement = conn.createStatement + try { + statement.setQueryTimeout(options.queryTimeout) + val rs = statement.executeQuery(sql) + try { + f(rs) + } finally { + rs.close() + } + } finally { + statement.close() + } + } + + def classifyException[T](message: String, dialect: JdbcDialect)(f: => T): T = { + try { + f + } catch { + case e: Throwable => throw dialect.classifyException(message, e) + } + } + + def withConnection[T](options: JDBCOptions)(f: Connection => T): T = { + val dialect = JdbcDialects.get(options.url) + val conn = dialect.createConnectionFactory(options)(-1) + try { + f(conn) + } finally { + conn.close() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala index fbc69704f1479..ed8398f265848 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.jdbc.JdbcConnectionProvider import org.apache.spark.util.Utils -private[jdbc] object ConnectionProvider extends Logging { +protected abstract class ConnectionProviderBase extends Logging { private val providers = loadProviders() def loadProviders(): Seq[JdbcConnectionProvider] = { @@ -73,3 +73,5 @@ private[jdbc] object ConnectionProvider extends Logging { } } } + +private[sql] object ConnectionProvider extends ConnectionProviderBase diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index fa8977f239164..59a52b318622b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -68,6 +68,22 @@ class OrcDeserializer( resultRow } + def deserializeFromValues(orcValues: Seq[WritableComparable[_]]): InternalRow = { + var targetColumnIndex = 0 + while (targetColumnIndex < fieldWriters.length) { + if (fieldWriters(targetColumnIndex) != null) { + val value = orcValues(requestedColIds(targetColumnIndex)) + if (value == null) { + resultRow.setNullAt(targetColumnIndex) + } else { + fieldWriters(targetColumnIndex)(value) + } + } + targetColumnIndex += 1 + } + resultRow + } + /** * Creates a writer to write ORC values to Catalyst data structure at the given ordinal. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index a8647726fe022..7758d6a515b51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -24,17 +24,22 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} -import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription, Writer} +import org.apache.hadoop.hive.serde2.io.DateWritable +import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable} +import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer} -import org.apache.spark.SPARK_VERSION_SHORT +import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.SchemaMergeUtils +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils import org.apache.spark.sql.types._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -84,7 +89,7 @@ object OrcUtils extends Logging { } } - private def toCatalystSchema(schema: TypeDescription): StructType = { + def toCatalystSchema(schema: TypeDescription): StructType = { // The Spark query engine has not completely supported CHAR/VARCHAR type yet, and here we // replace the orc CHAR/VARCHAR with STRING type. CharVarcharUtils.replaceCharVarcharWithStringInSchema( @@ -259,4 +264,139 @@ object OrcUtils extends Logging { OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString) resultSchemaString } + + /** + * Checks if `dataType` supports columnar reads. + * + * @param dataType Data type of the orc files. + * @param nestedColumnEnabled True if columnar reads is enabled for nested column types. + * @return Returns true if data type supports columnar reads. + */ + def supportColumnarReads( + dataType: DataType, + nestedColumnEnabled: Boolean): Boolean = { + dataType match { + case _: AtomicType => true + case st: StructType if nestedColumnEnabled => + st.forall(f => supportColumnarReads(f.dataType, nestedColumnEnabled)) + case ArrayType(elementType, _) if nestedColumnEnabled => + supportColumnarReads(elementType, nestedColumnEnabled) + case MapType(keyType, valueType, _) if nestedColumnEnabled => + supportColumnarReads(keyType, nestedColumnEnabled) && + supportColumnarReads(valueType, nestedColumnEnabled) + case _ => false + } + } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data + * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates + * (Max/Min/Count) result using the statistics information from ORC file footer, and then + * construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + def createAggInternalRowFromFooter( + reader: Reader, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType): InternalRow = { + require(aggregation.groupByColumns.length == 0, + s"aggregate $aggregation with group-by column shouldn't be pushed down") + var columnsStatistics: OrcColumnStatistics = null + try { + columnsStatistics = OrcFooterReader.readStatistics(reader) + } catch { case e: Exception => + throw new SparkException( + s"Cannot read columns statistics in file: $filePath. Please consider disabling " + + s"ORC aggregate push down by setting 'spark.sql.orc.aggregatePushdown' to false.", e) + } + + // Get column statistics with column name. + def getColumnStatistics(columnName: String): ColumnStatistics = { + val columnIndex = dataSchema.fieldNames.indexOf(columnName) + columnsStatistics.get(columnIndex).getStatistics + } + + // Get Min/Max statistics and store as ORC `WritableComparable` format. + // Return null if number of non-null values is zero. + def getMinMaxFromColumnStatistics( + statistics: ColumnStatistics, + dataType: DataType, + isMax: Boolean): WritableComparable[_] = { + if (statistics.getNumberOfValues == 0) { + return null + } + + statistics match { + case s: BooleanColumnStatistics => + val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0) + new BooleanWritable(value) + case s: IntegerColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case ByteType => new ByteWritable(value.toByte) + case ShortType => new ShortWritable(value.toShort) + case IntegerType => new IntWritable(value.toInt) + case LongType => new LongWritable(value) + case _ => throw new IllegalArgumentException( + s"getMinMaxFromColumnStatistics should not take type $dataType " + + "for IntegerColumnStatistics") + } + case s: DoubleColumnStatistics => + val value = if (isMax) s.getMaximum else s.getMinimum + dataType match { + case FloatType => new FloatWritable(value.toFloat) + case DoubleType => new DoubleWritable(value) + case _ => throw new IllegalArgumentException( + s"getMinMaxFromColumnStatistics should not take type $dataType " + + "for DoubleColumnStatistics") + } + case s: DateColumnStatistics => + new DateWritable( + if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt) + case _ => throw new IllegalArgumentException( + s"getMinMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " + + s"$statistics as the ORC column statistics") + } + } + + val aggORCValues: Seq[WritableComparable[_]] = + aggregation.aggregateExpressions.zipWithIndex.map { + case (max: Max, index) if V2ColumnUtils.extractV2Column(max.column).isDefined => + val columnName = V2ColumnUtils.extractV2Column(max.column).get + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema(index).dataType + getMinMaxFromColumnStatistics(statistics, dataType, isMax = true) + case (min: Min, index) if V2ColumnUtils.extractV2Column(min.column).isDefined => + val columnName = V2ColumnUtils.extractV2Column(min.column).get + val statistics = getColumnStatistics(columnName) + val dataType = aggSchema.apply(index).dataType + getMinMaxFromColumnStatistics(statistics, dataType, isMax = false) + case (count: Count, _) if V2ColumnUtils.extractV2Column(count.column).isDefined => + val columnName = V2ColumnUtils.extractV2Column(count.column).get + val isPartitionColumn = partitionSchema.fields.map(_.name).contains(columnName) + // NOTE: Count(columnName) doesn't include null values. + // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values + // for ColumnStatistics of individual column. In addition to this, ORC also stores number + // of all values (null and non-null) separately. + val nonNullRowsCount = if (isPartitionColumn) { + columnsStatistics.getStatistics.getNumberOfValues + } else { + getColumnStatistics(columnName).getNumberOfValues + } + new LongWritable(nonNullRowsCount) + case (_: CountStar, _) => + // Count(*) includes both null and non-null values. + new LongWritable(columnsStatistics.getStatistics.getNumberOfValues) + case (x, _) => + throw new IllegalArgumentException( + s"createAggInternalRowFromFooter should not take $x as the aggregate expression") + } + + val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until aggSchema.length).toArray) + orcValuesDeserializer.deserializeFromValues(aggORCValues) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala index b91d75c55c513..f3836ab8b5ae4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala @@ -16,10 +16,24 @@ */ package org.apache.spark.sql.execution.datasources.parquet +import java.util + +import scala.collection.mutable +import scala.language.existentials + import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.parquet.hadoop.ParquetFileWriter +import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata} +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{PrimitiveType, Types} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName +import org.apache.spark.SparkException import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min} +import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils +import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED} import org.apache.spark.sql.types.StructType object ParquetUtils { @@ -127,4 +141,176 @@ object ParquetUtils { file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE || file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + + /** + * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to + * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want + * to get the partial aggregates (Max/Min/Count) result using the statistics information + * from Parquet footer file, and then construct an InternalRow from these aggregate results. + * + * @return Aggregate results in the format of InternalRow + */ + private[sql] def createAggInternalRowFromFooter( + footer: ParquetMetadata, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + aggSchema: StructType, + datetimeRebaseMode: LegacyBehaviorPolicy.Value, + isCaseSensitive: Boolean): InternalRow = { + val (primitiveTypes, values) = getPushedDownAggResult( + footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive) + + val builder = Types.buildMessage + primitiveTypes.foreach(t => builder.addField(t)) + val parquetSchema = builder.named("root") + + val schemaConverter = new ParquetToSparkSchemaConverter + val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema, + None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater) + val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName) + primitiveTypeNames.zipWithIndex.foreach { + case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) => + val v = values(i).asInstanceOf[Boolean] + converter.getConverter(i).asPrimitiveConverter.addBoolean(v) + case (PrimitiveType.PrimitiveTypeName.INT32, i) => + val v = values(i).asInstanceOf[Integer] + converter.getConverter(i).asPrimitiveConverter.addInt(v) + case (PrimitiveType.PrimitiveTypeName.INT64, i) => + val v = values(i).asInstanceOf[Long] + converter.getConverter(i).asPrimitiveConverter.addLong(v) + case (PrimitiveType.PrimitiveTypeName.FLOAT, i) => + val v = values(i).asInstanceOf[Float] + converter.getConverter(i).asPrimitiveConverter.addFloat(v) + case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) => + val v = values(i).asInstanceOf[Double] + converter.getConverter(i).asPrimitiveConverter.addDouble(v) + case (PrimitiveType.PrimitiveTypeName.BINARY, i) => + val v = values(i).asInstanceOf[Binary] + converter.getConverter(i).asPrimitiveConverter.addBinary(v) + case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) => + val v = values(i).asInstanceOf[Binary] + converter.getConverter(i).asPrimitiveConverter.addBinary(v) + case (_, i) => + throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i)) + } + converter.currentRecord + } + + /** + * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics + * information from Parquet footer file. + * + * @return A tuple of `Array[PrimitiveType]` and Array[Any]. + * The first element is the Parquet PrimitiveType of the aggregate column, + * and the second element is the aggregated value. + */ + private[sql] def getPushedDownAggResult( + footer: ParquetMetadata, + filePath: String, + dataSchema: StructType, + partitionSchema: StructType, + aggregation: Aggregation, + isCaseSensitive: Boolean) + : (Array[PrimitiveType], Array[Any]) = { + val footerFileMetaData = footer.getFileMetaData + val fields = footerFileMetaData.getSchema.getFields + val blocks = footer.getBlocks + val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType] + val valuesBuilder = mutable.ArrayBuilder.make[Any] + + assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down") + aggregation.aggregateExpressions.foreach { agg => + var value: Any = None + var rowCount = 0L + var isCount = false + var index = 0 + var schemaName = "" + blocks.forEach { block => + val blockMetaData = block.getColumns + agg match { + case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined => + val colName = V2ColumnUtils.extractV2Column(max.column).get + index = dataSchema.fieldNames.toList.indexOf(colName) + schemaName = "max(" + colName + ")" + val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true) + if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) { + value = currentMax + } + case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined => + val colName = V2ColumnUtils.extractV2Column(min.column).get + index = dataSchema.fieldNames.toList.indexOf(colName) + schemaName = "min(" + colName + ")" + val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false) + if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) { + value = currentMin + } + case count: Count if V2ColumnUtils.extractV2Column(count.column).isDefined => + val colName = V2ColumnUtils.extractV2Column(count.column).get + schemaName = "count(" + colName + ")" + rowCount += block.getRowCount + var isPartitionCol = false + if (partitionSchema.fields.map(_.name).toSet.contains(colName)) { + isPartitionCol = true + } + isCount = true + if (!isPartitionCol) { + index = dataSchema.fieldNames.toList.indexOf(colName) + // Count(*) includes the null values, but Count(colName) doesn't. + rowCount -= getNumNulls(filePath, blockMetaData, index) + } + case _: CountStar => + schemaName = "count(*)" + rowCount += block.getRowCount + isCount = true + case _ => + } + } + if (isCount) { + valuesBuilder += rowCount + primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName); + } else { + valuesBuilder += value + val field = fields.get(index) + primitiveTypeBuilder += Types.required(field.asPrimitiveType.getPrimitiveTypeName) + .as(field.getLogicalTypeAnnotation) + .length(field.asPrimitiveType.getTypeLength) + .named(schemaName) + } + } + (primitiveTypeBuilder.result, valuesBuilder.result) + } + + /** + * Get the Max or Min value for ith column in the current block + * + * @return the Max or Min value + */ + private def getCurrentBlockMaxOrMin( + filePath: String, + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int, + isMax: Boolean): Any = { + val statistics = columnChunkMetaData.get(i).getStatistics + if (!statistics.hasNonNullValue) { + throw new UnsupportedOperationException(s"No min/max found for Parquet file $filePath. " + + s"Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again") + } else { + if (isMax) statistics.genericGetMax else statistics.genericGetMin + } + } + + private def getNumNulls( + filePath: String, + columnChunkMetaData: util.List[ColumnChunkMetaData], + i: Int): Long = { + val statistics = columnChunkMetaData.get(i).getStatistics + if (!statistics.isNumNullsSet) { + throw new UnsupportedOperationException(s"Number of nulls not set for Parquet file" + + s" $filePath. Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute" + + s" again") + } + statistics.getNumNulls; + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 1a50c320ea3e3..f267a03cbe218 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,14 +18,17 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable} -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, PredicateHelper, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.catalyst.util.{toPrettySQL, V2ExpressionBuilder} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog} +import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate} import org.apache.spark.sql.connector.read.LocalScan import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.connector.write.V1Write @@ -86,8 +89,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(project, filters, - DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate), output)) => + case PhysicalOperation(project, filters, DataSourceV2ScanRelation( + _, V1ScanWrapper(scan, pushed, pushedDownOperators), output)) => val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext) if (v1Relation.schema != scan.readSchema()) { throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError( @@ -95,12 +98,13 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } val rdd = v1Relation.buildScan() val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd) + val dsScan = RowDataSourceScanExec( output, output.toStructType, Set.empty, pushed.toSet, - aggregate, + pushedDownOperators, unsafeRowRDD, v1Relation, tableIdentifier = None) @@ -427,3 +431,112 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case _ => Nil } } + +private[sql] object DataSourceV2Strategy { + + private def translateLeafNodeFilterV2( + predicate: Expression, + supportNestedPredicatePushdown: Boolean): Option[Predicate] = { + val pushablePredicate = PushablePredicate(supportNestedPredicatePushdown) + predicate match { + case pushablePredicate(expr) => Some(expr) + case _ => None + } + } + + /** + * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. + * + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + */ + protected[sql] def translateFilterV2( + predicate: Expression, + supportNestedPredicatePushdown: Boolean): Option[Predicate] = { + translateFilterV2WithMapping(predicate, None, supportNestedPredicatePushdown) + } + + /** + * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. + * + * @param predicate The input [[Expression]] to be translated as [[Filter]] + * @param translatedFilterToExpr An optional map from leaf node filter expressions to its + * translated [[Filter]]. The map is used for rebuilding + * [[Expression]] from [[Filter]]. + * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`. + */ + protected[sql] def translateFilterV2WithMapping( + predicate: Expression, + translatedFilterToExpr: Option[mutable.HashMap[Predicate, Expression]], + nestedPredicatePushdownEnabled: Boolean) + : Option[Predicate] = { + predicate match { + case And(left, right) => + // See SPARK-12218 for detailed discussion + // It is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have (a = 2 AND trim(b) = 'blah') OR (c > 0) + // and we do not understand how to convert trim(b) = 'blah'. + // If we only convert a = 2, we will end up with + // (a = 2) OR (c > 0), which will generate wrong results. + // Pushing one leg of AND down is only safe to do at the top level. + // You can see ParquetFilters' createFilter for more details. + for { + leftFilter <- translateFilterV2WithMapping( + left, translatedFilterToExpr, nestedPredicatePushdownEnabled) + rightFilter <- translateFilterV2WithMapping( + right, translatedFilterToExpr, nestedPredicatePushdownEnabled) + } yield new V2And(leftFilter, rightFilter) + + case Or(left, right) => + for { + leftFilter <- translateFilterV2WithMapping( + left, translatedFilterToExpr, nestedPredicatePushdownEnabled) + rightFilter <- translateFilterV2WithMapping( + right, translatedFilterToExpr, nestedPredicatePushdownEnabled) + } yield new V2Or(leftFilter, rightFilter) + + case Not(child) => + translateFilterV2WithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled) + .map(new V2Not(_)) + + case other => + val filter = translateLeafNodeFilterV2(other, nestedPredicatePushdownEnabled) + if (filter.isDefined && translatedFilterToExpr.isDefined) { + translatedFilterToExpr.get(filter.get) = predicate + } + filter + } + } + + protected[sql] def rebuildExpressionFromFilter( + predicate: Predicate, + translatedFilterToExpr: mutable.HashMap[Predicate, Expression]): Expression = { + predicate match { + case and: V2And => + expressions.And( + rebuildExpressionFromFilter(and.left(), translatedFilterToExpr), + rebuildExpressionFromFilter(and.right(), translatedFilterToExpr)) + case or: V2Or => + expressions.Or( + rebuildExpressionFromFilter(or.left(), translatedFilterToExpr), + rebuildExpressionFromFilter(or.right(), translatedFilterToExpr)) + case not: V2Not => + expressions.Not(rebuildExpressionFromFilter(not.child(), translatedFilterToExpr)) + case _ => + translatedFilterToExpr.getOrElse(predicate, + throw new IllegalStateException("Failed to rebuild Expression for filter: " + predicate)) + } + } +} + +/** + * Get the expression of DS V2 to represent catalyst predicate that can be pushed down. + */ +case class PushablePredicate(nestedPredicatePushdownEnabled: Boolean) { + + def unapply(e: Expression): Option[Predicate] = + new V2ExpressionBuilder(e, nestedPredicatePushdownEnabled, true).build().map { v => + assert(v.isInstanceOf[Predicate]) + v.asInstanceOf[Predicate] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala index dbd5cbd874945..5d302055e7d91 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.connector.catalog.CatalogPlugin -import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.errors.QueryCompilationErrors /** * Physical plan node for dropping a namespace. @@ -37,17 +38,11 @@ case class DropNamespaceExec( val nsCatalog = catalog.asNamespaceCatalog val ns = namespace.toArray if (nsCatalog.namespaceExists(ns)) { - // The default behavior of `SupportsNamespace.dropNamespace()` is cascading, - // so make sure the namespace to drop is empty. - if (!cascade) { - if (catalog.asTableCatalog.listTables(ns).nonEmpty - || nsCatalog.listNamespaces(ns).nonEmpty) { - throw QueryExecutionErrors.cannotDropNonemptyNamespaceError(namespace) - } - } - - if (!nsCatalog.dropNamespace(ns)) { - throw QueryExecutionErrors.cannotDropNonemptyNamespaceError(namespace) + try { + nsCatalog.dropNamespace(ns, cascade) + } catch { + case _: NonEmptyNamespaceException => + throw QueryCompilationErrors.cannotDropNonemptyNamespaceError(namespace) } } else if (!ifExists) { throw QueryCompilationErrors.noSuchNamespaceError(ns) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 4506bd3d49b5b..8b0328cabc5a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -49,6 +49,8 @@ trait FileScan extends Scan def fileIndex: PartitioningAwareFileIndex + def dataSchema: StructType + /** * Returns the required data schema */ @@ -69,12 +71,6 @@ trait FileScan extends Scan */ def dataFilters: Seq[Expression] - /** - * Create a new `FileScan` instance from the current one - * with different `partitionFilters` and `dataFilters` - */ - def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan - /** * If a file with `path` is unsplittable, return the unsplittable reason, * otherwise return `None`. @@ -187,7 +183,10 @@ trait FileScan extends Scan new Statistics { override def sizeInBytes(): OptionalLong = { val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor - val size = (compressionFactor * fileIndex.sizeInBytes).toLong + val size = (compressionFactor * fileIndex.sizeInBytes / + (dataSchema.defaultSize + fileIndex.partitionSchema.defaultSize) * + (readDataSchema.defaultSize + readPartitionSchema.defaultSize)).toLong + OptionalLong.of(size) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala index 97874e8f4932e..2dc4137d6f9a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala @@ -16,19 +16,30 @@ */ package org.apache.spark.sql.execution.datasources.v2 -import org.apache.spark.sql.SparkSession +import scala.collection.mutable + +import org.apache.spark.sql.{sources, SparkSession} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils} +import org.apache.spark.sql.internal.connector.SupportsPushDownCatalystFilters +import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType abstract class FileScanBuilder( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, - dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns { + dataSchema: StructType) + extends ScanBuilder + with SupportsPushDownRequiredColumns + with SupportsPushDownCatalystFilters { private val partitionSchema = fileIndex.partitionSchema private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis protected val supportsNestedSchemaPruning = false protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields) + protected var partitionFilters = Seq.empty[Expression] + protected var dataFilters = Seq.empty[Expression] + protected var pushedDataFilters = Array.empty[Filter] override def pruneColumns(requiredSchema: StructType): Unit = { // [SPARK-30107] While `requiredSchema` might have pruned nested columns, @@ -48,7 +59,7 @@ abstract class FileScanBuilder( StructType(fields) } - protected def readPartitionSchema(): StructType = { + def readPartitionSchema(): StructType = { val requiredNameSet = createRequiredNameSet() val fields = partitionSchema.fields.filter { field => val colName = PartitioningUtils.getColName(field, isCaseSensitive) @@ -57,9 +68,34 @@ abstract class FileScanBuilder( StructType(fields) } + override def pushFilters(filters: Seq[Expression]): Seq[Expression] = { + val (partitionFilters, dataFilters) = + DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filters) + this.partitionFilters = partitionFilters + this.dataFilters = dataFilters + val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter] + for (filterExpr <- dataFilters) { + val translated = DataSourceStrategy.translateFilter(filterExpr, true) + if (translated.nonEmpty) { + translatedFilters += translated.get + } + } + pushedDataFilters = pushDataFilters(translatedFilters.toArray) + dataFilters + } + + override def pushedFilters: Array[Filter] = pushedDataFilters + + /* + * Push down data filters to the file source, so the data filters can be evaluated there to + * reduce the size of the data to be read. By default, data filters are not pushed down. + * File source needs to implement this method to push down data filters. + */ + protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter] + private def createRequiredNameSet(): Set[String] = requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet - private val partitionNameSet: Set[String] = + val partitionNameSet: Set[String] = partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index acc645741819e..2adbd5cf007e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,14 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.expressions.FieldReference -import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType @@ -38,9 +35,8 @@ object PushDownUtils extends PredicateHelper { * * @return pushed filter and post-scan filters. */ - def pushFilters( - scanBuilder: ScanBuilder, - filters: Seq[Expression]): (Seq[sources.Filter], Seq[Expression]) = { + def pushFilters(scanBuilder: ScanBuilder, filters: Seq[Expression]) + : (Either[Seq[sources.Filter], Seq[Predicate]], Seq[Expression]) = { scanBuilder match { case r: SupportsPushDownFilters => // A map from translated data source leaf node filters to original catalyst filter @@ -69,41 +65,79 @@ object PushDownUtils extends PredicateHelper { val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter => DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr) } - (r.pushedFilters(), (untranslatableExprs ++ postScanFilters).toSeq) + (Left(r.pushedFilters()), (untranslatableExprs ++ postScanFilters).toSeq) + + case r: SupportsPushDownV2Filters => + // A map from translated data source leaf node filters to original catalyst filter + // expressions. For a `And`/`Or` predicate, it is possible that the predicate is partially + // pushed down. This map can be used to construct a catalyst filter expression from the + // input filter, or a superset(partial push down filter) of the input filter. + val translatedFilterToExpr = mutable.HashMap.empty[Predicate, Expression] + val translatedFilters = mutable.ArrayBuffer.empty[Predicate] + // Catalyst filter expression that can't be translated to data source filters. + val untranslatableExprs = mutable.ArrayBuffer.empty[Expression] + + for (filterExpr <- filters) { + val translated = + DataSourceV2Strategy.translateFilterV2WithMapping( + filterExpr, Some(translatedFilterToExpr), nestedPredicatePushdownEnabled = true) + if (translated.isEmpty) { + untranslatableExprs += filterExpr + } else { + translatedFilters += translated.get + } + } + + // Data source filters that need to be evaluated again after scanning. which means + // the data source cannot guarantee the rows returned can pass these filters. + // As a result we must return it so Spark can plan an extra filter operator. + val postScanFilters = r.pushPredicates(translatedFilters.toArray).map { predicate => + DataSourceV2Strategy.rebuildExpressionFromFilter(predicate, translatedFilterToExpr) + } + (Right(r.pushedPredicates), (untranslatableExprs ++ postScanFilters).toSeq) - case _ => (Nil, filters) + case f: FileScanBuilder => + val postScanFilters = f.pushFilters(filters) + (Left(f.pushedFilters), postScanFilters) + + case _ => (Left(Nil), filters) } } /** - * Pushes down aggregates to the data source reader - * - * @return pushed aggregation. + * Pushes down TableSample to the data source Scan */ - def pushAggregates( - scanBuilder: ScanBuilder, - aggregates: Seq[AggregateExpression], - groupBy: Seq[Expression]): Option[Aggregation] = { - - def columnAsString(e: Expression): Option[FieldReference] = e match { - case PushableColumnWithoutNestedColumn(name) => - Some(FieldReference(name).asInstanceOf[FieldReference]) - case _ => None + def pushTableSample(scanBuilder: ScanBuilder, sample: TableSampleInfo): Boolean = { + scanBuilder match { + case s: SupportsPushDownTableSample => + s.pushTableSample( + sample.lowerBound, sample.upperBound, sample.withReplacement, sample.seed) + case _ => false } + } + /** + * Pushes down LIMIT to the data source Scan + */ + def pushLimit(scanBuilder: ScanBuilder, limit: Int): Boolean = { scanBuilder match { - case r: SupportsPushDownAggregates if aggregates.nonEmpty => - val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate) - val translatedGroupBys = groupBy.flatMap(columnAsString) - - if (translatedAggregates.length != aggregates.length || - translatedGroupBys.length != groupBy.length) { - return None - } + case s: SupportsPushDownLimit => + s.pushLimit(limit) + case _ => false + } + } - val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray) - Some(agg).filter(r.pushAggregation) - case _ => None + /** + * Pushes down top N to the data source Scan + */ + def pushTopN( + scanBuilder: ScanBuilder, + order: Array[SortOrder], + limit: Int): (Boolean, Boolean) = { + scanBuilder match { + case s: SupportsPushDownTopN if s.pushTopN(order, limit) => + (true, s.isPartiallyPushed) + case _ => (false, false) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala new file mode 100644 index 0000000000000..a95b4593fc397 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.expressions.filter.Predicate + +/** + * Pushed down operators + */ +case class PushedDownOperators( + aggregation: Option[Aggregation], + sample: Option[TableSampleInfo], + limit: Option[Int], + sortValues: Seq[SortOrder], + pushedPredicates: Seq[Predicate]) { + assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala new file mode 100644 index 0000000000000..cb4fb9eb0809a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +case class TableSampleInfo( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala new file mode 100644 index 0000000000000..9fc220f440bc1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.connector.expressions.{Expression, NamedReference} + +object V2ColumnUtils { + def extractV2Column(expr: Expression): Option[String] = expr match { + case r: NamedReference if r. fieldNames.length == 1 => Some(r.fieldNames.head) + case _ => None + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index 046155b55cc2d..cdcae15ef4e24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -19,24 +19,28 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.planning.ScanOperation -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder} +import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum} +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, LongType, StructType} import org.apache.spark.sql.util.SchemaUtils._ -object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { +object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper { import DataSourceV2Implicits._ def apply(plan: LogicalPlan): LogicalPlan = { - applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan)))) + applyColumnPruning( + applyLimit(pushDownAggregates(pushDownFilters(pushDownSample(createScanBuilder(plan)))))) } private def createScanBuilder(plan: LogicalPlan) = plan.transform { @@ -58,12 +62,19 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters( sHolder.builder, normalizedFiltersWithoutSubquery) + val pushedFiltersStr = if (pushedFilters.isLeft) { + pushedFilters.left.get.mkString(", ") + } else { + sHolder.pushedPredicates = pushedFilters.right.get + pushedFilters.right.get.mkString(", ") + } + val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery logInfo( s""" |Pushing operators to ${sHolder.relation.name} - |Pushed Filters: ${pushedFilters.mkString(", ")} + |Pushed Filters: $pushedFiltersStr |Post-Scan Filters: ${postScanFilters.mkString(",")} """.stripMargin) @@ -76,103 +87,168 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) => child match { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) - if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) => + if filters.isEmpty && CollapseProject.canCollapseExpressions( + resultExpressions, project, alwaysInline = true) => sHolder.builder match { - case _: SupportsPushDownAggregates => + case r: SupportsPushDownAggregates => + val aliasMap = getAliasMap(project) + val actualResultExprs = resultExpressions.map(replaceAliasButKeepName(_, aliasMap)) + val actualGroupExprs = groupingExpressions.map(replaceAlias(_, aliasMap)) + val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int] - var ordinal = 0 - val aggregates = resultExpressions.flatMap { expr => - expr.collect { - // Do not push down duplicated aggregate expressions. For example, - // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one - // `max(a)` to the data source. - case agg: AggregateExpression - if !aggExprToOutputOrdinal.contains(agg.canonicalized) => - aggExprToOutputOrdinal(agg.canonicalized) = ordinal - ordinal += 1 - agg - } - } + val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal) val normalizedAggregates = DataSourceStrategy.normalizeExprs( aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs( - groupingExpressions, sHolder.relation.output) - val pushedAggregates = PushDownUtils.pushAggregates( - sHolder.builder, normalizedAggregates, normalizedGroupingExpressions) - if (pushedAggregates.isEmpty) { + actualGroupExprs, sHolder.relation.output) + val translatedAggregates = DataSourceStrategy.translateAggregation( + normalizedAggregates, normalizedGroupingExpressions) + val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = { + if (translatedAggregates.isEmpty || + r.supportCompletePushDown(translatedAggregates.get) || + translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) { + (actualResultExprs, aggregates, translatedAggregates) + } else { + // scalastyle:off + // The data source doesn't support the complete push-down of this aggregation. + // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be + // pushed, completely or partially. + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT avg(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // + // After convert avg(c1#9) to sum(c1#9)/count(c1#9) + // we have the following + // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19] + // +- ScanOperation[...] + // scalastyle:on + val newResultExpressions = actualResultExprs.map { expr => + expr.transform { + case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) => + val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct) + val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct) + avg.evaluateExpression transform { + case a: Attribute if a.semanticEquals(avg.sum) => + addCastIfNeeded(sum, avg.sum.dataType) + case a: Attribute if a.semanticEquals(avg.count) => + addCastIfNeeded(count, avg.count.dataType) + } + } + }.asInstanceOf[Seq[NamedExpression]] + // Because aggregate expressions changed, translate them again. + aggExprToOutputOrdinal.clear() + val newAggregates = + collectAggregates(newResultExpressions, aggExprToOutputOrdinal) + val newNormalizedAggregates = DataSourceStrategy.normalizeExprs( + newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]] + (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation( + newNormalizedAggregates, normalizedGroupingExpressions)) + } + } + + if (finalTranslatedAggregates.isEmpty) { + aggNode // return original plan node + } else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) && + !supportPartialAggPushDown(finalTranslatedAggregates.get)) { aggNode // return original plan node } else { - // No need to do column pruning because only the aggregate columns are used as - // DataSourceV2ScanRelation output columns. All the other columns are not - // included in the output. - val scan = sHolder.builder.build() - - // scalastyle:off - // use the group by columns and aggregate columns as the output columns - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation - // We want to have the following logical plan: - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] - // scalastyle:on - val newOutput = scan.readSchema().toAttributes - assert(newOutput.length == groupingExpressions.length + aggregates.length) - val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { - case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) - case (_, b) => b - } - val output = groupAttrs ++ newOutput.drop(groupAttrs.length) - - logInfo( - s""" - |Pushing operators to ${sHolder.relation.name} - |Pushed Aggregate Functions: - | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} - |Pushed Group by: - | ${pushedAggregates.get.groupByColumns.mkString(", ")} - |Output: ${output.mkString(", ")} + val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation) + if (pushedAggregates.isEmpty) { + aggNode // return original plan node + } else { + // No need to do column pruning because only the aggregate columns are used as + // DataSourceV2ScanRelation output columns. All the other columns are not + // included in the output. + val scan = sHolder.builder.build() + + // scalastyle:off + // use the group by columns and aggregate columns as the output columns + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation + // We want to have the following logical plan: + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] + // scalastyle:on + val newOutput = scan.readSchema().toAttributes + assert(newOutput.length == groupingExpressions.length + finalAggregates.length) + val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map { + case (a: Attribute, b: Attribute) => b.withExprId(a.exprId) + case (_, b) => b + } + val aggOutput = newOutput.drop(groupAttrs.length) + val output = groupAttrs ++ aggOutput + + logInfo( + s""" + |Pushing operators to ${sHolder.relation.name} + |Pushed Aggregate Functions: + | ${pushedAggregates.get.aggregateExpressions.mkString(", ")} + |Pushed Group by: + | ${pushedAggregates.get.groupByColumns.mkString(", ")} + |Output: ${output.mkString(", ")} """.stripMargin) - val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) - - val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) - - val plan = Aggregate( - output.take(groupingExpressions.length), resultExpressions, scanRelation) - - // scalastyle:off - // Change the optimized logical plan to reflect the pushed down aggregate - // e.g. TABLE t (c1 INT, c2 INT, c3 INT) - // SELECT min(c1), max(c1) FROM t GROUP BY c2; - // The original logical plan is - // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c1#9, c2#10] ... - // - // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] - // we have the following - // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // - // We want to change it to - // == Optimized Logical Plan == - // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] - // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... - // scalastyle:on - val aggOutput = output.drop(groupAttrs.length) - plan.transformExpressions { - case agg: AggregateExpression => - val ordinal = aggExprToOutputOrdinal(agg.canonicalized) - val aggFunction: aggregate.AggregateFunction = - agg.aggregateFunction match { - case max: aggregate.Max => max.copy(child = aggOutput(ordinal)) - case min: aggregate.Min => min.copy(child = aggOutput(ordinal)) - case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal)) - case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal)) - case other => other + val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates) + val scanRelation = + DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output) + if (r.supportCompletePushDown(pushedAggregates.get)) { + val projectExpressions = finalResultExpressions.map { expr => + // TODO At present, only push down group by attribute is supported. + // In future, more attribute conversion is extended here. e.g. GetStructField + expr.transform { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val child = + addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType) + Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId) } - agg.copy(aggregateFunction = aggFunction) + }.asInstanceOf[Seq[NamedExpression]] + Project(projectExpressions, scanRelation) + } else { + val plan = Aggregate(output.take(groupingExpressions.length), + finalResultExpressions, scanRelation) + + // scalastyle:off + // Change the optimized logical plan to reflect the pushed down aggregate + // e.g. TABLE t (c1 INT, c2 INT, c3 INT) + // SELECT min(c1), max(c1) FROM t GROUP BY c2; + // The original logical plan is + // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c1#9, c2#10] ... + // + // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22] + // we have the following + // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // + // We want to change it to + // == Optimized Logical Plan == + // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18] + // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ... + // scalastyle:on + plan.transformExpressions { + case agg: AggregateExpression => + val ordinal = aggExprToOutputOrdinal(agg.canonicalized) + val aggAttribute = aggOutput(ordinal) + val aggFunction: aggregate.AggregateFunction = + agg.aggregateFunction match { + case max: aggregate.Max => + max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType)) + case min: aggregate.Min => + min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType)) + case sum: aggregate.Sum => + sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType)) + case _: aggregate.Count => + aggregate.Sum(addCastIfNeeded(aggAttribute, LongType)) + case other => other + } + agg.copy(aggregateFunction = aggFunction) + } + } } } case _ => aggNode @@ -181,6 +257,42 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { } } + private def collectAggregates(resultExpressions: Seq[NamedExpression], + aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = { + var ordinal = 0 + resultExpressions.flatMap { expr => + expr.collect { + // Do not push down duplicated aggregate expressions. For example, + // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one + // `max(a)` to the data source. + case agg: AggregateExpression + if !aggExprToOutputOrdinal.contains(agg.canonicalized) => + aggExprToOutputOrdinal(agg.canonicalized) = ordinal + ordinal += 1 + agg + } + } + } + + private def supportPartialAggPushDown(agg: Aggregation): Boolean = { + // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down. + // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down. + agg.aggregateExpressions().exists { + case sum: Sum => !sum.isDistinct + case count: Count => !count.isDistinct + case avg: Avg => !avg.isDistinct + case _: GeneralAggregateFunc => false + case _ => true + } + } + + private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) = + if (expression.dataType == expectedDataType) { + expression + } else { + Cast(expression, expectedDataType) + } + def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform { case ScanOperation(project, filters, sHolder: ScanBuilderHolder) => // column pruning @@ -219,6 +331,69 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { withProjection } + def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform { + case sample: Sample => sample.child match { + case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + val tableSample = TableSampleInfo( + sample.lowerBound, + sample.upperBound, + sample.withReplacement, + sample.seed) + val pushed = PushDownUtils.pushTableSample(sHolder.builder, tableSample) + if (pushed) { + sHolder.pushedSample = Some(tableSample) + sample.child + } else { + sample + } + + case _ => sample + } + } + + private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match { + case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty => + val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit) + if (limitPushed) { + sHolder.pushedLimit = Some(limit) + } + operation + case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder)) + if filter.isEmpty && CollapseProject.canCollapseExpressions( + order, project, alwaysInline = true) => + val aliasMap = getAliasMap(project) + val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]] + val orders = DataSourceStrategy.translateSortOrders(newOrder) + if (orders.length == order.length) { + val (isPushed, isPartiallyPushed) = + PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit) + if (isPushed) { + sHolder.pushedLimit = Some(limit) + sHolder.sortOrders = orders + if (isPartiallyPushed) { + s + } else { + operation + } + } else { + s + } + } else { + s + } + case p: Project => + val newChild = pushDownLimit(p.child, limit) + p.withNewChildren(Seq(newChild)) + case other => other + } + + def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform { + case globalLimit @ Limit(IntegerLiteral(limitValue), child) => + val newChild = pushDownLimit(child, limitValue) + val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild)) + globalLimit.withNewChildren(Seq(newLocalLimit)) + } + private def getWrappedScan( scan: Scan, sHolder: ScanBuilderHolder, @@ -230,7 +405,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { f.pushedFilters() case _ => Array.empty[sources.Filter] } - V1ScanWrapper(v1, pushedFilters, aggregation) + val pushedDownOperators = PushedDownOperators(aggregation, sHolder.pushedSample, + sHolder.pushedLimit, sHolder.sortOrders, sHolder.pushedPredicates) + V1ScanWrapper(v1, pushedFilters, pushedDownOperators) case _ => scan } } @@ -239,13 +416,22 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { case class ScanBuilderHolder( output: Seq[AttributeReference], relation: DataSourceV2Relation, - builder: ScanBuilder) extends LeafNode + builder: ScanBuilder) extends LeafNode { + var pushedLimit: Option[Int] = None + + var sortOrders: Seq[V2SortOrder] = Seq.empty[V2SortOrder] + + var pushedSample: Option[TableSampleInfo] = None + + var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate] +} + -// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by -// the physical v1 scan node. +// A wrapper for v1 scan to carry the translated filters and the handled ones, along with +// other pushed down operators. This is required by the physical v1 scan node. case class V1ScanWrapper( v1Scan: V1Scan, handledFilters: Seq[sources.Filter], - pushedAggregate: Option[Aggregation]) extends Scan { + pushedDownOperators: PushedDownOperators) extends Scan { override def readSchema(): StructType = v1Scan.readSchema() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala index 33b8f22e3f88a..fe91cc486967b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala @@ -261,12 +261,11 @@ class V2SessionCatalog(catalog: SessionCatalog) } } - override def dropNamespace(namespace: Array[String]): Boolean = namespace match { + override def dropNamespace( + namespace: Array[String], + cascade: Boolean): Boolean = namespace match { case Array(db) if catalog.databaseExists(db) => - if (catalog.listTables(db).nonEmpty) { - throw QueryExecutionErrors.namespaceNotEmptyError(namespace) - } - catalog.dropDatabase(db, ignoreIfNotExists = false, cascade = false) + catalog.dropDatabase(db, ignoreIfNotExists = false, cascade) true case Array(_) => @@ -293,8 +292,8 @@ private[sql] object V2SessionCatalog { case IdentityTransform(FieldReference(Seq(col))) => identityCols += col - case BucketTransform(numBuckets, FieldReference(Seq(col))) => - bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil)) + case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) => + bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil)) case transform => throw QueryExecutionErrors.unsupportedPartitionTransformError(transform) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala index 3f77b2147f9ca..cc3c146106670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.csv.CSVDataSource -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -84,10 +84,6 @@ case class CSVScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options && equivalentFilters(pushedFilters, c.pushedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala index f7a79bf31948e..2b6edd4f357ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.csv import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -32,7 +32,7 @@ case class CSVScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { CSVScan( @@ -42,17 +42,16 @@ case class CSVScanBuilder( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.csvFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala index ef42691e5ca94..f68f78d51fd96 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala @@ -18,17 +18,23 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.connector.expressions.SortOrder +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read.V1Scan import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation -import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan} +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo +import org.apache.spark.sql.sources.{BaseRelation, TableScan} import org.apache.spark.sql.types.StructType case class JDBCScan( relation: JDBCRelation, prunedSchema: StructType, - pushedFilters: Array[Filter], + pushedPredicates: Array[Predicate], pushedAggregateColumn: Array[String] = Array(), - groupByColumns: Option[Array[String]]) extends V1Scan { + groupByColumns: Option[Array[String]], + tableSample: Option[TableSampleInfo], + pushedLimit: Int, + sortOrders: Array[SortOrder]) extends V1Scan { override def readSchema(): StructType = prunedSchema @@ -43,7 +49,8 @@ case class JDBCScan( } else { pushedAggregateColumn } - relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns) + relation.buildScan(columnList, prunedSchema, pushedPredicates, groupByColumns, tableSample, + pushedLimit, sortOrders) } }.asInstanceOf[T] } @@ -57,7 +64,7 @@ case class JDBCScan( ("[]", "[]") } super.description() + ", prunedSchema: " + seqToString(prunedSchema) + - ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedPredicates: " + seqToString(pushedPredicates) + ", PushedAggregates: " + aggString + ", PushedGroupBy: " + groupByString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala index b0de7c015c91a..0a1542a42956d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala @@ -20,12 +20,14 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.connector.expressions.SortOrder import org.apache.spark.sql.connector.expressions.aggregate.Aggregation -import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters} import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation} +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.jdbc.JdbcDialects -import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType case class JDBCScanBuilder( @@ -33,40 +35,56 @@ case class JDBCScanBuilder( schema: StructType, jdbcOptions: JDBCOptions) extends ScanBuilder - with SupportsPushDownFilters + with SupportsPushDownV2Filters with SupportsPushDownRequiredColumns with SupportsPushDownAggregates + with SupportsPushDownLimit + with SupportsPushDownTableSample + with SupportsPushDownTopN with Logging { private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis - private var pushedFilter = Array.empty[Filter] + private var pushedPredicate = Array.empty[Predicate] private var finalSchema = schema - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + private var tableSample: Option[TableSampleInfo] = None + + private var pushedLimit = 0 + + private var sortOrders: Array[SortOrder] = Array.empty[SortOrder] + + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { if (jdbcOptions.pushDownPredicate) { val dialect = JdbcDialects.get(jdbcOptions.url) - val (pushed, unSupported) = filters.partition(JDBCRDD.compileFilter(_, dialect).isDefined) - this.pushedFilter = pushed + val (pushed, unSupported) = predicates.partition(dialect.compileExpression(_).isDefined) + this.pushedPredicate = pushed unSupported } else { - filters + predicates } } - override def pushedFilters(): Array[Filter] = pushedFilter + override def pushedPredicates(): Array[Predicate] = pushedPredicate private var pushedAggregateList: Array[String] = Array() private var pushedGroupByCols: Option[Array[String]] = None + override def supportCompletePushDown(aggregation: Aggregation): Boolean = { + lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames() + jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) || + (aggregation.groupByColumns().length == 1 && fieldNames.length == 1 && + jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_))) + } + override def pushAggregation(aggregation: Aggregation): Boolean = { if (!jdbcOptions.pushDownAggregate) return false val dialect = JdbcDialects.get(jdbcOptions.url) - val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect) - if (compiledAgg.isEmpty) return false + val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate) + if (compiledAggs.length != aggregation.aggregateExpressions.length) return false val groupByCols = aggregation.groupByColumns.map { col => if (col.fieldNames.length != 1) return false @@ -77,7 +95,7 @@ case class JDBCScanBuilder( // e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") => // SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee" // GROUP BY "DEPT", "NAME" - val selectList = groupByCols ++ compiledAgg.get + val selectList = groupByCols ++ compiledAggs val groupByClause = if (groupByCols.isEmpty) { "" } else { @@ -98,6 +116,38 @@ case class JDBCScanBuilder( } } + override def pushTableSample( + lowerBound: Double, + upperBound: Double, + withReplacement: Boolean, + seed: Long): Boolean = { + if (jdbcOptions.pushDownTableSample && + JdbcDialects.get(jdbcOptions.url).supportsTableSample) { + this.tableSample = Some(TableSampleInfo(lowerBound, upperBound, withReplacement, seed)) + return true + } + false + } + + override def pushLimit(limit: Int): Boolean = { + if (jdbcOptions.pushDownLimit) { + pushedLimit = limit + return true + } + false + } + + override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = { + if (jdbcOptions.pushDownLimit) { + pushedLimit = limit + sortOrders = orders + return true + } + false + } + + override def isPartiallyPushed(): Boolean = jdbcOptions.numPartitions.map(_ > 1).getOrElse(false) + override def pruneColumns(requiredSchema: StructType): Unit = { // JDBC doesn't support nested column pruning. // TODO (SPARK-32593): JDBC support nested column and nested column pruning. @@ -122,7 +172,7 @@ case class JDBCScanBuilder( // "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column names from // prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't // be used in sql string. - JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter, - pushedAggregateList, pushedGroupByCols) + JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate, + pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala index 5e11ea66be4c6..793b72727b9ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala @@ -23,13 +23,16 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.TableCapability._ +import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex} +import org.apache.spark.sql.connector.expressions.NamedReference import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder} -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JdbcUtils} +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOptions) - extends Table with SupportsRead with SupportsWrite { + extends Table with SupportsRead with SupportsWrite with SupportsIndex { override def name(): String = ident.toString @@ -48,4 +51,39 @@ case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOpt jdbcOptions.parameters.originalMap ++ info.options.asCaseSensitiveMap().asScala) JDBCWriteBuilder(schema, mergedOptions) } + + override def createIndex( + indexName: String, + columns: Array[NamedReference], + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String]): Unit = { + JdbcUtils.withConnection(jdbcOptions) { conn => + JdbcUtils.classifyException(s"Failed to create index $indexName in $name", + JdbcDialects.get(jdbcOptions.url)) { + JdbcUtils.createIndex( + conn, indexName, name, columns, columnsProperties, properties, jdbcOptions) + } + } + } + + override def indexExists(indexName: String): Boolean = { + JdbcUtils.withConnection(jdbcOptions) { conn => + JdbcUtils.indexExists(conn, indexName, name, jdbcOptions) + } + } + + override def dropIndex(indexName: String): Unit = { + JdbcUtils.withConnection(jdbcOptions) { conn => + JdbcUtils.classifyException(s"Failed to drop index: $indexName", + JdbcDialects.get(jdbcOptions.url)) { + JdbcUtils.dropIndex(conn, indexName, name, jdbcOptions) + } + } + } + + override def listIndexes(): Array[TableIndex] = { + JdbcUtils.withConnection(jdbcOptions) { conn => + JdbcUtils.listIndexes(conn, name, jdbcOptions) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala index a90ab564ddb50..03200d5a6f371 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala @@ -16,12 +16,11 @@ */ package org.apache.spark.sql.execution.datasources.v2.jdbc -import java.sql.{Connection, SQLException} +import java.sql.SQLException import java.util import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.collection.mutable.ArrayBuilder import org.apache.spark.internal.Logging import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange} @@ -57,7 +56,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def listTables(namespace: Array[String]): Array[Identifier] = { checkNamespace(namespace) - withConnection { conn => + JdbcUtils.withConnection(options) { conn => val schemaPattern = if (namespace.length == 1) namespace.head else null val rs = conn.getMetaData .getTables(null, schemaPattern, "%", Array("TABLE")); @@ -72,14 +71,14 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging checkNamespace(ident.namespace()) val writeOptions = new JdbcOptionsInWrite( options.parameters + (JDBCOptions.JDBC_TABLE_NAME -> getTableName(ident))) - classifyException(s"Failed table existence check: $ident") { - withConnection(JdbcUtils.tableExists(_, writeOptions)) + JdbcUtils.classifyException(s"Failed table existence check: $ident", dialect) { + JdbcUtils.withConnection(options)(JdbcUtils.tableExists(_, writeOptions)) } } override def dropTable(ident: Identifier): Boolean = { checkNamespace(ident.namespace()) - withConnection { conn => + JdbcUtils.withConnection(options) { conn => try { JdbcUtils.dropTable(conn, getTableName(ident), options) true @@ -91,8 +90,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { checkNamespace(oldIdent.namespace()) - withConnection { conn => - classifyException(s"Failed table renaming from $oldIdent to $newIdent") { + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed table renaming from $oldIdent to $newIdent", dialect) { JdbcUtils.renameTable(conn, getTableName(oldIdent), getTableName(newIdent), options) } } @@ -151,8 +150,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging val writeOptions = new JdbcOptionsInWrite(tableOptions) val caseSensitive = SQLConf.get.caseSensitiveAnalysis - withConnection { conn => - classifyException(s"Failed table creation: $ident") { + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed table creation: $ident", dialect) { JdbcUtils.createTable(conn, getTableName(ident), schema, caseSensitive, writeOptions) } } @@ -162,8 +161,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def alterTable(ident: Identifier, changes: TableChange*): Table = { checkNamespace(ident.namespace()) - withConnection { conn => - classifyException(s"Failed table altering: $ident") { + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed table altering: $ident", dialect) { JdbcUtils.alterTable(conn, getTableName(ident), changes, options) } loadTable(ident) @@ -172,24 +171,15 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging override def namespaceExists(namespace: Array[String]): Boolean = namespace match { case Array(db) => - withConnection { conn => - val rs = conn.getMetaData.getSchemas(null, db) - while (rs.next()) { - if (rs.getString(1) == db) return true; - } - false + JdbcUtils.withConnection(options) { conn => + JdbcUtils.schemaExists(conn, options, db) } case _ => false } override def listNamespaces(): Array[Array[String]] = { - withConnection { conn => - val schemaBuilder = ArrayBuilder.make[Array[String]] - val rs = conn.getMetaData.getSchemas() - while (rs.next()) { - schemaBuilder += Array(rs.getString(1)) - } - schemaBuilder.result + JdbcUtils.withConnection(options) { conn => + JdbcUtils.listSchemas(conn, options) } } @@ -234,9 +224,9 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } } } - withConnection { conn => - classifyException(s"Failed create name space: $db") { - JdbcUtils.createNamespace(conn, options, db, comment) + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed create name space: $db", dialect) { + JdbcUtils.createSchema(conn, options, db, comment) } } @@ -253,8 +243,10 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging changes.foreach { case set: NamespaceChange.SetProperty => if (set.property() == SupportsNamespaces.PROP_COMMENT) { - withConnection { conn => - JdbcUtils.createNamespaceComment(conn, options, db, set.value) + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed create comment on name space: $db", dialect) { + JdbcUtils.alterSchemaComment(conn, options, db, set.value) + } } } else { throw QueryCompilationErrors.cannotSetJDBCNamespaceWithPropertyError(set.property) @@ -262,8 +254,10 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging case unset: NamespaceChange.RemoveProperty => if (unset.property() == SupportsNamespaces.PROP_COMMENT) { - withConnection { conn => - JdbcUtils.removeNamespaceComment(conn, options, db) + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed remove comment on name space: $db", dialect) { + JdbcUtils.removeSchemaComment(conn, options, db) + } } } else { throw QueryCompilationErrors.cannotUnsetJDBCNamespaceWithPropertyError(unset.property) @@ -278,14 +272,13 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } } - override def dropNamespace(namespace: Array[String]): Boolean = namespace match { + override def dropNamespace( + namespace: Array[String], + cascade: Boolean): Boolean = namespace match { case Array(db) if namespaceExists(namespace) => - if (listTables(Array(db)).nonEmpty) { - throw QueryExecutionErrors.namespaceNotEmptyError(namespace) - } - withConnection { conn => - classifyException(s"Failed drop name space: $db") { - JdbcUtils.dropNamespace(conn, options, db) + JdbcUtils.withConnection(options) { conn => + JdbcUtils.classifyException(s"Failed drop name space: $db", dialect) { + JdbcUtils.dropSchema(conn, options, db, cascade) true } } @@ -301,24 +294,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging } } - private def withConnection[T](f: Connection => T): T = { - val conn = JdbcUtils.createConnectionFactory(options)() - try { - f(conn) - } finally { - conn.close() - } - } - private def getTableName(ident: Identifier): String = { (ident.namespace() :+ ident.name()).map(dialect.quoteIdentifier).mkString(".") } - - private def classifyException[T](message: String)(f: => T): T = { - try { - f - } catch { - case e: Throwable => throw dialect.classifyException(message, e) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala index 0e6c72c2cc331..7449f66ee020f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala @@ -20,6 +20,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.connector.write._ import org.apache.spark.sql.execution.datasources.jdbc.{JdbcOptionsInWrite, JdbcUtils} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources.InsertableRelation import org.apache.spark.sql.types.StructType @@ -37,7 +38,8 @@ case class JDBCWriteBuilder(schema: StructType, options: JdbcOptionsInWrite) ext override def toInsertableRelation: InsertableRelation = (data: DataFrame, _: Boolean) => { // TODO (SPARK-32595): do truncate and append atomically. if (isTruncate) { - val conn = JdbcUtils.createConnectionFactory(options)() + val dialect = JdbcDialects.get(options.url) + val conn = dialect.createConnectionFactory(options)(-1) JdbcUtils.truncateTable(conn, options) } JdbcUtils.saveTable(data, Some(schema), SQLConf.get.caseSensitiveAnalysis, options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala index 29eb8bec9a589..9ab367136fc97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.json.JsonDataSource -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -83,10 +83,6 @@ case class JsonScan( dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options && equivalentFilters(pushedFilters, j.pushedFilters) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala index cf1204566ddbd..c581617a4b7e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.json import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.StructFilters -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} +import org.apache.spark.sql.connector.read.Scan import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter @@ -31,7 +31,7 @@ class JsonScanBuilder ( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { JsonScan( sparkSession, @@ -40,17 +40,16 @@ class JsonScanBuilder ( readDataSchema(), readPartitionSchema(), options, - pushedFilters()) + pushedDataFilters, + partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.jsonFilterPushDown) { - _pushedFilters = StructFilters.pushedFilters(filters, dataSchema) + StructFilters.pushedFilters(dataFilters, dataSchema) + } else { + Array.empty[Filter] } - filters } - - override def pushedFilters(): Array[Filter] = _pushedFilters } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 414252cc12481..79c34827c0bec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -23,14 +23,15 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType} import org.apache.hadoop.mapreduce.lib.input.FileSplit import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.orc.{OrcConf, OrcFile, TypeDescription} +import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription} import org.apache.orc.mapred.OrcStruct import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} -import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitionedFile} import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils} import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf @@ -54,7 +55,8 @@ case class OrcPartitionReaderFactory( dataSchema: StructType, readDataSchema: StructType, partitionSchema: StructType, - filters: Array[Filter]) extends FilePartitionReaderFactory { + filters: Array[Filter], + aggregation: Option[Aggregation]) extends FilePartitionReaderFactory { private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields) private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val capacity = sqlConf.orcVectorizedReaderBatchSize @@ -79,17 +81,14 @@ case class OrcPartitionReaderFactory( override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { val conf = broadcastedConf.value.value - - OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) - val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) + if (aggregation.nonEmpty) { + return buildReaderWithAggregates(filePath, conf) + } - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.requestedColumnIds( isCaseSensitive, dataSchema, readDataSchema, reader, conf) } @@ -126,17 +125,14 @@ case class OrcPartitionReaderFactory( override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { val conf = broadcastedConf.value.value - - OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) - val filePath = new Path(new URI(file.filePath)) - pushDownPredicates(filePath, conf) + if (aggregation.nonEmpty) { + return buildColumnarReaderWithAggregates(filePath, conf) + } - val fs = filePath.getFileSystem(conf) - val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val resultedColPruneInfo = - Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader => + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => OrcUtils.requestedColumnIds( isCaseSensitive, dataSchema, readDataSchema, reader, conf) } @@ -171,4 +167,67 @@ case class OrcPartitionReaderFactory( } } + private def createORCReader(filePath: Path, conf: Configuration): Reader = { + OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive) + + pushDownPredicates(filePath, conf) + + val fs = filePath.getFileSystem(conf) + val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) + OrcFile.createReader(filePath, readerOptions) + } + + /** + * Build reader with aggregate push down. + */ + private def buildReaderWithAggregates( + filePath: Path, + conf: Configuration): PartitionReader[InternalRow] = { + new PartitionReader[InternalRow] { + private var hasNext = true + private lazy val row: InternalRow = { + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => + OrcUtils.createAggInternalRowFromFooter( + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema) + } + } + + override def next(): Boolean = hasNext + + override def get(): InternalRow = { + hasNext = false + row + } + + override def close(): Unit = {} + } + } + + /** + * Build columnar reader with aggregate push down. + */ + private def buildColumnarReaderWithAggregates( + filePath: Path, + conf: Configuration): PartitionReader[ColumnarBatch] = { + new PartitionReader[ColumnarBatch] { + private var hasNext = true + private lazy val batch: ColumnarBatch = { + Utils.tryWithResource(createORCReader(filePath, conf)) { reader => + val row = OrcUtils.createAggInternalRowFromFooter( + reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, + readDataSchema) + AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false) + } + } + + override def next(): Boolean = hasNext + + override def get(): ColumnarBatch = { + hasNext = false + batch + } + + override def close(): Unit = {} + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index 8fa7f8dc41ead..6b9d181a7f4c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -21,8 +21,9 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.Filter import org.apache.spark.sql.types.StructType @@ -37,10 +38,25 @@ case class OrcScan( readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, + pushedAggregate: Option[Aggregation] = None, pushedFilters: Array[Filter], partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { - override def isSplitable(path: Path): Boolean = true + override def isSplitable(path: Path): Boolean = { + // If aggregate is pushed down, only the file footer will be read once, + // so file should be not split across multiple tasks. + pushedAggregate.isEmpty + } + + override def readSchema(): StructType = { + // If aggregate is pushed down, schema has already been pruned in `OrcScanBuilder` + // and no need to call super.readSchema() + if (pushedAggregate.nonEmpty) { + readDataSchema + } else { + super.readSchema() + } + } override def createReaderFactory(): PartitionReaderFactory = { val broadcastedConf = sparkSession.sparkContext.broadcast( @@ -48,28 +64,39 @@ case class OrcScan( // The partition values are already truncated in `FileScan.partitions`. // We should use `readPartitionSchema` as the partition schema here. OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, - dataSchema, readDataSchema, readPartitionSchema, pushedFilters) + dataSchema, readDataSchema, readPartitionSchema, pushedFilters, pushedAggregate) } override def equals(obj: Any): Boolean = obj match { case o: OrcScan => + val pushedDownAggEqual = if (pushedAggregate.nonEmpty && o.pushedAggregate.nonEmpty) { + AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, o.pushedAggregate.get) + } else { + pushedAggregate.isEmpty && o.pushedAggregate.isEmpty + } super.equals(o) && dataSchema == o.dataSchema && options == o.options && - equivalentFilters(pushedFilters, o.pushedFilters) - + equivalentFilters(pushedFilters, o.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { + (seqToString(pushedAggregate.get.aggregateExpressions), + seqToString(pushedAggregate.get.groupByColumns)) + } else { + ("[]", "[]") + } + override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index dc59526bb316b..d2c17fda4a382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.v2.orc import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf @@ -35,30 +36,59 @@ case class OrcScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownAggregates { + lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) } + private var finalSchema = new StructType() + + private var pushedAggregations = Option.empty[Aggregation] + override protected val supportsNestedSchemaPruning: Boolean = true override def build(): Scan = { - OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, - readDataSchema(), readPartitionSchema(), options, pushedFilters()) + // the `finalSchema` is either pruned in pushAggregation (if aggregates are + // pushed down), or pruned in readDataSchema() (in regular column pruning). These + // two are mutual exclusive. + if (pushedAggregations.isEmpty) { + finalSchema = readDataSchema() + } + OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, + readPartitionSchema(), options, pushedAggregations, pushedDataFilters, partitionFilters, + dataFilters) } - private var _pushedFilters: Array[Filter] = Array.empty - - override def pushFilters(filters: Array[Filter]): Array[Filter] = { + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = { if (sparkSession.sessionState.conf.orcFilterPushDown) { val dataTypeMap = OrcFilters.getSearchableTypeMap( readDataSchema(), SQLConf.get.caseSensitiveAnalysis) - _pushedFilters = OrcFilters.convertibleFilters(dataTypeMap, filters).toArray + OrcFilters.convertibleFilters(dataTypeMap, dataFilters).toArray + } else { + Array.empty[Filter] } - filters } - override def pushedFilters(): Array[Filter] = _pushedFilters + override def pushAggregation(aggregation: Aggregation): Boolean = { + if (!sparkSession.sessionState.conf.orcAggregatePushDown) { + return false + } + + AggregatePushDownUtils.getSchemaForPushedAggregation( + aggregation, + schema, + partitionNameSet, + dataFilters) match { + + case Some(schema) => + finalSchema = schema + this.pushedAggregations = Some(aggregation) + true + case _ => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 058669b0937fa..6f021ff2e97f5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -25,16 +25,18 @@ import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.parquet.filter2.compat.FilterCompat import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} -import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS +import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS} import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} +import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} -import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile, RecordReaderIterator} import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf @@ -53,6 +55,7 @@ import org.apache.spark.util.SerializableConfiguration * @param readDataSchema Required schema of Parquet files. * @param partitionSchema Schema of partitions. * @param filters Filters to be pushed down in the batch scan. + * @param aggregation Aggregation to be pushed down in the batch scan. * @param parquetOptions The options of Parquet datasource that are set for the read. */ case class ParquetPartitionReaderFactory( @@ -62,6 +65,7 @@ case class ParquetPartitionReaderFactory( readDataSchema: StructType, partitionSchema: StructType, filters: Array[Filter], + aggregation: Option[Aggregation], parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging { private val isCaseSensitive = sqlConf.caseSensitiveAnalysis private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields) @@ -80,6 +84,30 @@ case class ParquetPartitionReaderFactory( private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead + private def getFooter(file: PartitionedFile): ParquetMetadata = { + val conf = broadcastedConf.value.value + val filePath = new Path(new URI(file.filePath)) + + if (aggregation.isEmpty) { + ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS) + } else { + // For aggregate push down, we will get max/min/count from footer statistics. + // We want to read the footer for the whole file instead of reading multiple + // footers for every split of the file. Basically if the start (the beginning of) + // the offset in PartitionedFile is 0, we will read the footer. Otherwise, it means + // that we have already read footer for that file, so we will skip reading again. + if (file.start != 0) return null + ParquetFooterReader.readFooter(conf, filePath, NO_FILTER) + } + } + + private def getDatetimeRebaseMode( + footerFileMetaData: FileMetaData): LegacyBehaviorPolicy.Value = { + DataSourceUtils.datetimeRebaseMode( + footerFileMetaData.getKeyValueMetaData.get, + datetimeRebaseModeInRead) + } + override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled && resultSchema.length <= sqlConf.wholeStageMaxNumFields && @@ -87,18 +115,44 @@ case class ParquetPartitionReaderFactory( } override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = { - val reader = if (enableVectorizedReader) { - createVectorizedReader(file) - } else { - createRowBaseReader(file) - } + val fileReader = if (aggregation.isEmpty) { + val reader = if (enableVectorizedReader) { + createVectorizedReader(file) + } else { + createRowBaseReader(file) + } + + new PartitionReader[InternalRow] { + override def next(): Boolean = reader.nextKeyValue() - val fileReader = new PartitionReader[InternalRow] { - override def next(): Boolean = reader.nextKeyValue() + override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] - override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow] + override def close(): Unit = reader.close() + } + } else { + new PartitionReader[InternalRow] { + private var hasNext = true + private lazy val row: InternalRow = { + val footer = getFooter(file) + if (footer != null && footer.getBlocks.size > 0) { + ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema, + partitionSchema, aggregation.get, readDataSchema, + getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + } else { + null + } + } + override def next(): Boolean = { + hasNext && row != null + } - override def close(): Unit = reader.close() + override def get(): InternalRow = { + hasNext = false + row + } + + override def close(): Unit = {} + } } new PartitionReaderWithPartitionValues(fileReader, readDataSchema, @@ -106,17 +160,47 @@ case class ParquetPartitionReaderFactory( } override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { - val vectorizedReader = createVectorizedReader(file) - vectorizedReader.enableReturningBatches() + val fileReader = if (aggregation.isEmpty) { + val vectorizedReader = createVectorizedReader(file) + vectorizedReader.enableReturningBatches() + + new PartitionReader[ColumnarBatch] { + override def next(): Boolean = vectorizedReader.nextKeyValue() - new PartitionReader[ColumnarBatch] { - override def next(): Boolean = vectorizedReader.nextKeyValue() + override def get(): ColumnarBatch = + vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] - override def get(): ColumnarBatch = - vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch] + override def close(): Unit = vectorizedReader.close() + } + } else { + new PartitionReader[ColumnarBatch] { + private var hasNext = true + private val batch: ColumnarBatch = { + val footer = getFooter(file) + if (footer != null && footer.getBlocks.size > 0) { + val row = ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, + dataSchema, partitionSchema, aggregation.get, readDataSchema, + getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive) + AggregatePushDownUtils.convertAggregatesRowToBatch( + row, readDataSchema, enableOffHeapColumnVector && Option(TaskContext.get()).isDefined) + } else { + null + } + } + + override def next(): Boolean = { + hasNext && batch != null + } + + override def get(): ColumnarBatch = { + hasNext = false + batch + } - override def close(): Unit = vectorizedReader.close() + override def close(): Unit = {} + } } + fileReader } private def buildReaderBase[T]( @@ -131,11 +215,8 @@ case class ParquetPartitionReaderFactory( val filePath = new Path(new URI(file.filePath)) val split = new FileSplit(filePath, file.start, file.length, Array.empty[String]) - lazy val footerFileMetaData = - ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData - val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode( - footerFileMetaData.getKeyValueMetaData.get, - datetimeRebaseModeInRead) + lazy val footerFileMetaData = getFooter(file).getFileMetaData + val datetimeRebaseMode = getDatetimeRebaseMode(footerFileMetaData) // Try to push down filters when filter push-down is enabled. val pushed = if (enableParquetFilterPushDown) { val parquetSchema = footerFileMetaData.getSchema diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala index 60573ba10ccb6..b92ed82190ae8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala @@ -24,8 +24,9 @@ import org.apache.parquet.hadoop.ParquetInputFormat import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.PartitionReaderFactory -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport} import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.internal.SQLConf @@ -43,10 +44,17 @@ case class ParquetScan( readPartitionSchema: StructType, pushedFilters: Array[Filter], options: CaseInsensitiveStringMap, + pushedAggregate: Option[Aggregation] = None, partitionFilters: Seq[Expression] = Seq.empty, dataFilters: Seq[Expression] = Seq.empty) extends FileScan { override def isSplitable(path: Path): Boolean = true + override def readSchema(): StructType = { + // If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder` + // and no need to call super.readSchema() + if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema() + } + override def createReaderFactory(): PartitionReaderFactory = { val readDataSchemaAsJson = readDataSchema.json hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName) @@ -86,27 +94,40 @@ case class ParquetScan( readDataSchema, readPartitionSchema, pushedFilters, + pushedAggregate, new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf)) } override def equals(obj: Any): Boolean = obj match { case p: ParquetScan => + val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) { + AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get) + } else { + pushedAggregate.isEmpty && p.pushedAggregate.isEmpty + } super.equals(p) && dataSchema == p.dataSchema && options == p.options && - equivalentFilters(pushedFilters, p.pushedFilters) + equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual case _ => false } override def hashCode(): Int = getClass.hashCode() + lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) { + (seqToString(pushedAggregate.get.aggregateExpressions), + seqToString(pushedAggregate.get.groupByColumns)) + } else { + ("[]", "[]") + } + override def description(): String = { - super.description() + ", PushedFilters: " + seqToString(pushedFilters) + super.description() + ", PushedFilters: " + seqToString(pushedFilters) + + ", PushedAggregation: " + pushedAggregationsStr + + ", PushedGroupBy: " + pushedGroupByStr } override def getMetaData(): Map[String, String] = { - super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) + super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++ + Map("PushedAggregation" -> pushedAggregationsStr) ++ + Map("PushedGroupBy" -> pushedGroupByStr) } - - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala index 4b3f4e7edca6c..d198321eacdb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import scala.collection.JavaConverters._ import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters} -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex +import org.apache.spark.sql.connector.expressions.aggregate.Aggregation +import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex} import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter} import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy @@ -35,7 +36,8 @@ case class ParquetScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) - extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters { + extends FileScanBuilder(sparkSession, fileIndex, dataSchema) + with SupportsPushDownAggregates{ lazy val hadoopConf = { val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap // Hadoop Configurations are case sensitive. @@ -63,25 +65,50 @@ case class ParquetScanBuilder( // The rebase mode doesn't matter here because the filters are used to determine // whether they is convertible. LegacyBehaviorPolicy.CORRECTED) - parquetFilters.convertibleFilters(this.filters).toArray + parquetFilters.convertibleFilters(pushedDataFilters).toArray } - override protected val supportsNestedSchemaPruning: Boolean = true + private var finalSchema = new StructType() - private var filters: Array[Filter] = Array.empty + private var pushedAggregations = Option.empty[Aggregation] - override def pushFilters(filters: Array[Filter]): Array[Filter] = { - this.filters = filters - this.filters - } + override protected val supportsNestedSchemaPruning: Boolean = true + + override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters // Note: for Parquet, the actual filter push down happens in [[ParquetPartitionReaderFactory]]. // It requires the Parquet physical schema to determine whether a filter is convertible. // All filters that can be converted to Parquet are pushed down. override def pushedFilters(): Array[Filter] = pushedParquetFilters + override def pushAggregation(aggregation: Aggregation): Boolean = { + if (!sparkSession.sessionState.conf.parquetAggregatePushDown) { + return false + } + + AggregatePushDownUtils.getSchemaForPushedAggregation( + aggregation, + schema, + partitionNameSet, + dataFilters) match { + + case Some(schema) => + finalSchema = schema + this.pushedAggregations = Some(aggregation) + true + case _ => false + } + } + override def build(): Scan = { - ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(), - readPartitionSchema(), pushedParquetFilters, options) + // the `finalSchema` is either pruned in pushAggregation (if aggregates are + // pushed down), or pruned in readDataSchema() (in regular column pruning). These + // two are mutual exclusive. + if (pushedAggregations.isEmpty) { + finalSchema = readDataSchema() + } + ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema, + readPartitionSchema(), pushedParquetFilters, options, pushedAggregations, + partitionFilters, dataFilters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala index e75de2c4a4079..c7b0fec34b4e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.connector.read.PartitionReaderFactory import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.text.TextOptions -import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan} +import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration @@ -33,6 +33,7 @@ import org.apache.spark.util.SerializableConfiguration case class TextScan( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, + dataSchema: StructType, readDataSchema: StructType, readPartitionSchema: StructType, options: CaseInsensitiveStringMap, @@ -71,10 +72,6 @@ case class TextScan( readPartitionSchema, textOptions) } - override def withFilters( - partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan = - this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters) - override def equals(obj: Any): Boolean = obj match { case t: TextScan => super.equals(t) && options == t.options diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala index b2b518c12b01a..0ebb098bfc1df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala @@ -33,6 +33,7 @@ case class TextScanBuilder( extends FileScanBuilder(sparkSession, fileIndex, dataSchema) { override def build(): Scan = { - TextScan(sparkSession, fileIndex, readDataSchema(), readPartitionSchema(), options) + TextScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options, + partitionFilters, dataFilters) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index 0b394db5c8932..9bf25aa0d633f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{SQLException, Types} import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ private object DB2Dialect extends JdbcDialect { @@ -27,6 +30,37 @@ private object DB2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2") + // See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARIANCE($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARIANCE_SAMP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => + assert(f.children().length == 2) + Some(s"COVARIANCE(${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => + assert(f.children().length == 2) + Some(s"COVARIANCE_SAMP(${f.children().head}, ${f.children().last})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, @@ -79,4 +113,28 @@ private object DB2Dialect extends JdbcDialect { val nullable = if (isNullable) "DROP NOT NULL" else "SET NOT NULL" s"ALTER TABLE $tableName ALTER COLUMN ${quoteIdentifier(columnName)} $nullable" } + + override def removeSchemaCommentQuery(schema: String): String = { + s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS ''" + } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getSQLState match { + // https://www.ibm.com/docs/en/db2/11.5?topic=messages-sqlstate + case "42893" => throw NonEmptyNamespaceException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case _ => super.classifyException(message, e) + } + } + + override def dropSchema(schema: String, cascade: Boolean): String = { + if (cascade) { + s"DROP SCHEMA ${quoteIdentifier(schema)} CASCADE" + } else { + s"DROP SCHEMA ${quoteIdentifier(schema)} RESTRICT" + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala index 020733aaee8c0..36c3c6be4a05c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types._ @@ -29,6 +30,27 @@ private object DerbyDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby") + // See https://db.apache.org/derby/docs/10.15/ref/index.html + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"VAR_POP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"VAR_SAMP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"STDDEV_POP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"STDDEV_SAMP(${f.children().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) Option(FloatType) else None @@ -47,7 +69,7 @@ private object DerbyDialect extends JdbcDialect { override def isCascadingTruncateTable(): Option[Boolean] = Some(false) - // See https://db.apache.org/derby/docs/10.5/ref/rrefsqljrenametablestatement.html + // See https://db.apache.org/derby/docs/10.15/ref/rrefsqljrenametablestatement.html override def renameTable(oldTable: String, newTable: String): String = { s"RENAME TABLE $oldTable TO $newTable" } @@ -57,4 +79,8 @@ private object DerbyDialect extends JdbcDialect { override def getTableCommentQuery(table: String, comment: String): String = { throw QueryExecutionErrors.commentOnTableUnsupportedError() } + + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 9c727957ffab8..6681aee778dbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -20,13 +20,76 @@ package org.apache.spark.sql.jdbc import java.sql.SQLException import java.util.Locale +import scala.util.control.NonFatal + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.connector.expressions.Expression +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} +import org.apache.spark.sql.errors.QueryCompilationErrors private object H2Dialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2") + class H2SQLBuilder extends JDBCSQLBuilder { + override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { + funcName match { + case "WIDTH_BUCKET" => + val functionInfo = super.visitSQLFunction(funcName, inputs) + throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo) + case _ => super.visitSQLFunction(funcName, inputs) + } + } + } + + override def compileExpression(expr: Expression): Option[String] = { + val h2SQLBuilder = new H2SQLBuilder() + try { + Some(h2SQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } + } + + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.children().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.children().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.children().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.children().head}, ${f.children().last})") + case _ => None + } + ) + } + override def classifyException(message: String, e: Throwable): AnalysisException = { if (e.isInstanceOf[SQLException]) { // Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index aa957113b5ca5..397942d7837db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -17,21 +17,30 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, Date, Timestamp} +import java.sql.{Connection, Date, Driver, Statement, Timestamp} import java.time.{Instant, LocalDate} +import java.util import scala.collection.mutable.ArrayBuilder +import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.TableChange._ +import org.apache.spark.sql.connector.catalog.index.TableIndex +import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum} +import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -94,6 +103,29 @@ abstract class JdbcDialect extends Serializable with Logging{ */ def getJDBCType(dt: DataType): Option[JdbcType] = None + /** + * Returns a factory for creating connections to the given JDBC URL. + * In general, creating a connection has nothing to do with JDBC partition id. + * But sometimes it is needed, such as a database with multiple shard nodes. + * @param options - JDBC options that contains url, table and other information. + * @return The factory method for creating JDBC connections with the RDD partition ID. -1 means + the connection is being created at the driver side. + * @throws IllegalArgumentException if the driver could not open a JDBC connection. + */ + @Since("3.3.0") + def createConnectionFactory(options: JDBCOptions): Int => Connection = { + val driverClass: String = options.driverClass + (partitionId: Int) => { + DriverRegistry.register(driverClass) + val driver: Driver = DriverRegistry.get(driverClass) + val connection = + ConnectionProvider.create(driver, options.parameters) + require(connection != null, + s"The driver could not open a JDBC connection. Check the URL: ${options.url}") + connection + } + } + /** * Quotes the identifier. This is used to put quotes around the identifier in case the column * name is a reserved keyword, or in case it contains characters that require quotes (e.g. space). @@ -189,6 +221,110 @@ abstract class JdbcDialect extends Serializable with Logging{ case _ => value } + class JDBCSQLBuilder extends V2ExpressionSQLBuilder { + override def visitLiteral(literal: Literal[_]): String = { + compileValue( + CatalystTypeConverters.convertToScala(literal.value(), literal.dataType())).toString + } + + override def visitNamedReference(namedRef: NamedReference): String = { + if (namedRef.fieldNames().length > 1) { + throw QueryCompilationErrors.commandNotSupportNestedColumnError( + "Filter push down", namedRef.toString) + } + quoteIdentifier(namedRef.fieldNames.head) + } + + override def visitCast(l: String, dataType: DataType): String = { + val databaseTypeDefinition = + getJDBCType(dataType).map(_.databaseTypeDefinition).getOrElse(dataType.typeName) + s"CAST($l AS $databaseTypeDefinition)" + } + } + + /** + * Converts V2 expression to String representing a SQL expression. + * @param expr The V2 expression to be converted. + * @return Converted value. + */ + @Since("3.3.0") + def compileExpression(expr: Expression): Option[String] = { + val jdbcSQLBuilder = new JDBCSQLBuilder() + try { + Some(jdbcSQLBuilder.build(expr)) + } catch { + case NonFatal(e) => + logWarning("Error occurs while compiling V2 expression", e) + None + } + } + + /** + * Converts aggregate function to String representing a SQL expression. + * @param aggFunction The aggregate function to be converted. + * @return Converted value. + */ + @Since("3.3.0") + def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + aggFunction match { + case min: Min => + compileExpression(min.column).map(v => s"MIN($v)") + case max: Max => + compileExpression(max.column).map(v => s"MAX($v)") + case count: Count => + val distinct = if (count.isDistinct) "DISTINCT " else "" + compileExpression(count.column).map(v => s"COUNT($distinct$v)") + case sum: Sum => + val distinct = if (sum.isDistinct) "DISTINCT " else "" + compileExpression(sum.column).map(v => s"SUM($distinct$v)") + case _: CountStar => + Some("COUNT(*)") + case avg: Avg => + val distinct = if (avg.isDistinct) "DISTINCT " else "" + compileExpression(avg.column).map(v => s"AVG($distinct$v)") + case _ => None + } + } + + /** + * Create schema with an optional comment. Empty string means no comment. + */ + def createSchema(statement: Statement, schema: String, comment: String): Unit = { + val schemaCommentQuery = if (comment.nonEmpty) { + // We generate comment query here so that it can fail earlier without creating the schema. + getSchemaCommentQuery(schema, comment) + } else { + comment + } + statement.executeUpdate(s"CREATE SCHEMA ${quoteIdentifier(schema)}") + if (comment.nonEmpty) { + statement.executeUpdate(schemaCommentQuery) + } + } + + /** + * Check schema exists or not. + */ + def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = { + val rs = conn.getMetaData.getSchemas(null, schema) + while (rs.next()) { + if (rs.getString(1) == schema) return true; + } + false + } + + /** + * Lists all the schemas in this table. + */ + def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = { + val schemaBuilder = ArrayBuilder.make[Array[String]] + val rs = conn.getMetaData.getSchemas() + while (rs.next()) { + schemaBuilder += Array(rs.getString(1)) + } + schemaBuilder.result + } + /** * Return Some[true] iff `TRUNCATE TABLE` causes cascading default. * Some[true] : TRUNCATE TABLE causes cascading. @@ -287,6 +423,71 @@ abstract class JdbcDialect extends Serializable with Logging{ s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS NULL" } + def dropSchema(schema: String, cascade: Boolean): String = { + if (cascade) { + s"DROP SCHEMA ${quoteIdentifier(schema)} CASCADE" + } else { + s"DROP SCHEMA ${quoteIdentifier(schema)}" + } + } + + /** + * Build a create index SQL statement. + * + * @param indexName the name of the index to be created + * @param tableName the table on which index to be created + * @param columns the columns on which index to be created + * @param columnsProperties the properties of the columns on which index to be created + * @param properties the properties of the index to be created + * @return the SQL statement to use for creating the index. + */ + def createIndex( + indexName: String, + tableName: String, + columns: Array[NamedReference], + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String]): String = { + throw new UnsupportedOperationException("createIndex is not supported") + } + + /** + * Checks whether an index exists + * + * @param indexName the name of the index + * @param tableName the table name on which index to be checked + * @param options JDBCOptions of the table + * @return true if the index with `indexName` exists in the table with `tableName`, + * false otherwise + */ + def indexExists( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Boolean = { + throw new UnsupportedOperationException("indexExists is not supported") + } + + /** + * Build a drop index SQL statement. + * + * @param indexName the name of the index to be dropped. + * @param tableName the table name on which index to be dropped. + * @return the SQL statement to use for dropping the index. + */ + def dropIndex(indexName: String, tableName: String): String = { + throw new UnsupportedOperationException("dropIndex is not supported") + } + + /** + * Lists all the indexes in this table. + */ + def listIndexes( + conn: Connection, + tableName: String, + options: JDBCOptions): Array[TableIndex] = { + throw new UnsupportedOperationException("listIndexes is not supported") + } + /** * Gets a dialect exception, classifies it and wraps it by `AnalysisException`. * @param message The error message to be placed to the returned exception. @@ -296,6 +497,18 @@ abstract class JdbcDialect extends Serializable with Logging{ def classifyException(message: String, e: Throwable): AnalysisException = { new AnalysisException(message, cause = Some(e)) } + + /** + * returns the LIMIT clause for the SELECT statement + */ + def getLimitClause(limit: Integer): String = { + if (limit > 0 ) s"LIMIT $limit" else "" + } + + def supportsTableSample: Boolean = false + + def getTableSample(sample: TableSampleInfo): String = + throw new UnsupportedOperationException("TableSample is not supported by this data source") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala index ea9834830e373..8d2fbec55f919 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.jdbc +import java.sql.SQLException import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -36,6 +40,33 @@ private object MsSqlServerDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver") + // scalastyle:off line.size.limit + // See https://docs.microsoft.com/en-us/sql/t-sql/functions/aggregate-functions-transact-sql?view=sql-server-ver15 + // scalastyle:on line.size.limit + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VARP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDEVP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDEV($distinct${f.children().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (typeName.contains("datetimeoffset")) { @@ -118,4 +149,19 @@ private object MsSqlServerDialect extends JdbcDialect { override def getTableCommentQuery(table: String, comment: String): String = { throw QueryExecutionErrors.commentOnTableUnsupportedError() } + + override def getLimitClause(limit: Integer): String = { + "" + } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getErrorCode match { + case 3729 => throw NonEmptyNamespaceException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case _ => super.classifyException(message, e) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index ed107707c9d1f..24f9bac74f86d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -17,18 +17,48 @@ package org.apache.spark.sql.jdbc -import java.sql.Types +import java.sql.{Connection, SQLException, Types} +import java.util import java.util.Locale +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException} +import org.apache.spark.sql.connector.catalog.index.TableIndex +import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder} -private case object MySQLDialect extends JdbcDialect { +private case object MySQLDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url : String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql") + // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"VAR_POP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"VAR_SAMP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"STDDEV_POP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"STDDEV_SAMP(${f.children().head})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) { @@ -45,6 +75,25 @@ private case object MySQLDialect extends JdbcDialect { s"`$colName`" } + override def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = { + listSchemas(conn, options).exists(_.head == schema) + } + + override def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = { + val schemaBuilder = ArrayBuilder.make[Array[String]] + try { + JdbcUtils.executeQuery(conn, options, "SHOW SCHEMAS") { rs => + while (rs.next()) { + schemaBuilder += Array(rs.getString("Database")) + } + } + } catch { + case _: Exception => + logWarning("Cannot show schemas.") + } + schemaBuilder.result + } + override def getTableExistsQuery(table: String): String = { s"SELECT 1 FROM $table LIMIT 1" } @@ -102,4 +151,107 @@ private case object MySQLDialect extends JdbcDialect { case FloatType => Option(JdbcType("FLOAT", java.sql.Types.FLOAT)) case _ => JdbcUtils.getCommonJDBCType(dt) } + + override def getSchemaCommentQuery(schema: String, comment: String): String = { + throw QueryExecutionErrors.unsupportedCreateNamespaceCommentError() + } + + override def removeSchemaCommentQuery(schema: String): String = { + throw QueryExecutionErrors.unsupportedRemoveNamespaceCommentError() + } + + // CREATE INDEX syntax + // https://dev.mysql.com/doc/refman/8.0/en/create-index.html + override def createIndex( + indexName: String, + tableName: String, + columns: Array[NamedReference], + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String]): String = { + val columnList = columns.map(col => quoteIdentifier(col.fieldNames.head)) + val (indexType, indexPropertyList) = JdbcUtils.processIndexProperties(properties, "mysql") + + // columnsProperties doesn't apply to MySQL so it is ignored + s"CREATE INDEX ${quoteIdentifier(indexName)} $indexType ON" + + s" ${quoteIdentifier(tableName)} (${columnList.mkString(", ")})" + + s" ${indexPropertyList.mkString(" ")}" + } + + // SHOW INDEX syntax + // https://dev.mysql.com/doc/refman/8.0/en/show-index.html + override def indexExists( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Boolean = { + val sql = s"SHOW INDEXES FROM ${quoteIdentifier(tableName)} WHERE key_name = '$indexName'" + JdbcUtils.checkIfIndexExists(conn, sql, options) + } + + override def dropIndex(indexName: String, tableName: String): String = { + s"DROP INDEX ${quoteIdentifier(indexName)} ON $tableName" + } + + // SHOW INDEX syntax + // https://dev.mysql.com/doc/refman/8.0/en/show-index.html + override def listIndexes( + conn: Connection, + tableName: String, + options: JDBCOptions): Array[TableIndex] = { + val sql = s"SHOW INDEXES FROM $tableName" + var indexMap: Map[String, TableIndex] = Map() + try { + JdbcUtils.executeQuery(conn, options, sql) { rs => + while (rs.next()) { + val indexName = rs.getString("key_name") + val colName = rs.getString("column_name") + val indexType = rs.getString("index_type") + val indexComment = rs.getString("Index_comment") + if (indexMap.contains(indexName)) { + val index = indexMap.get(indexName).get + val newIndex = new TableIndex(indexName, indexType, + index.columns() :+ FieldReference(colName), + index.columnProperties, index.properties) + indexMap += (indexName -> newIndex) + } else { + // The only property we are building here is `COMMENT` because it's the only one + // we can get from `SHOW INDEXES`. + val properties = new util.Properties(); + if (indexComment.nonEmpty) properties.put("COMMENT", indexComment) + val index = new TableIndex(indexName, indexType, Array(FieldReference(colName)), + new util.HashMap[NamedReference, util.Properties](), properties) + indexMap += (indexName -> index) + } + } + } + } catch { + case _: Exception => + logWarning("Cannot retrieved index info.") + } + indexMap.values.toArray + } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getErrorCode match { + // ER_DUP_KEYNAME + case 1061 => + throw new IndexAlreadyExistsException(message, cause = Some(e)) + case 1091 => + throw new NoSuchIndexException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case unsupported: UnsupportedOperationException => throw unsupported + case _ => super.classifyException(message, e) + } + } + + override def dropSchema(schema: String, cascade: Boolean): String = { + if (cascade) { + s"DROP SCHEMA ${quoteIdentifier(schema)}" + } else { + throw QueryExecutionErrors.unsupportedDropNamespaceRestrictError() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index b741ece8dda9b..40333c1757c4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp, Types} import java.util.{Locale, TimeZone} import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -33,6 +34,38 @@ private case object OracleDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle") + // scalastyle:off line.size.limit + // https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848 + // scalastyle:on line.size.limit + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"VAR_POP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"VAR_SAMP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"STDDEV_POP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false => + assert(f.children().length == 1) + Some(s"STDDEV_SAMP(${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => + assert(f.children().length == 2) + Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => + assert(f.children().length == 2) + Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => + assert(f.children().length == 2) + Some(s"CORR(${f.children().head}, ${f.children().last})") + case _ => None + } + ) + } + private def supportTimeZoneTypes: Boolean = { val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone) // TODO: support timezone types when users are not using the JVM timezone, which diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 3ce785ed844c5..a668d66ee2f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -17,18 +17,62 @@ package org.apache.spark.sql.jdbc -import java.sql.{Connection, Types} +import java.sql.{Connection, SQLException, Types} +import java.util import java.util.Locale +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper +import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException} +import org.apache.spark.sql.connector.expressions.NamedReference +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils} +import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo import org.apache.spark.sql.types._ -private object PostgresDialect extends JdbcDialect { +private object PostgresDialect extends JdbcDialect with SQLConfHelper { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql") + // See https://www.postgresql.org/docs/8.4/functions-aggregate.html + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" => + assert(f.children().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" => + assert(f.children().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" => + assert(f.children().length == 2) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"CORR($distinct${f.children().head}, ${f.children().last})") + case _ => None + } + ) + } + override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { if (sqlType == Types.REAL) { @@ -154,4 +198,66 @@ private object PostgresDialect extends JdbcDialect { val nullable = if (isNullable) "DROP NOT NULL" else "SET NOT NULL" s"ALTER TABLE $tableName ALTER COLUMN ${quoteIdentifier(columnName)} $nullable" } + + override def supportsTableSample: Boolean = true + + override def getTableSample(sample: TableSampleInfo): String = { + // hard-coded to BERNOULLI for now because Spark doesn't have a way to specify sample + // method name + s"TABLESAMPLE BERNOULLI" + + s" (${(sample.upperBound - sample.lowerBound) * 100}) REPEATABLE (${sample.seed})" + } + + // CREATE INDEX syntax + // https://www.postgresql.org/docs/14/sql-createindex.html + override def createIndex( + indexName: String, + tableName: String, + columns: Array[NamedReference], + columnsProperties: util.Map[NamedReference, util.Map[String, String]], + properties: util.Map[String, String]): String = { + val columnList = columns.map(col => quoteIdentifier(col.fieldNames.head)) + var indexProperties = "" + val (indexType, indexPropertyList) = JdbcUtils.processIndexProperties(properties, "postgresql") + + if (indexPropertyList.nonEmpty) { + indexProperties = "WITH (" + indexPropertyList.mkString(", ") + ")" + } + + s"CREATE INDEX ${quoteIdentifier(indexName)} ON ${quoteIdentifier(tableName)}" + + s" $indexType (${columnList.mkString(", ")}) $indexProperties" + } + + // SHOW INDEX syntax + // https://www.postgresql.org/docs/14/view-pg-indexes.html + override def indexExists( + conn: Connection, + indexName: String, + tableName: String, + options: JDBCOptions): Boolean = { + val sql = s"SELECT * FROM pg_indexes WHERE tablename = '$tableName' AND" + + s" indexname = '$indexName'" + JdbcUtils.checkIfIndexExists(conn, sql, options) + } + + // DROP INDEX syntax + // https://www.postgresql.org/docs/14/sql-dropindex.html + override def dropIndex(indexName: String, tableName: String): String = { + s"DROP INDEX ${quoteIdentifier(indexName)}" + } + + override def classifyException(message: String, e: Throwable): AnalysisException = { + e match { + case sqlException: SQLException => + sqlException.getSQLState match { + // https://www.postgresql.org/docs/14/errcodes-appendix.html + case "42P07" => throw new IndexAlreadyExistsException(message, cause = Some(e)) + case "42704" => throw new NoSuchIndexException(message, cause = Some(e)) + case "2BP01" => throw NonEmptyNamespaceException(message, cause = Some(e)) + case _ => super.classifyException(message, e) + } + case unsupported: UnsupportedOperationException => throw unsupported + case _ => super.classifyException(message, e) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala index 58fe62cb6e088..79fb710cf03b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc import java.util.Locale +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc} import org.apache.spark.sql.types._ @@ -27,6 +28,42 @@ private case object TeradataDialect extends JdbcDialect { override def canHandle(url: String): Boolean = url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata") + // scalastyle:off line.size.limit + // See https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/Aggregate-Functions + // scalastyle:on line.size.limit + override def compileAggregate(aggFunction: AggregateFunc): Option[String] = { + super.compileAggregate(aggFunction).orElse( + aggFunction match { + case f: GeneralAggregateFunc if f.name() == "VAR_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_POP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"VAR_SAMP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_POP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" => + assert(f.children().length == 1) + val distinct = if (f.isDistinct) "DISTINCT " else "" + Some(s"STDDEV_SAMP($distinct${f.children().head})") + case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false => + assert(f.children().length == 2) + Some(s"COVAR_POP(${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false => + assert(f.children().length == 2) + Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})") + case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false => + assert(f.children().length == 2) + Some(s"CORR(${f.children().head}, ${f.children().last})") + case _ => None + } + ) + } + override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR)) case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR)) @@ -55,4 +92,8 @@ private case object TeradataDialect extends JdbcDialect { override def renameTable(oldTable: String, newTable: String): String = { s"RENAME TABLE $oldTable TO $newTable" } + + override def getLimitClause(limit: Integer): String = { + "" + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java new file mode 100644 index 0000000000000..ec532da61042f --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.connector; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.connector.TestingV2Source; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.expressions.FieldReference; +import org.apache.spark.sql.connector.expressions.Literal; +import org.apache.spark.sql.connector.expressions.LiteralValue; +import org.apache.spark.sql.connector.expressions.filter.Predicate; +import org.apache.spark.sql.connector.read.*; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +public class JavaAdvancedDataSourceV2WithV2Filter implements TestingV2Source { + + @Override + public Table getTable(CaseInsensitiveStringMap options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { + return new AdvancedScanBuilderWithV2Filter(); + } + }; + } + + static class AdvancedScanBuilderWithV2Filter implements ScanBuilder, Scan, + SupportsPushDownV2Filters, SupportsPushDownRequiredColumns { + + private StructType requiredSchema = TestingV2Source.schema(); + private Predicate[] predicates = new Predicate[0]; + + @Override + public void pruneColumns(StructType requiredSchema) { + this.requiredSchema = requiredSchema; + } + + @Override + public StructType readSchema() { + return requiredSchema; + } + + @Override + public Predicate[] pushPredicates(Predicate[] predicates) { + Predicate[] supported = Arrays.stream(predicates).filter(f -> { + if (f.name().equals(">")) { + assert(f.children()[0] instanceof FieldReference); + FieldReference column = (FieldReference) f.children()[0]; + assert(f.children()[1] instanceof LiteralValue); + Literal value = (Literal) f.children()[1]; + return column.describe().equals("i") && value.value() instanceof Integer; + } else { + return false; + } + }).toArray(Predicate[]::new); + + Predicate[] unsupported = Arrays.stream(predicates).filter(f -> { + if (f.name().equals(">")) { + assert(f.children()[0] instanceof FieldReference); + FieldReference column = (FieldReference) f.children()[0]; + assert(f.children()[1] instanceof LiteralValue); + Literal value = (LiteralValue) f.children()[1]; + return !column.describe().equals("i") || !(value.value() instanceof Integer); + } else { + return true; + } + }).toArray(Predicate[]::new); + + this.predicates = supported; + return unsupported; + } + + @Override + public Predicate[] pushedPredicates() { + return predicates; + } + + @Override + public Scan build() { + return this; + } + + @Override + public Batch toBatch() { + return new AdvancedBatchWithV2Filter(requiredSchema, predicates); + } + } + + public static class AdvancedBatchWithV2Filter implements Batch { + // Exposed for testing. + public StructType requiredSchema; + public Predicate[] predicates; + + AdvancedBatchWithV2Filter(StructType requiredSchema, Predicate[] predicates) { + this.requiredSchema = requiredSchema; + this.predicates = predicates; + } + + @Override + public InputPartition[] planInputPartitions() { + List res = new ArrayList<>(); + + Integer lowerBound = null; + for (Predicate predicate : predicates) { + if (predicate.name().equals(">")) { + assert(predicate.children()[0] instanceof FieldReference); + FieldReference column = (FieldReference) predicate.children()[0]; + assert(predicate.children()[1] instanceof LiteralValue); + Literal value = (Literal) predicate.children()[1]; + if ("i".equals(column.describe()) && value.value() instanceof Integer) { + lowerBound = (Integer) value.value(); + break; + } + } + } + + if (lowerBound == null) { + res.add(new JavaRangeInputPartition(0, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 4) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 5)); + res.add(new JavaRangeInputPartition(5, 10)); + } else if (lowerBound < 9) { + res.add(new JavaRangeInputPartition(lowerBound + 1, 10)); + } + + return res.stream().toArray(InputPartition[]::new); + } + + @Override + public PartitionReaderFactory createReaderFactory() { + return new AdvancedReaderFactoryWithV2Filter(requiredSchema); + } + } + + static class AdvancedReaderFactoryWithV2Filter implements PartitionReaderFactory { + StructType requiredSchema; + + AdvancedReaderFactoryWithV2Filter(StructType requiredSchema) { + this.requiredSchema = requiredSchema; + } + + @Override + public PartitionReader createReader(InputPartition partition) { + JavaRangeInputPartition p = (JavaRangeInputPartition) partition; + return new PartitionReader() { + private int current = p.start - 1; + + @Override + public boolean next() throws IOException { + current += 1; + return current < p.end; + } + + @Override + public InternalRow get() { + Object[] values = new Object[requiredSchema.size()]; + for (int i = 0; i < values.length; i++) { + if ("i".equals(requiredSchema.apply(i).name())) { + values[i] = current; + } else if ("j".equals(requiredSchema.apply(i).name())) { + values[i] = -current; + } + } + return new GenericInternalRow(values); + } + + @Override + public void close() throws IOException { + + } + }; + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 001b6a00af52f..910f159cc49a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -731,6 +731,28 @@ class FileBasedDataSourceSuite extends QueryTest } } + test("SPARK-36568: FileScan statistics estimation takes read schema into account") { + withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { + withTempDir { dir => + spark.range(1000).map(x => (x / 100, x, x)).toDF("k", "v1", "v2"). + write.partitionBy("k").mode(SaveMode.Overwrite).orc(dir.toString) + val dfAll = spark.read.orc(dir.toString) + val dfK = dfAll.select("k") + val dfV1 = dfAll.select("v1") + val dfV2 = dfAll.select("v2") + val dfV1V2 = dfAll.select("v1", "v2") + + def sizeInBytes(df: DataFrame): BigInt = df.queryExecution.optimizedPlan.stats.sizeInBytes + + assert(sizeInBytes(dfAll) === BigInt(getLocalDirSize(dir))) + assert(sizeInBytes(dfK) < sizeInBytes(dfAll)) + assert(sizeInBytes(dfV1) < sizeInBytes(dfAll)) + assert(sizeInBytes(dfV2) === sizeInBytes(dfV1)) + assert(sizeInBytes(dfV1V2) < sizeInBytes(dfAll)) + } + } + } + test("File source v2: support partition pruning") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { allFileBasedDataSources.foreach { format => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala index 4e7fe8455ff93..14b59ba23d09f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala @@ -354,11 +354,11 @@ class FileScanSuite extends FileScanSuiteBase { val scanBuilders = Seq[(String, ScanBuilder, Seq[String])]( ("ParquetScan", (s, fi, ds, rds, rps, f, o, pf, df) => - ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df), + ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, None, pf, df), Seq.empty), ("OrcScan", (s, fi, ds, rds, rps, f, o, pf, df) => - OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, df), + OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, None, f, pf, df), Seq.empty), ("CSVScan", (s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f, pf, df), @@ -367,7 +367,7 @@ class FileScanSuite extends FileScanSuiteBase { (s, fi, ds, rds, rps, f, o, pf, df) => JsonScan(s, fi, ds, rds, rps, o, f, pf, df), Seq.empty), ("TextScan", - (s, fi, _, rds, rps, _, o, pf, df) => TextScan(s, fi, rds, rps, o, pf, df), + (s, fi, ds, rds, rps, _, o, pf, df) => TextScan(s, fi, ds, rds, rps, o, pf, df), Seq("dataSchema", "pushedFilters"))) run(scanBuilders) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala index 91ac7db335cc3..e9c8131fe9bec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode} @@ -97,7 +95,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): InMemoryTable = { + properties: java.util.Map[String, String]): InMemoryTable = { new InMemoryTable(name, schema, partitions, properties) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala index d5417be0f229f..e4ba33c619a7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connector -import java.util import java.util.Collections import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen} @@ -35,7 +34,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class DataSourceV2FunctionSuite extends DatasourceV2SQLBase { - private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String] private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = { catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index a326b82dbaf1e..7b941ab0d8f7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -1609,6 +1609,24 @@ class DataSourceV2SQLSuite } } + test("create table using - with sorted bucket") { + val identifier = "testcat.table_name" + withTable(identifier) { + sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" + + s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS") + val table = getTableMetadata(identifier) + val describe = spark.sql(s"DESCRIBE $identifier") + val part1 = describe + .filter("col_name = 'Part 0'") + .select("data_type").head.getString(0) + assert(part1 === "c") + val part2 = describe + .filter("col_name = 'Part 1'") + .select("data_type").head.getString(0) + assert(part2 === "bucket(4, b, a)") + } + } + test("REFRESH TABLE: v2 table") { val t = "testcat.ns1.ns2.tbl" withTable(t) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala index b42d48d873fee..cff58d7367317 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala @@ -18,11 +18,8 @@ package org.apache.spark.sql.connector import java.io.File -import java.util import java.util.OptionalLong -import scala.collection.JavaConverters._ - import test.org.apache.spark.sql.connector._ import org.apache.spark.SparkException @@ -30,7 +27,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ -import org.apache.spark.sql.connector.expressions.Transform +import org.apache.spark.sql.connector.expressions.{Literal, Transform} +import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.read._ import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper @@ -54,6 +52,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS }.head } + private def getBatchWithV2Filter(query: DataFrame): AdvancedBatchWithV2Filter = { + query.queryExecution.executedPlan.collect { + case d: BatchScanExec => + d.batch.asInstanceOf[AdvancedBatchWithV2Filter] + }.head + } + private def getJavaBatch(query: DataFrame): JavaAdvancedDataSourceV2.AdvancedBatch = { query.queryExecution.executedPlan.collect { case d: BatchScanExec => @@ -61,6 +66,14 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS }.head } + private def getJavaBatchWithV2Filter( + query: DataFrame): JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter = { + query.queryExecution.executedPlan.collect { + case d: BatchScanExec => + d.batch.asInstanceOf[JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter] + }.head + } + test("simplest implementation") { Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -131,6 +144,66 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS } } + test("advanced implementation with V2 Filter") { + Seq(classOf[AdvancedDataSourceV2WithV2Filter], classOf[JavaAdvancedDataSourceV2WithV2Filter]) + .foreach { cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + checkAnswer(df, (0 until 10).map(i => Row(i, -i))) + + val q1 = df.select('j) + checkAnswer(q1, (0 until 10).map(i => Row(-i))) + if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { + val batch = getBatchWithV2Filter(q1) + assert(batch.predicates.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) + } else { + val batch = getJavaBatchWithV2Filter(q1) + assert(batch.predicates.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) + } + + val q2 = df.filter('i > 3) + checkAnswer(q2, (4 until 10).map(i => Row(i, -i))) + if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { + val batch = getBatchWithV2Filter(q2) + assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i", "j")) + } else { + val batch = getJavaBatchWithV2Filter(q2) + assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i", "j")) + } + + val q3 = df.select('i).filter('i > 6) + checkAnswer(q3, (7 until 10).map(i => Row(i))) + if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { + val batch = getBatchWithV2Filter(q3) + assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i")) + } else { + val batch = getJavaBatchWithV2Filter(q3) + assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i")) + assert(batch.requiredSchema.fieldNames === Seq("i")) + } + + val q4 = df.select('j).filter('j < -10) + checkAnswer(q4, Nil) + if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) { + val batch = getBatchWithV2Filter(q4) + // 'j < 10 is not supported by the testing data source. + assert(batch.predicates.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) + } else { + val batch = getJavaBatchWithV2Filter(q4) + // 'j < 10 is not supported by the testing data source. + assert(batch.predicates.isEmpty) + assert(batch.requiredSchema.fieldNames === Seq("j")) + } + } + } + } + test("columnar batch scan implementation") { Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls => withClue(cls.getName) { @@ -466,7 +539,7 @@ abstract class SimpleBatchTable extends Table with SupportsRead { override def name(): String = this.getClass.toString - override def capabilities(): util.Set[TableCapability] = Set(BATCH_READ).asJava + override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ) } abstract class SimpleScanBuilder extends ScanBuilder @@ -489,7 +562,7 @@ trait TestingV2Source extends TableProvider { override def getTable( schema: StructType, partitioning: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { getTable(new CaseInsensitiveStringMap(properties)) } @@ -597,6 +670,75 @@ class AdvancedBatch(val filters: Array[Filter], val requiredSchema: StructType) } } +class AdvancedDataSourceV2WithV2Filter extends TestingV2Source { + + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new AdvancedScanBuilderWithV2Filter() + } + } +} + +class AdvancedScanBuilderWithV2Filter extends ScanBuilder + with Scan with SupportsPushDownV2Filters with SupportsPushDownRequiredColumns { + + var requiredSchema = TestingV2Source.schema + var predicates = Array.empty[Predicate] + + override def pruneColumns(requiredSchema: StructType): Unit = { + this.requiredSchema = requiredSchema + } + + override def readSchema(): StructType = requiredSchema + + override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = { + val (supported, unsupported) = predicates.partition { + case p: Predicate if p.name() == ">" => true + case _ => false + } + this.predicates = supported + unsupported + } + + override def pushedPredicates(): Array[Predicate] = predicates + + override def build(): Scan = this + + override def toBatch: Batch = new AdvancedBatchWithV2Filter(predicates, requiredSchema) +} + +class AdvancedBatchWithV2Filter( + val predicates: Array[Predicate], + val requiredSchema: StructType) extends Batch { + + override def planInputPartitions(): Array[InputPartition] = { + val lowerBound = predicates.collectFirst { + case p: Predicate if p.name().equals(">") => + val value = p.children()(1) + assert(value.isInstanceOf[Literal[_]]) + value.asInstanceOf[Literal[_]] + } + + val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition] + + if (lowerBound.isEmpty) { + res.append(RangeInputPartition(0, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get.value.asInstanceOf[Integer] < 4) { + res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 5)) + res.append(RangeInputPartition(5, 10)) + } else if (lowerBound.get.value.asInstanceOf[Integer] < 9) { + res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 10)) + } + + res.toArray + } + + override def createReaderFactory(): PartitionReaderFactory = { + new AdvancedReaderFactory(requiredSchema) + } +} + class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory { override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { @@ -640,7 +782,7 @@ class SchemaRequiredDataSource extends TableProvider { override def getTable( schema: StructType, partitioning: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala index db71eeb75eae0..e3d61a846fdb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.connector -import java.util - -import scala.collection.JavaConverters._ - import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} @@ -63,7 +59,7 @@ class TestLocalScanCatalog extends BasicInMemoryTableCatalog { ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val table = new TestLocalScanTable(ident.toString) tables.put(ident, table) table @@ -78,7 +74,8 @@ object TestLocalScanTable { class TestLocalScanTable(override val name: String) extends Table with SupportsRead { override def schema(): StructType = TestLocalScanTable.schema - override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(TableCapability.BATCH_READ) override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new TestLocalScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala index bb2acecc782b2..64c893ed74fdb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.connector import java.io.{BufferedReader, InputStreamReader, IOException} -import java.util import scala.collection.JavaConverters._ @@ -138,8 +137,8 @@ class SimpleWritableDataSource extends TestingV2Source { new MyWriteBuilder(path, info) } - override def capabilities(): util.Set[TableCapability] = - Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava + override def capabilities(): java.util.Set[TableCapability] = + java.util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE) } override def getTable(options: CaseInsensitiveStringMap): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala index ce94d3b5c2fc0..5f2e0b28aeccc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.connector -import java.util - -import scala.collection.JavaConverters._ - import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} @@ -217,7 +213,11 @@ private case object TestRelation extends LeafNode with NamedRelation { private case class CapabilityTable(_capabilities: TableCapability*) extends Table { override def name(): String = "capability_test_table" override def schema(): StructType = TableCapabilityCheckSuite.schema - override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava + override def capabilities(): java.util.Set[TableCapability] = { + val set = java.util.EnumSet.noneOf(classOf[TableCapability]) + _capabilities.foreach(set.add) + set + } } private class TestStreamSourceProvider extends StreamSourceProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala index bf2749d1afc53..0a0aaa8021996 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.connector -import java.util import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicBoolean @@ -35,7 +34,7 @@ import org.apache.spark.sql.types.StructType */ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends DelegatingCatalogExtension { - protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() + protected val tables: java.util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]() private val tableCreated: AtomicBoolean = new AtomicBoolean(false) @@ -48,7 +47,7 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): T + properties: java.util.Map[String, String]): T override def loadTable(ident: Identifier): Table = { if (tables.containsKey(ident)) { @@ -69,12 +68,12 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { val key = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY val propsWithLocation = if (properties.containsKey(key)) { // Always set a location so that CREATE EXTERNAL TABLE won't fail with LOCATION not specified. if (!properties.containsKey(TableCatalog.PROP_LOCATION)) { - val newProps = new util.HashMap[String, String]() + val newProps = new java.util.HashMap[String, String]() newProps.putAll(properties) newProps.put(TableCatalog.PROP_LOCATION, "file:/abc") newProps diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala index 847953e09cef7..c5be222645b19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql.connector -import java.util - -import scala.collection.JavaConverters._ - import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext} import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability} @@ -106,7 +102,7 @@ class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog { ident: Identifier, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): Table = { + properties: java.util.Map[String, String]): Table = { // To simplify the test implementation, only support fixed schema. if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) { throw new UnsupportedOperationException @@ -131,8 +127,8 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp override def schema(): StructType = V1ReadFallbackCatalog.schema - override def capabilities(): util.Set[TableCapability] = { - Set(TableCapability.BATCH_READ).asJava + override def capabilities(): java.util.Set[TableCapability] = { + java.util.EnumSet.of(TableCapability.BATCH_READ) } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala index 7effc747ab323..992c46cc6cdb1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.connector -import java.util - import scala.collection.JavaConverters._ import scala.collection.mutable @@ -223,7 +221,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV name: String, schema: StructType, partitions: Array[Transform], - properties: util.Map[String, String]): InMemoryTableWithV1Fallback = { + properties: java.util.Map[String, String]): InMemoryTableWithV1Fallback = { val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties) InMemoryV1Provider.tables.put(name, t) tables.put(Identifier.of(Array("default"), name), t) @@ -321,7 +319,7 @@ class InMemoryTableWithV1Fallback( override val name: String, override val schema: StructType, override val partitioning: Array[Transform], - override val properties: util.Map[String, String]) + override val properties: java.util.Map[String, String]) extends Table with SupportsWrite with SupportsRead { @@ -331,11 +329,11 @@ class InMemoryTableWithV1Fallback( } } - override def capabilities: util.Set[TableCapability] = Set( + override def capabilities: java.util.Set[TableCapability] = java.util.EnumSet.of( TableCapability.BATCH_READ, TableCapability.V1_BATCH_WRITE, TableCapability.OVERWRITE_BY_FILTER, - TableCapability.TRUNCATE).asJava + TableCapability.TRUNCATE) @volatile private var dataMap: mutable.Map[Seq[Any], Seq[Row]] = mutable.Map.empty private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala new file mode 100644 index 0000000000000..c787493fbdcc1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala @@ -0,0 +1,617 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.sql.{Date, Timestamp} + +import org.apache.spark.SparkConf +import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.sql.execution.datasources.orc.OrcTest +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.functions.min +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType, TimestampType} + +/** + * A test suite that tests aggregate push down for Parquet and ORC. + */ +trait FileSourceAggregatePushDownSuite + extends QueryTest + with FileBasedDataSourceTest + with SharedSparkSession + with ExplainSuiteHelper { + + import testImplicits._ + + protected def format: String + // The SQL config key for enabling aggregate push down. + protected val aggPushDownEnabledKey: String + + test("nested column: Max(top level column) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { + val max = sql("SELECT Max(_1) FROM t") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + } + } + } + + test("nested column: Count(top level column) push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { + val count = sql("SELECT Count(_1) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [COUNT(_1)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + + test("nested column: Max(nested sub-field) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf(aggPushDownEnabledKey-> "true") { + withDataSourceTable(data, "t") { + val max = sql("SELECT Max(_1._2[0]) FROM t") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + } + } + } + + test("nested column: Count(nested sub-field) not push down") { + val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i")))) + withSQLConf(aggPushDownEnabledKey -> "true") { + withDataSourceTable(data, "t") { + val count = sql("SELECT Count(_1._2[0]) FROM t") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + + test("Max(partition column): not push down") { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") + withSQLConf(aggPushDownEnabledKey -> "true") { + val max = sql("SELECT Max(p) FROM tmp") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + checkAnswer(max, Seq(Row(2))) + } + } + } + } + + test("Count(partition column): push down") { + withTempPath { dir => + spark.range(10).selectExpr("if(id % 2 = 0, null, id) AS n", "id % 3 as p") + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") + val enableVectorizedReader = Seq("false", "true") + for (testVectorizedReader <- enableVectorizedReader) { + withSQLConf(aggPushDownEnabledKey -> "true", + vectorizedReaderEnabledKey -> testVectorizedReader) { + val count = sql("SELECT COUNT(p) FROM tmp") + count.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [COUNT(p)]" + checkKeywordsExistsInExplain(count, expected_plan_fragment) + } + checkAnswer(count, Seq(Row(10))) + } + } + } + } + } + + test("filter alias over aggregate") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_1), MAX(_1)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(7))) + } + } + } + + test("alias over aggregate") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_1)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-1, 0))) + } + } + } + + test("aggregate over alias push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + val df = spark.table("t") + val query = df.select($"_1".as("col1")).agg(min($"col1")) + query.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_1)]" + checkKeywordsExistsInExplain(query, expected_plan_fragment) + } + checkAnswer(query, Seq(Row(-2))) + } + } + } + + test("query with group by not push down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + // aggregate not pushed down if there is group by + val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-2), Row(0), Row(2), Row(3))) + } + } + } + + test("aggregate with data filter cannot be pushed down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + // aggregate not pushed down if there is filter + val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(2))) + } + } + } + + test("aggregate with partition filter can be pushed down") { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") + Seq("false", "true").foreach { enableVectorizedReader => + withSQLConf(aggPushDownEnabledKey -> "true", + vectorizedReaderEnabledKey -> enableVectorizedReader) { + val max = sql("SELECT max(id), min(id), count(id) FROM tmp WHERE p = 0") + max.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(id), MIN(id), COUNT(id)]" + checkKeywordsExistsInExplain(max, expected_plan_fragment) + } + checkAnswer(max, Seq(Row(9, 0, 4))) + } + } + } + } + } + + test("push down only if all the aggregates can be pushed down") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 7)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + // not push down since sum can't be pushed down + val selectAgg = sql("SELECT min(_1), sum(_3) FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(-2, 41))) + } + } + } + + test("aggregate push down - MIN/MAX/COUNT") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," + + " count(*), count(_1), count(_2), count(_3) FROM t") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(_3), " + + "MAX(_3), " + + "MIN(_1), " + + "MAX(_1), " + + "COUNT(*), " + + "COUNT(_1), " + + "COUNT(_2), " + + "COUNT(_3)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + + checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6))) + } + } + } + + test("aggregate not push down - MIN/MAX/COUNT with CASE WHEN") { + val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19), + (9, "mno", 7), (2, null, 6)) + withDataSourceTable(data, "t") { + withSQLConf(aggPushDownEnabledKey -> "true") { + val selectAgg = sql( + """ + |SELECT + | min(CASE WHEN _1 < 0 THEN 0 ELSE _1 END), + | min(CASE WHEN _3 > 5 THEN 1 ELSE 0 END), + | max(CASE WHEN _1 < 0 THEN 0 ELSE _1 END), + | max(CASE WHEN NOT(_3 > 5) THEN 1 ELSE 0 END), + | count(CASE WHEN _1 < 0 AND _2 IS NOT NULL THEN 0 ELSE _1 END), + | count(CASE WHEN _3 != 5 OR _2 IS NULL THEN 1 ELSE 0 END) + |FROM t + """.stripMargin) + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + + checkAnswer(selectAgg, Seq(Row(0, 0, 9, 1, 6, 6))) + } + } + } + + private def testPushDownForAllDataTypes( + inputRows: Seq[Row], + expectedMinWithAllTypes: Seq[Row], + expectedMinWithOutTSAndBinary: Seq[Row], + expectedMaxWithAllTypes: Seq[Row], + expectedMaxWithOutTSAndBinary: Seq[Row], + expectedCount: Seq[Row]): Unit = { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val schema = StructType(List(StructField("StringCol", StringType, true), + StructField("BooleanCol", BooleanType, false), + StructField("ByteCol", ByteType, false), + StructField("BinaryCol", BinaryType, false), + StructField("ShortCol", ShortType, false), + StructField("IntegerCol", IntegerType, true), + StructField("LongCol", LongType, false), + StructField("FloatCol", FloatType, false), + StructField("DoubleCol", DoubleType, false), + StructField("DecimalCol", DecimalType(25, 5), true), + StructField("DateCol", DateType, false), + StructField("TimestampCol", TimestampType, false)).toArray) + + val rdd = sparkContext.parallelize(inputRows) + withTempPath { file => + spark.createDataFrame(rdd, schema).write.format(format).save(file.getCanonicalPath) + withTempView("test") { + spark.read.format(format).load(file.getCanonicalPath).createOrReplaceTempView("test") + Seq("false", "true").foreach { enableVectorizedReader => + withSQLConf(aggPushDownEnabledKey -> "true", + vectorizedReaderEnabledKey -> enableVectorizedReader) { + + val testMinWithAllTypes = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " + + "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + // In addition, Parquet Binary min/max could be truncated, so we disable aggregate + // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType). + // Also do not push down for ORC with same reason. + testMinWithAllTypes.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMinWithAllTypes, expected_plan_fragment) + } + + checkAnswer(testMinWithAllTypes, expectedMinWithAllTypes) + + val testMinWithOutTSAndBinary = sql("SELECT min(BooleanCol), min(ByteCol), " + + "min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " + + "min(DoubleCol), min(DateCol) FROM test") + + testMinWithOutTSAndBinary.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MIN(BooleanCol), " + + "MIN(ByteCol), " + + "MIN(ShortCol), " + + "MIN(IntegerCol), " + + "MIN(LongCol), " + + "MIN(FloatCol), " + + "MIN(DoubleCol), " + + "MIN(DateCol)]" + checkKeywordsExistsInExplain(testMinWithOutTSAndBinary, expected_plan_fragment) + } + + checkAnswer(testMinWithOutTSAndBinary, expectedMinWithOutTSAndBinary) + + val testMaxWithAllTypes = sql("SELECT max(StringCol), max(BooleanCol), " + + "max(ByteCol), max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), " + + "max(FloatCol), max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) " + + "FROM test") + + // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type + // so aggregates are not pushed down + // In addition, Parquet Binary min/max could be truncated, so we disable aggregate + // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType). + // Also do not push down for ORC with same reason. + testMaxWithAllTypes.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: []" + checkKeywordsExistsInExplain(testMaxWithAllTypes, expected_plan_fragment) + } + + checkAnswer(testMaxWithAllTypes, expectedMaxWithAllTypes) + + val testMaxWithoutTSAndBinary = sql("SELECT max(BooleanCol), max(ByteCol), " + + "max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " + + "max(DoubleCol), max(DateCol) FROM test") + + testMaxWithoutTSAndBinary.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(BooleanCol), " + + "MAX(ByteCol), " + + "MAX(ShortCol), " + + "MAX(IntegerCol), " + + "MAX(LongCol), " + + "MAX(FloatCol), " + + "MAX(DoubleCol), " + + "MAX(DateCol)]" + checkKeywordsExistsInExplain(testMaxWithoutTSAndBinary, expected_plan_fragment) + } + + checkAnswer(testMaxWithoutTSAndBinary, expectedMaxWithOutTSAndBinary) + + val testCount = sql("SELECT count(StringCol), count(BooleanCol)," + + " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," + + " count(LongCol), count(FloatCol), count(DoubleCol)," + + " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test") + + testCount.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [" + + "COUNT(StringCol), " + + "COUNT(BooleanCol), " + + "COUNT(ByteCol), " + + "COUNT(BinaryCol), " + + "COUNT(ShortCol), " + + "COUNT(IntegerCol), " + + "COUNT(LongCol), " + + "COUNT(FloatCol), " + + "COUNT(DoubleCol), " + + "COUNT(DecimalCol), " + + "COUNT(DateCol), " + + "COUNT(TimestampCol)]" + checkKeywordsExistsInExplain(testCount, expected_plan_fragment) + } + + checkAnswer(testCount, expectedCount) + } + } + } + } + } + + test("aggregate push down - different data types") { + implicit class StringToDate(s: String) { + def date: Date = Date.valueOf(s) + } + + implicit class StringToTs(s: String) { + def ts: Timestamp = Timestamp.valueOf(s) + } + + val rows = + Seq( + Row( + "a string", + true, + 10.toByte, + "Spark SQL".getBytes, + 12.toShort, + 3, + Long.MaxValue, + 0.15.toFloat, + 0.75D, + Decimal("12.345678"), + ("2021-01-01").date, + ("2015-01-01 23:50:59.123").ts), + Row( + "test string", + false, + 1.toByte, + "Parquet".getBytes, + 2.toShort, + null, + Long.MinValue, + 0.25.toFloat, + 0.85D, + Decimal("1.2345678"), + ("2015-01-01").date, + ("2021-01-01 23:50:59.123").ts), + Row( + null, + true, + 10000.toByte, + "Spark ML".getBytes, + 222.toShort, + 113, + 11111111L, + 0.25.toFloat, + 0.75D, + Decimal("12345.678"), + ("2004-06-19").date, + ("1999-08-26 10:43:59.123").ts) + ) + + testPushDownForAllDataTypes( + rows, + Seq(Row("a string", false, 1.toByte, + "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, + 1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)), + Seq(Row(false, 1.toByte, + 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, ("2004-06-19").date)), + Seq(Row("test string", true, 16.toByte, + "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, + 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)), + Seq(Row(true, 16.toByte, + 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, ("2021-01-01").date)), + Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3)) + ) + + // Test for 0 row (empty file) + val nullRow = Row.fromSeq((1 to 12).map(_ => null)) + val nullRowWithOutTSAndBinary = Row.fromSeq((1 to 8).map(_ => null)) + val zeroCount = Row.fromSeq((1 to 12).map(_ => 0)) + testPushDownForAllDataTypes(Seq.empty, Seq(nullRow), Seq(nullRowWithOutTSAndBinary), + Seq(nullRow), Seq(nullRowWithOutTSAndBinary), Seq(zeroCount)) + } + + test("column name case sensitivity") { + Seq("false", "true").foreach { enableVectorizedReader => + withSQLConf(aggPushDownEnabledKey -> "true", + vectorizedReaderEnabledKey -> enableVectorizedReader) { + withTempPath { dir => + spark.range(10).selectExpr("id", "id % 3 as p") + .write.partitionBy("p").format(format).save(dir.getCanonicalPath) + withTempView("tmp") { + spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp") + val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp") + selectAgg.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregation: [MAX(id), MIN(id)]" + checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment) + } + checkAnswer(selectAgg, Seq(Row(9, 0))) + } + } + } + } + } +} + +abstract class ParquetAggregatePushDownSuite + extends FileSourceAggregatePushDownSuite with ParquetTest { + + override def format: String = "parquet" + override protected val aggPushDownEnabledKey: String = + SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key +} + +class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet") +} + +class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") +} + +abstract class OrcAggregatePushDownSuite extends OrcTest with FileSourceAggregatePushDownSuite { + + override def format: String = "orc" + override protected val aggPushDownEnabledKey: String = + SQLConf.ORC_AGGREGATE_PUSHDOWN_ENABLED.key +} + +class OrcV1AggregatePushDownSuite extends OrcAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "orc") +} + +class OrcV2AggregatePushDownSuite extends OrcAggregatePushDownSuite { + + override protected def sparkConf: SparkConf = + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "") +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 58921485b207d..e71f3b8c35e25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -2965,16 +2965,14 @@ class JsonV2Suite extends JsonSuite { withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "true") { withTempPath { file => val scanBuilder = getBuilder(file.getCanonicalPath) - assert(scanBuilder.pushFilters(filters) === filters) - assert(scanBuilder.pushedFilters() === filters) + assert(scanBuilder.pushDataFilters(filters) === filters) } } withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "false") { withTempPath { file => val scanBuilder = getBuilder(file.getCanonicalPath) - assert(scanBuilder.pushFilters(filters) === filters) - assert(scanBuilder.pushedFilters() === Array.empty[sources.Filter]) + assert(scanBuilder.pushDataFilters(filters) === Array.empty[sources.Filter]) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala new file mode 100644 index 0000000000000..6296da47cca51 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter.Predicate +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.BooleanType + +class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession { + test("SPARK-36644: Push down boolean column filter") { + testTranslateFilter(Symbol("col").boolean, + Some(new Predicate("=", Array(FieldReference("col"), LiteralValue(true, BooleanType))))) + } + + /** + * Translate the given Catalyst [[Expression]] into data source V2 [[Predicate]] + * then verify against the given [[Predicate]]. + */ + def testTranslateFilter(catalystFilter: Expression, result: Option[Predicate]): Unit = { + assertResult(result) { + DataSourceV2Strategy.translateFilterV2(catalystFilter, true) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala new file mode 100644 index 0000000000000..2d6e6fcf16174 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, LiteralValue} +import org.apache.spark.sql.connector.expressions.filter._ +import org.apache.spark.sql.execution.datasources.v2.V2PredicateSuite.ref +import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +class V2PredicateSuite extends SparkFunSuite { + + test("nested columns") { + val predicate1 = + new Predicate("=", Array[Expression](ref("a", "B"), LiteralValue(1, IntegerType))) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a.B")) + assert(predicate1.describe.equals("a.B = 1")) + + val predicate2 = + new Predicate("=", Array[Expression](ref("a", "b.c"), LiteralValue(1, IntegerType))) + assert(predicate2.references.map(_.describe()).toSeq == Seq("a.`b.c`")) + assert(predicate2.describe.equals("a.`b.c` = 1")) + + val predicate3 = + new Predicate("=", Array[Expression](ref("`a`.b", "c"), LiteralValue(1, IntegerType))) + assert(predicate3.references.map(_.describe()).toSeq == Seq("```a``.b`.c")) + assert(predicate3.describe.equals("```a``.b`.c = 1")) + } + + test("AlwaysTrue") { + val predicate1 = new AlwaysTrue + val predicate2 = new AlwaysTrue + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).length == 0) + assert(predicate1.describe.equals("TRUE")) + } + + test("AlwaysFalse") { + val predicate1 = new AlwaysFalse + val predicate2 = new AlwaysFalse + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).length == 0) + assert(predicate1.describe.equals("FALSE")) + } + + test("EqualTo") { + val predicate1 = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + val predicate2 = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a = 1")) + } + + test("EqualNullSafe") { + val predicate1 = new Predicate("<=>", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + val predicate2 = new Predicate("<=>", Array[Expression](ref("a"), LiteralValue(1, IntegerType))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("(a = 1) OR (a IS NULL AND 1 IS NULL)")) + } + + test("In") { + val predicate1 = new Predicate("IN", + Array(ref("a"), LiteralValue(1, IntegerType), LiteralValue(2, IntegerType), + LiteralValue(3, IntegerType), LiteralValue(4, IntegerType))) + val predicate2 = new Predicate("IN", + Array(ref("a"), LiteralValue(4, IntegerType), LiteralValue(2, IntegerType), + LiteralValue(3, IntegerType), LiteralValue(1, IntegerType))) + assert(!predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a IN (1, 2, 3, 4)")) + val values: Array[Literal[_]] = new Array[Literal[_]](1000) + var expected = "a IN (" + for (i <- 0 until 1000) { + values(i) = LiteralValue(i, IntegerType) + expected += i + ", " + } + val predicate3 = new Predicate("IN", (ref("a") +: values).toArray[Expression]) + expected = expected.dropRight(2) // remove the last ", " + expected += ")" + assert(predicate3.describe.equals(expected)) + } + + test("IsNull") { + val predicate1 = new Predicate("IS_NULL", Array[Expression](ref("a"))) + val predicate2 = new Predicate("IS_NULL", Array[Expression](ref("a"))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a IS NULL")) + } + + test("IsNotNull") { + val predicate1 = new Predicate("IS_NOT_NULL", Array[Expression](ref("a"))) + val predicate2 = new Predicate("IS_NOT_NULL", Array[Expression](ref("a"))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a IS NOT NULL")) + } + + test("Not") { + val predicate1 = new Not( + new Predicate("<", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))) + val predicate2 = new Not( + new Predicate("<", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("NOT (a < 1)")) + } + + test("And") { + val predicate1 = new And( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) + val predicate2 = new And( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b")) + assert(predicate1.describe.equals("(a = 1) AND (b = 1)")) + } + + test("Or") { + val predicate1 = new Or( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) + val predicate2 = new Or( + new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))), + new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType)))) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b")) + assert(predicate1.describe.equals("(a = 1) OR (b = 1)")) + } + + test("StringStartsWith") { + val literal = LiteralValue(UTF8String.fromString("str"), StringType) + val predicate1 = new Predicate("STARTS_WITH", + Array[Expression](ref("a"), literal)) + val predicate2 = new Predicate("STARTS_WITH", + Array[Expression](ref("a"), literal)) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a LIKE 'str%'")) + } + + test("StringEndsWith") { + val literal = LiteralValue(UTF8String.fromString("str"), StringType) + val predicate1 = new Predicate("ENDS_WITH", + Array[Expression](ref("a"), literal)) + val predicate2 = new Predicate("ENDS_WITH", + Array[Expression](ref("a"), literal)) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a LIKE '%str'")) + } + + test("StringContains") { + val literal = LiteralValue(UTF8String.fromString("str"), StringType) + val predicate1 = new Predicate("CONTAINS", + Array[Expression](ref("a"), literal)) + val predicate2 = new Predicate("CONTAINS", + Array[Expression](ref("a"), literal)) + assert(predicate1.equals(predicate2)) + assert(predicate1.references.map(_.describe()).toSeq == Seq("a")) + assert(predicate1.describe.equals("a LIKE '%str%'")) + } +} + +object V2PredicateSuite { + private[sql] def ref(parts: String*): FieldReference = { + new FieldReference(parts) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala index 1a4f08418f8d3..1a52dc4da009f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala @@ -67,10 +67,10 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite { override protected def afterAll(): Unit = { val catalog = newCatalog() - catalog.dropNamespace(Array("db")) - catalog.dropNamespace(Array("db2")) - catalog.dropNamespace(Array("ns")) - catalog.dropNamespace(Array("ns2")) + catalog.dropNamespace(Array("db"), cascade = true) + catalog.dropNamespace(Array("db2"), cascade = true) + catalog.dropNamespace(Array("ns"), cascade = true) + catalog.dropNamespace(Array("ns2"), cascade = true) super.afterAll() } @@ -806,7 +806,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.listNamespaces(Array()) === Array(testNs, defaultNs)) assert(catalog.listNamespaces(testNs) === Array()) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("listNamespaces: fail if missing namespace") { @@ -844,7 +844,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) checkMetadata(metadata.asScala, Map("property" -> "value")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("loadNamespaceMetadata: empty metadata") { @@ -859,7 +859,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) checkMetadata(metadata.asScala, emptyProps.asScala) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: basic behavior") { @@ -879,7 +879,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(metadata, Map("property" -> "value")) assert(expectedPath === metadata("location")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: initialize location") { @@ -895,7 +895,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(metadata, Map.empty) assert(expectedPath === metadata("location")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: relative location") { @@ -912,7 +912,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(metadata, Map.empty) assert(expectedPath === metadata("location")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: fail if namespace already exists") { @@ -928,7 +928,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) checkMetadata(catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("createNamespace: fail nested namespace") { @@ -943,7 +943,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(exc.getMessage.contains("Invalid namespace name: db.nested")) - catalog.dropNamespace(Array("db")) + catalog.dropNamespace(Array("db"), cascade = false) } test("createTable: fail if namespace does not exist") { @@ -964,7 +964,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === false) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === false) } @@ -976,7 +976,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(catalog.namespaceExists(testNs) === true) - val ret = catalog.dropNamespace(testNs) + val ret = catalog.dropNamespace(testNs, cascade = false) assert(ret === true) assert(catalog.namespaceExists(testNs) === false) @@ -988,8 +988,8 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { catalog.createNamespace(testNs, Map("property" -> "value").asJava) catalog.createTable(testIdent, schema, Array.empty, emptyProps) - val exc = intercept[IllegalStateException] { - catalog.dropNamespace(testNs) + val exc = intercept[AnalysisException] { + catalog.dropNamespace(testNs, cascade = false) } assert(exc.getMessage.contains(testNs.quoted)) @@ -997,7 +997,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { checkMetadata(catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value")) catalog.dropTable(testIdent) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: basic behavior") { @@ -1022,7 +1022,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value")) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: update namespace location") { @@ -1045,7 +1045,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { catalog.alterNamespace(testNs, NamespaceChange.setProperty("location", "relativeP")) assert(newRelativePath === spark.catalog.getDatabase(testNs(0)).locationUri) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: update namespace comment") { @@ -1060,7 +1060,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(newComment === spark.catalog.getDatabase(testNs(0)).description) - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } test("alterNamespace: fail if namespace doesn't exist") { @@ -1087,6 +1087,6 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite { assert(exc.getMessage.contains(s"Cannot remove reserved property: $p")) } - catalog.dropNamespace(testNs) + catalog.dropNamespace(testNs, cascade = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 8842db2a2aca4..8f690eeaff901 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -24,7 +24,6 @@ import java.util.{Calendar, GregorianCalendar, Properties, TimeZone} import scala.collection.JavaConverters._ -import org.h2.jdbc.JdbcSQLException import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} @@ -38,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeTestUtils import org.apache.spark.sql.execution.{DataSourceScanExec, ExtendedMode} import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCommand} import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils} +import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation, JdbcUtils} import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ @@ -54,7 +53,8 @@ class JDBCSuite extends QueryTest val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass" var conn: java.sql.Connection = null - val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) + val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) ++ + Array.fill(15)(0.toByte) val testH2Dialect = new JdbcDialect { override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2") @@ -87,7 +87,6 @@ class JDBCSuite extends QueryTest val properties = new Properties() properties.setProperty("user", "testUser") properties.setProperty("password", "testPass") - properties.setProperty("rowId", "false") conn = DriverManager.getConnection(url, properties) conn.prepareStatement("create schema test").executeUpdate() @@ -162,7 +161,7 @@ class JDBCSuite extends QueryTest |OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass') """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)" + conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP(7))" ).executeUpdate() conn.prepareStatement("insert into test.timetypes values ('12:34:56', " + "'1996-01-01', '2002-02-20 11:22:33.543543543')").executeUpdate() @@ -177,12 +176,12 @@ class JDBCSuite extends QueryTest """.stripMargin.replaceAll("\n", " ")) conn.prepareStatement("CREATE TABLE test.timezone (tz TIMESTAMP WITH TIME ZONE) " + - "AS SELECT '1999-01-08 04:05:06.543543543 GMT-08:00'") + "AS SELECT '1999-01-08 04:05:06.543543543-08:00'") .executeUpdate() conn.commit() - conn.prepareStatement("CREATE TABLE test.array (ar ARRAY) " + - "AS SELECT '(1, 2, 3)'") + conn.prepareStatement("CREATE TABLE test.array_table (ar Integer ARRAY) " + + "AS SELECT ARRAY[1, 2, 3]") .executeUpdate() conn.commit() @@ -638,7 +637,7 @@ class JDBCSuite extends QueryTest assert(rows(0).getAs[Array[Byte]](0).sameElements(testBytes)) assert(rows(0).getString(1).equals("Sensitive")) assert(rows(0).getString(2).equals("Insensitive")) - assert(rows(0).getString(3).equals("Twenty-byte CHAR")) + assert(rows(0).getString(3).equals("Twenty-byte CHAR ")) assert(rows(0).getAs[Array[Byte]](4).sameElements(testBytes)) assert(rows(0).getString(5).equals("I am a clob!")) } @@ -729,20 +728,6 @@ class JDBCSuite extends QueryTest assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12) } - test("Pass extra properties via OPTIONS") { - // We set rowId to false during setup, which means that _ROWID_ column should be absent from - // all tables. If rowId is true (default), the query below doesn't throw an exception. - intercept[JdbcSQLException] { - sql( - s""" - |CREATE OR REPLACE TEMPORARY VIEW abc - |USING org.apache.spark.sql.jdbc - |OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)', - | user 'testUser', password 'testPass') - """.stripMargin.replaceAll("\n", " ")) - } - } - test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties()) @@ -788,33 +773,36 @@ class JDBCSuite extends QueryTest } test("compile filters") { - val compileFilter = PrivateMethod[Option[String]](Symbol("compileFilter")) def doCompileFilter(f: Filter): String = - JDBCRDD invokePrivate compileFilter(f, JdbcDialects.get("jdbc:")) getOrElse("") - assert(doCompileFilter(EqualTo("col0", 3)) === """"col0" = 3""") - assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === """(NOT ("col1" = 'abc'))""") - assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) - === """("col0" = 0) AND ("col1" = 'def')""") - assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi"))) - === """("col0" = 2) OR ("col1" = 'ghi')""") - assert(doCompileFilter(LessThan("col0", 5)) === """"col0" < 5""") - assert(doCompileFilter(LessThan("col3", - Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col3" < '1995-11-21 00:00:00.0'""") - assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) - === """"col4" < '1983-08-04'""") - assert(doCompileFilter(LessThanOrEqual("col0", 5)) === """"col0" <= 5""") - assert(doCompileFilter(GreaterThan("col0", 3)) === """"col0" > 3""") - assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === """"col0" >= 3""") - assert(doCompileFilter(In("col1", Array("jkl"))) === """"col1" IN ('jkl')""") - assert(doCompileFilter(In("col1", Array.empty)) === - """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""") - assert(doCompileFilter(Not(In("col1", Array("mno", "pqr")))) - === """(NOT ("col1" IN ('mno', 'pqr')))""") - assert(doCompileFilter(IsNull("col1")) === """"col1" IS NULL""") - assert(doCompileFilter(IsNotNull("col1")) === """"col1" IS NOT NULL""") - assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) - === """((NOT ("col0" != 'abc' OR "col0" IS NULL OR 'abc' IS NULL) """ - + """OR ("col0" IS NULL AND 'abc' IS NULL))) AND ("col1" = 'def')""") + JdbcDialects.get("jdbc:").compileExpression(f.toV2).getOrElse("") + + Seq(("col0", "col1"), ("`col0`", "`col1`")).foreach { case(col0, col1) => + assert(doCompileFilter(EqualTo(col0, 3)) === """"col0" = 3""") + assert(doCompileFilter(Not(EqualTo(col1, "abc"))) === """NOT ("col1" = 'abc')""") + assert(doCompileFilter(And(EqualTo(col0, 0), EqualTo(col1, "def"))) + === """("col0" = 0) AND ("col1" = 'def')""") + assert(doCompileFilter(Or(EqualTo(col0, 2), EqualTo(col1, "ghi"))) + === """("col0" = 2) OR ("col1" = 'ghi')""") + assert(doCompileFilter(LessThan(col0, 5)) === """"col0" < 5""") + assert(doCompileFilter(LessThan(col0, + Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col0" < '1995-11-21 00:00:00.0'""") + assert(doCompileFilter(LessThan(col0, Date.valueOf("1983-08-04"))) + === """"col0" < '1983-08-04'""") + assert(doCompileFilter(LessThanOrEqual(col0, 5)) === """"col0" <= 5""") + assert(doCompileFilter(GreaterThan(col0, 3)) === """"col0" > 3""") + assert(doCompileFilter(GreaterThanOrEqual(col0, 3)) === """"col0" >= 3""") + assert(doCompileFilter(In(col1, Array("jkl"))) === """"col1" IN ('jkl')""") + assert(doCompileFilter(In(col1, Array.empty)) === + """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""") + assert(doCompileFilter(Not(In(col1, Array("mno", "pqr")))) + === """NOT ("col1" IN ('mno', 'pqr'))""") + assert(doCompileFilter(IsNull(col1)) === """"col1" IS NULL""") + assert(doCompileFilter(IsNotNull(col1)) === """"col1" IS NOT NULL""") + assert(doCompileFilter(And(EqualNullSafe(col0, "abc"), EqualTo(col1, "def"))) + === """(("col0" = 'abc') OR ("col0" IS NULL AND 'abc' IS NULL))""" + + """ AND ("col1" = 'def')""") + } + assert(doCompileFilter(EqualTo("col0.nested", 3)).isEmpty) } test("Dialect unregister") { @@ -1375,7 +1363,7 @@ class JDBCSuite extends QueryTest }.getMessage assert(e.contains("Unsupported type TIMESTAMP_WITH_TIMEZONE")) e = intercept[SQLException] { - spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY", new Properties()).collect() + spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY_TABLE", new Properties()).collect() }.getMessage assert(e.contains("Unsupported type ARRAY")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala index 526dad91e5e19..94f044a0a6755 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala @@ -20,13 +20,14 @@ package org.apache.spark.sql.jdbc import java.sql.{Connection, DriverManager} import java.util.Properties -import org.apache.spark.SparkConf -import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row} +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException -import org.apache.spark.sql.catalyst.plans.logical.Filter -import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper} import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog -import org.apache.spark.sql.functions.{lit, sum, udf} +import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.util.Utils @@ -42,6 +43,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .set("spark.sql.catalog.h2.url", url) .set("spark.sql.catalog.h2.driver", "org.h2.Driver") .set("spark.sql.catalog.h2.pushDownAggregate", "true") + .set("spark.sql.catalog.h2.pushDownLimit", "true") private def withConnection[T](f: Connection => T): T = { val conn = DriverManager.getConnection(url, new Properties()) @@ -67,17 +69,40 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate() conn.prepareStatement( "CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," + - " bonus DOUBLE)").executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)") - .executeUpdate() - conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)") + " bonus DOUBLE, is_manager BOOLEAN)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000, true)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200, false)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200, false)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300, true)").executeUpdate() + conn.prepareStatement( + "INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200, true)").executeUpdate() + conn.prepareStatement( + "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate() + + // scalastyle:off + conn.prepareStatement( + "CREATE TABLE \"test\".\"person\" (\"名\" INTEGER NOT NULL)").executeUpdate() + // scalastyle:on + conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate() + conn.prepareStatement( + """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate() + conn.prepareStatement( + """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate() + + conn.prepareStatement( + "CREATE TABLE \"test\".\"item\" (id INTEGER, name TEXT(32), price NUMERIC(23, 3))") .executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " + + "(1, 'bottle', 11111111111111111111.123)").executeUpdate() + conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " + + "(1, 'bottle', 99999999999999999999.123)").executeUpdate() } } @@ -92,42 +117,369 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2))) } - test("scan with filter push-down") { - val df = spark.table("h2.test.people").filter($"id" > 1) - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters.isEmpty) - + private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String): Unit = { df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + checkKeywordsExistsInExplain(df, expectedPlanFragment) } + } - checkAnswer(df, Row("mary", 2)) + // TABLESAMPLE ({integer_expression | decimal_expression} PERCENT) and + // TABLESAMPLE (BUCKET integer_expression OUT OF integer_expression) + // are tested in JDBC dialect tests because TABLESAMPLE is not supported by all the DBMS + test("TABLESAMPLE (integer_expression ROWS) is the same as LIMIT") { + val df = sql("SELECT NAME FROM h2.test.employee TABLESAMPLE (3 ROWS)") + checkSchemaNames(df, Seq("NAME")) + checkPushedInfo(df, "PushedFilters: [], PushedLimit: LIMIT 3, ") + checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy"))) } - test("scan with column pruning") { - val df = spark.table("h2.test.people").select("id") + private def checkSchemaNames(df: DataFrame, names: Seq[String]): Unit = { val scan = df.queryExecution.optimizedPlan.collectFirst { case s: DataSourceV2ScanRelation => s }.get - assert(scan.schema.names.sameElements(Seq("ID"))) + assert(scan.schema.names.sameElements(names)) + } + + test("simple scan with LIMIT") { + val df1 = spark.read.table("h2.test.employee") + .where($"dept" === 1).limit(1) + checkPushedInfo(df1, + "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 1, ") + checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0, true))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .limit(1) + checkPushedInfo(df2, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ") + checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0, false))) + + val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1") + checkSchemaNames(df3, Seq("NAME")) + checkPushedInfo(df3, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ") + checkAnswer(df3, Seq(Row("alex"))) + + val df4 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .limit(1) + checkPushedInfo(df4, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT], ") + checkAnswer(df4, Seq(Row(1, 19000.00))) + + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val df5 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .limit(1) + // LIMIT is pushed down only if all the filters are pushed down + checkPushedInfo(df5, "PushedFilters: [], ") + checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy"))) + } + + private def checkSortRemoved(df: DataFrame, removed: Boolean = true): Unit = { + val sorts = df.queryExecution.optimizedPlan.collect { + case s: Sort => s + } + if (removed) { + assert(sorts.isEmpty) + } else { + assert(sorts.nonEmpty) + } + } + + test("simple scan with top N") { + val df1 = spark.read + .table("h2.test.employee") + .sort("salary") + .limit(1) + checkSortRemoved(df1) + checkPushedInfo(df1, + "PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") + checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .where($"dept" === 1) + .orderBy($"salary") + .limit(1) + checkSortRemoved(df2) + checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " + + "PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ") + checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false))) + + val df3 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .filter($"dept" > 1) + .orderBy($"salary".desc) + .limit(1) + checkSortRemoved(df3, false) + checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + + "PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ") + checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false))) + + val df4 = + sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1") + checkSchemaNames(df4, Seq("NAME")) + checkSortRemoved(df4) + checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + + "PushedTopN: ORDER BY [salary ASC NULLS LAST] LIMIT 1, ") + checkAnswer(df4, Seq(Row("david"))) + + val df5 = spark.read.table("h2.test.employee") + .where($"dept" === 1).orderBy($"salary") + checkSortRemoved(df5, false) + checkPushedInfo(df5, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ") + checkAnswer(df5, + Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true))) + + val df6 = spark.read + .table("h2.test.employee") + .groupBy("DEPT").sum("SALARY") + .orderBy("DEPT") + .limit(1) + checkSortRemoved(df6, false) + checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," + + " PushedFilters: [], PushedGroupByColumns: [DEPT], ") + checkAnswer(df6, Seq(Row(1, 19000.00))) + + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val df7 = spark.read + .table("h2.test.employee") + .select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter(name($"shortName")) + .sort($"SALARY".desc) + .limit(1) + // LIMIT is pushed down only if all the filters are pushed down + checkSortRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: [], ") + checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy"))) + + val df8 = spark.read + .table("h2.test.employee") + .sort(sub($"NAME")) + .limit(1) + checkSortRemoved(df8, false) + checkPushedInfo(df8, "PushedFilters: [], ") + checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false))) + } + + test("simple scan with top N: order by with alias") { + val df1 = spark.read + .table("h2.test.employee") + .select($"NAME", $"SALARY".as("mySalary")) + .sort("mySalary") + .limit(1) + checkSortRemoved(df1) + checkPushedInfo(df1, + "PushedFilters: [], PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") + checkAnswer(df1, Seq(Row("cathy", 9000.00))) + + val df2 = spark.read + .table("h2.test.employee") + .select($"DEPT", $"NAME", $"SALARY".as("mySalary")) + .filter($"DEPT" > 1) + .sort("mySalary") + .limit(1) + checkSortRemoved(df2) + checkPushedInfo(df2, + "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " + + "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ") + checkAnswer(df2, Seq(Row(2, "david", 10000.00))) + } + + test("scan with filter push-down") { + val df = spark.table("h2.test.people").filter($"id" > 1) + checkFiltersRemoved(df) + checkPushedInfo(df, "PushedFilters: [ID IS NOT NULL, ID > 1], ") + checkAnswer(df, Row("mary", 2)) + + val df2 = spark.table("h2.test.employee").filter($"name".isin("amy", "cathy")) + checkFiltersRemoved(df2) + checkPushedInfo(df2, "PushedFilters: [NAME IN ('amy', 'cathy')]") + checkAnswer(df2, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false))) + + val df3 = spark.table("h2.test.employee").filter($"name".startsWith("a")) + checkFiltersRemoved(df3) + checkPushedInfo(df3, "PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%']") + checkAnswer(df3, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false))) + + val df4 = spark.table("h2.test.employee").filter($"is_manager") + checkFiltersRemoved(df4) + checkPushedInfo(df4, "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true]") + checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "david", 10000, 1300, true), + Row(6, "jen", 12000, 1200, true))) + + val df5 = spark.table("h2.test.employee").filter($"is_manager".and($"salary" > 10000)) + checkFiltersRemoved(df5) + checkPushedInfo(df5, "PushedFilters: [IS_MANAGER IS NOT NULL, SALARY IS NOT NULL, " + + "IS_MANAGER = true, SALARY > 10000.00]") + checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true))) + + val df6 = spark.table("h2.test.employee").filter($"is_manager".or($"salary" > 10000)) + checkFiltersRemoved(df6) + checkPushedInfo(df6, "PushedFilters: [(IS_MANAGER = true) OR (SALARY > 10000.00)], ") + checkAnswer(df6, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false), + Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) + + val df7 = spark.table("h2.test.employee").filter(not($"is_manager") === true) + checkFiltersRemoved(df7) + checkPushedInfo(df7, "PushedFilters: [IS_MANAGER IS NOT NULL, NOT (IS_MANAGER = true)], ") + checkAnswer(df7, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "alex", 12000, 1200, false))) + + val df8 = spark.table("h2.test.employee").filter($"is_manager" === true) + checkFiltersRemoved(df8) + checkPushedInfo(df8, "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true], ") + checkAnswer(df8, Seq(Row(1, "amy", 10000, 1000, true), + Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) + + val df9 = spark.table("h2.test.employee") + .filter(when($"dept" > 1, true).when($"is_manager", false).otherwise($"dept" > 3)) + checkFiltersRemoved(df9) + checkPushedInfo(df9, "PushedFilters: [CASE WHEN DEPT > 1 THEN TRUE " + + "WHEN IS_MANAGER = true THEN FALSE ELSE DEPT > 3 END], ") + checkAnswer(df9, Seq(Row(2, "alex", 12000, 1200, false), + Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true))) + + val df10 = spark.table("h2.test.people") + .select($"NAME".as("myName"), $"ID".as("myID")) + .filter($"myID" > 1) + checkFiltersRemoved(df10) + checkPushedInfo(df10, "PushedFilters: [ID IS NOT NULL, ID > 1], ") + checkAnswer(df10, Row("mary", 2)) + } + + test("scan with filter push-down with ansi mode") { + Seq(false, true).foreach { ansiMode => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val df = spark.table("h2.test.people").filter($"id" + 1 > 1) + checkFiltersRemoved(df, ansiMode) + val expectedPlanFragment = if (ansiMode) { + "PushedFilters: [ID IS NOT NULL, (ID + 1) > 1]" + } else { + "PushedFilters: [ID IS NOT NULL]" + } + checkPushedInfo(df, expectedPlanFragment) + checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2))) + + val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 1) + checkFiltersRemoved(df2, ansiMode) + val expectedPlanFragment2 = if (ansiMode) { + "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], " + } else { + "PushedFilters: [ID IS NOT NULL], " + } + checkPushedInfo(df2, expectedPlanFragment2) + if (ansiMode) { + val e = intercept[SparkException] { + checkAnswer(df2, Seq.empty) + } + assert(e.getMessage.contains( + "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\"")) + } else { + checkAnswer(df2, Seq.empty) + } + + val df3 = sql(""" + |SELECT * FROM h2.test.employee + |WHERE (CASE WHEN SALARY > 10000 THEN BONUS ELSE BONUS + 200 END) > 1200 + |""".stripMargin) + + checkFiltersRemoved(df3, ansiMode) + val expectedPlanFragment3 = if (ansiMode) { + "PushedFilters: [(CASE WHEN SALARY > 10000.00 THEN BONUS" + + " ELSE BONUS + 200.0 END) > 1200.0]" + } else { + "PushedFilters: []" + } + checkPushedInfo(df3, expectedPlanFragment3) + checkAnswer(df3, + Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) + + val df4 = spark.table("h2.test.employee") + .filter(($"salary" > 1000d).and($"salary" < 12000d)) + checkFiltersRemoved(df4, ansiMode) + val expectedPlanFragment4 = if (ansiMode) { + "PushedFilters: [SALARY IS NOT NULL, " + + "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], " + } else { + "PushedFilters: [SALARY IS NOT NULL], " + } + checkPushedInfo(df4, expectedPlanFragment4) + checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true), + Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true))) + + val df5 = spark.table("h2.test.employee") + .filter(abs($"dept" - 3) > 1) + .filter(coalesce($"salary", $"bonus") > 2000) + checkFiltersRemoved(df5, ansiMode) + val expectedPlanFragment5 = if (ansiMode) { + "PushedFilters: [DEPT IS NOT NULL, ABS(DEPT - 3) > 1, " + + "(COALESCE(CAST(SALARY AS double), BONUS)) > 2000.0]" + } else { + "PushedFilters: [DEPT IS NOT NULL]" + } + checkPushedInfo(df5, expectedPlanFragment5) + checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true), + Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true))) + + val df6 = spark.table("h2.test.employee") + .filter(ln($"dept") > 1) + .filter(exp($"salary") > 2000) + .filter(pow($"dept", 2) > 4) + .filter(sqrt($"salary") > 100) + .filter(floor($"dept") > 1) + .filter(ceil($"dept") > 1) + checkFiltersRemoved(df6, ansiMode) + val expectedPlanFragment6 = if (ansiMode) { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " + + "LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...," + } else { + "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]" + } + checkPushedInfo(df6, expectedPlanFragment6) + checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true))) + + // H2 does not support width_bucket + val df7 = sql(""" + |SELECT * FROM h2.test.employee + |WHERE width_bucket(dept, 1, 6, 3) > 1 + |""".stripMargin) + checkFiltersRemoved(df7, false) + checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]") + checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true))) + } + } + } + + test("scan with column pruning") { + val df = spark.table("h2.test.people").select("id") + checkSchemaNames(df, Seq("ID")) checkAnswer(df, Seq(Row(1), Row(2))) } test("scan with filter push-down and column pruning") { val df = spark.table("h2.test.people").filter($"id" > 1).select("name") - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters.isEmpty) - val scan = df.queryExecution.optimizedPlan.collectFirst { - case s: DataSourceV2ScanRelation => s - }.get - assert(scan.schema.names.sameElements(Seq("NAME"))) + checkFiltersRemoved(df) + checkSchemaNames(df, Seq("NAME")) checkAnswer(df, Row("mary")) } @@ -168,7 +520,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel test("show tables") { checkAnswer(sql("SHOW TABLES IN h2.test"), Seq(Row("test", "people", false), Row("test", "empty_table", false), - Row("test", "employee", false))) + Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false), + Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false))) } test("SQL API: create table as select") { @@ -238,167 +591,195 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel } } - test("scan with aggregate push-down: MAX MIN with filter and group by") { - val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where dept > 0" + - " group by DePt") - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f + private def checkAggregateRemoved(df: DataFrame, removed: Boolean = true): Unit = { + val aggregates = df.queryExecution.optimizedPlan.collect { + case agg: Aggregate => agg } - assert(filters.isEmpty) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupby: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + if (removed) { + assert(aggregates.isEmpty) + } else { + assert(aggregates.nonEmpty) } - checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200))) } - test("scan with aggregate push-down: MAX MIN with filter without group by") { - val df = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0") + test("scan with aggregate push-down: MAX AVG with filter and group by") { + val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" + + " group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " + + "PushedGroupByColumns: [DEPT], ") + checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0))) + } + + private def checkFiltersRemoved(df: DataFrame, removed: Boolean = true): Unit = { val filters = df.queryExecution.optimizedPlan.collect { case f: Filter => f } - assert(filters.isEmpty) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(ID), MIN(ID)], " + - "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " + - "PushedGroupby: []" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + if (removed) { + assert(filters.isEmpty) + } else { + assert(filters.nonEmpty) + } + } + + test("scan with aggregate push-down: MAX AVG with filter without group by") { + val df = sql("select MAX(ID), AVG(ID) FROM h2.test.people where id > 0") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [MAX(ID), AVG(ID)], " + + "PushedFilters: [ID IS NOT NULL, ID > 0], " + + "PushedGroupByColumns: [], ") + checkAnswer(df, Seq(Row(2, 1.5))) + } + + test("partitioned scan with aggregate push-down: complete push-down only") { + withTempView("v") { + spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .createTempView("v") + val df = sql("select AVG(SALARY) FROM v GROUP BY name") + // Partitioned JDBC Scan doesn't support complete aggregate push-down, and AVG requires + // complete push-down so aggregate is not pushed at the end. + checkAggregateRemoved(df, removed = false) + checkAnswer(df, Seq(Row(9000.0), Row(10000.0), Row(10000.0), Row(12000.0), Row(12000.0))) } - checkAnswer(df, Seq(Row(2, 1))) } test("scan with aggregate push-down: aggregate + number") { val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee") + checkAggregateRemoved(df) df.queryExecution.optimizedPlan.collect { case _: DataSourceV2ScanRelation => val expected_plan_fragment = "PushedAggregates: [MAX(SALARY)]" checkKeywordsExistsInExplain(df, expected_plan_fragment) } + checkPushedInfo(df, "PushedAggregates: [MAX(SALARY)]") checkAnswer(df, Seq(Row(12001))) } test("scan with aggregate push-down: COUNT(*)") { val df = sql("select COUNT(*) FROM h2.test.employee") - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COUNT(*)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [COUNT(*)]") checkAnswer(df, Seq(Row(5))) } test("scan with aggregate push-down: COUNT(col)") { val df = sql("select COUNT(DEPT) FROM h2.test.employee") - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COUNT(DEPT)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [COUNT(DEPT)]") checkAnswer(df, Seq(Row(5))) } test("scan with aggregate push-down: COUNT(DISTINCT col)") { val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee") - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [COUNT(DISTINCT DEPT)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [COUNT(DISTINCT DEPT)]") + checkAnswer(df, Seq(Row(3))) + } + + test("scan with aggregate push-down: cannot partial push down COUNT(DISTINCT col)") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .agg(count_distinct($"DEPT")) + checkAggregateRemoved(df, false) checkAnswer(df, Seq(Row(3))) } test("scan with aggregate push-down: SUM without filer and group by") { val df = sql("SELECT SUM(SALARY) FROM h2.test.employee") - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY)]") checkAnswer(df, Seq(Row(53000))) } test("scan with aggregate push-down: DISTINCT SUM without filer and group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee") - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(DISTINCT SALARY)]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)]") checkAnswer(df, Seq(Row(31000))) } test("scan with aggregate push-down: SUM with group by") { val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT") - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [], " + - "PushedGroupby: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY)], " + + "PushedFilters: [], PushedGroupByColumns: [DEPT], ") checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } test("scan with aggregate push-down: DISTINCT SUM with group by") { val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT") - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(DISTINCT SALARY)], " + - "PushedFilters: [], " + - "PushedGroupby: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)], " + + "PushedFilters: [], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000))) } test("scan with aggregate push-down: with multiple group by columns") { val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DEPT, NAME") - val filters11 = df.queryExecution.optimizedPlan.collect { + checkFiltersRemoved(df) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") + checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), + Row(10000, 1000), Row(12000, 1200))) + } + + test("scan with aggregate push-down: with concat multiple group key in project") { + val df1 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) FROM h2.test.employee" + + " where dept > 0 group by DEPT, NAME") + val filters1 = df1.queryExecution.optimizedPlan.collect { case f: Filter => f } - assert(filters11.isEmpty) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupby: [DEPT, NAME]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) + assert(filters1.isEmpty) + checkAggregateRemoved(df1) + checkPushedInfo(df1, "PushedAggregates: [MAX(SALARY)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") + checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000), + Row("2#david", 10000), Row("6#jen", 12000))) + + val df2 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + MIN(BONUS)" + + " FROM h2.test.employee where dept > 0 group by DEPT, NAME") + val filters2 = df2.queryExecution.optimizedPlan.collect { + case f: Filter => f } - checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300), - Row(10000, 1000), Row(12000, 1200))) + assert(filters2.isEmpty) + checkAggregateRemoved(df2) + checkPushedInfo(df2, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]") + checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), + Row("2#david", 11300), Row("6#jen", 13200))) + + val df3 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + MIN(BONUS)" + + " FROM h2.test.employee where dept > 0 group by concat_ws('#', DEPT, NAME)") + checkFiltersRemoved(df3) + checkAggregateRemoved(df3, false) + checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], ") + checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200), + Row("2#david", 11300), Row("6#jen", 13200))) } test("scan with aggregate push-down: with having clause") { val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" + " group by DEPT having MIN(BONUS) > 1000") - val filters = df.queryExecution.optimizedPlan.collect { - case f: Filter => f // filter over aggregate not push down - } - assert(filters.nonEmpty) - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupby: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + // filter over aggregate not push down + checkFiltersRemoved(df, false) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200))) } @@ -406,14 +787,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = sql("select * from h2.test.employee") .groupBy($"DEPT") .min("SALARY").as("total") - df.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [MIN(SALARY)], " + - "PushedFilters: [], " + - "PushedGroupby: [DEPT]" - checkKeywordsExistsInExplain(df, expected_plan_fragment) - } + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [MIN(SALARY)], " + + "PushedFilters: [], PushedGroupByColumns: [DEPT]") checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000))) } @@ -425,18 +801,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel .agg(sum($"SALARY").as("total")) .filter($"total" > 1000) .orderBy($"total") - val filters = query.queryExecution.optimizedPlan.collect { - case f: Filter => f - } - assert(filters.nonEmpty) // filter over aggregate not pushed down - query.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY)], " + - "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " + - "PushedGroupby: [DEPT]" - checkKeywordsExistsInExplain(query, expected_plan_fragment) - } + checkFiltersRemoved(query, false)// filter over aggregate not pushed down + checkAggregateRemoved(query) + checkPushedInfo(query, "PushedAggregates: [SUM(SALARY)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000))) } @@ -444,25 +812,366 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel val df = spark.table("h2.test.employee") val decrease = udf { (x: Double, y: Double) => x - y } val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value")) - query.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: [SUM(SALARY), SUM(BONUS)" - checkKeywordsExistsInExplain(query, expected_plan_fragment) - } + checkAggregateRemoved(query) + checkPushedInfo(query, "PushedAggregates: [SUM(SALARY), SUM(BONUS)], ") checkAnswer(query, Seq(Row(47100.0))) } - test("scan with aggregate push-down: aggregate over alias NOT push down") { - val cols = Seq("a", "b", "c", "d") + test("scan with aggregate push-down: partition columns are same as group by columns") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"dept") + .count() + checkAggregateRemoved(df) + checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1))) + } + + test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") { + val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" + + " group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + } + + test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") { + val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" + + " where dept > 0 group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null))) + } + + test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") { + val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" + + " FROM h2.test.employee where dept > 0 group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null))) + } + + test("scan with aggregate push-down: CORR with filter and group by") { + val df = sql("select CORR(bonus, bonus) FROM h2.test.employee where dept > 0" + + " group by DePt") + checkFiltersRemoved(df) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [CORR(BONUS, BONUS)], " + + "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]") + checkAnswer(df, Seq(Row(1d), Row(1d), Row(null))) + } + + test("scan with aggregate push-down: aggregate over alias push down") { + val cols = Seq("a", "b", "c", "d", "e") val df1 = sql("select * from h2.test.employee").toDF(cols: _*) val df2 = df1.groupBy().sum("c") + checkAggregateRemoved(df2) df2.queryExecution.optimizedPlan.collect { - case _: DataSourceV2ScanRelation => - val expected_plan_fragment = - "PushedAggregates: []" // aggregate over alias not push down - checkKeywordsExistsInExplain(df2, expected_plan_fragment) + case relation: DataSourceV2ScanRelation => + val expectedPlanFragment = + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []" + checkKeywordsExistsInExplain(df2, expectedPlanFragment) + relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.aggregation.nonEmpty) + } } checkAnswer(df2, Seq(Row(53000.00))) } + + test("scan with aggregate push-down: aggregate with partially pushed down filters" + + "will NOT push down") { + val df = spark.table("h2.test.employee") + val name = udf { (x: String) => x.matches("cat|dav|amy") } + val sub = udf { (x: String) => x.substring(0, 3) } + val query = df.select($"SALARY", $"BONUS", sub($"NAME").as("shortName")) + .filter("SALARY > 100") + .filter(name($"shortName")) + .agg(sum($"SALARY").as("sum_salary")) + checkAggregateRemoved(query, false) + query.queryExecution.optimizedPlan.collect { + case relation: DataSourceV2ScanRelation => relation.scan match { + case v1: V1ScanWrapper => + assert(v1.pushedDownOperators.aggregation.isEmpty) + } + } + checkAnswer(query, Seq(Row(29000.0))) + } + + test("scan with aggregate push-down: aggregate function with CASE WHEN") { + val df = sql( + """ + |SELECT + | COUNT(CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY > 8000 AND SALARY <= 13000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY > 11000 OR SALARY < 10000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY >= 12000 OR SALARY < 9000 THEN SALARY ELSE 0 END), + | COUNT(CASE WHEN SALARY >= 12000 OR NOT(SALARY >= 9000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 10000) AND SALARY >= 8000 THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 10000) OR SALARY > 8000 THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 10000) AND NOT(SALARY < 8000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY != 0) OR NOT(SALARY < 8000) THEN SALARY ELSE 0 END), + | MAX(CASE WHEN NOT(SALARY > 8000 AND SALARY > 8000) THEN 0 ELSE SALARY END), + | MIN(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NULL) THEN SALARY ELSE 0 END), + | SUM(CASE WHEN SALARY > 10000 THEN 2 WHEN SALARY > 8000 THEN 1 END), + | AVG(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NOT NULL) THEN SALARY ELSE 0 END) + |FROM h2.test.employee GROUP BY DEPT + """.stripMargin) + checkAggregateRemoved(df) + checkPushedInfo(df, + "PushedAggregates: [COUNT(CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00)" + + " THEN SALARY ELSE 0.00 END), COUNT(CAS..., " + + "PushedFilters: [], " + + "PushedGroupByColumns: [DEPT], ") + checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 2, 0d), + Row(2, 2, 2, 2, 2, 10000d, 10000d, 10000d, 10000d, 10000d, 0d, 2, 0d), + Row(2, 2, 2, 2, 2, 10000d, 12000d, 10000d, 12000d, 12000d, 0d, 3, 0d))) + } + + test("scan with aggregate push-down: aggregate function with binary arithmetic") { + Seq(false, true).foreach { ansiMode => + withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) { + val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee") + checkAggregateRemoved(df, ansiMode) + val expectedPlanFragment = if (ansiMode) { + "PushedAggregates: [SUM(2147483647 + DEPT)], " + + "PushedFilters: [], " + + "PushedGroupByColumns: []" + } else { + "PushedFilters: []" + } + checkPushedInfo(df, expectedPlanFragment) + if (ansiMode) { + val e = intercept[SparkException] { + checkAnswer(df, Seq(Row(-10737418233L))) + } + assert(e.getMessage.contains( + "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\"")) + } else { + checkAnswer(df, Seq(Row(-10737418233L))) + } + } + } + } + + test("scan with aggregate push-down: aggregate function with UDF") { + val df = spark.table("h2.test.employee") + val decrease = udf { (x: Double, y: Double) => x - y } + val query = df.select(sum(decrease($"SALARY", $"BONUS")).as("value")) + checkAggregateRemoved(query, false) + checkPushedInfo(query, "PushedFilters: []") + checkAnswer(query, Seq(Row(47100.0))) + } + + test("scan with aggregate push-down: partition columns with multi group by columns") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"dept", $"name") + .count() + checkAggregateRemoved(df, false) + checkAnswer(df, Seq(Row(1, "amy", 1), Row(1, "cathy", 1), + Row(2, "alex", 1), Row(2, "david", 1), Row(6, "jen", 1))) + } + + test("scan with aggregate push-down: partition columns is different from group by columns") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"name") + .count() + checkAggregateRemoved(df, false) + checkAnswer(df, + Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1))) + } + + test("column name with composite field") { + checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2))) + val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept") + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [COUNT(`dept id`)]") + checkAnswer(df, Seq(Row(2))) + } + + test("column name with non-ascii") { + // scalastyle:off + checkAnswer(sql("SELECT `名` FROM h2.test.person"), Seq(Row(1), Row(2))) + val df = sql("SELECT COUNT(`名`) FROM h2.test.person") + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [COUNT(`名`)]") + checkAnswer(df, Seq(Row(2))) + // scalastyle:on + } + + test("scan with aggregate push-down: complete push-down SUM, AVG, COUNT") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]") + checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "1") + .table("h2.test.employee") + .groupBy($"name") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df) + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]") + checkAnswer(df2, Seq( + Row("alex", 12000.00, 12000.000000, 1), + Row("amy", 10000.00, 10000.000000, 1), + Row("cathy", 9000.00, 9000.000000, 1), + Row("david", 10000.00, 10000.000000, 1), + Row("jen", 12000.00, 12000.000000, 1))) + } + + test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") { + val df = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df, false) + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]") + checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5))) + + val df2 = spark.read + .option("partitionColumn", "dept") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .groupBy($"name") + .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count")) + checkAggregateRemoved(df, false) + checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]") + checkAnswer(df2, Seq( + Row("alex", 12000.00, 12000.000000, 1), + Row("amy", 10000.00, 10000.000000, 1), + Row("cathy", 9000.00, 9000.000000, 1), + Row("david", 10000.00, 10000.000000, 1), + Row("jen", 12000.00, 12000.000000, 1))) + } + + test("SPARK-37895: JDBC push down with delimited special identifiers") { + val df = sql( + """SELECT h2.test.view1.`|col1`, h2.test.view1.`|col2`, h2.test.view2.`|col3` + |FROM h2.test.view1 LEFT JOIN h2.test.view2 + |ON h2.test.view1.`|col1` = h2.test.view2.`|col1`""".stripMargin) + checkAnswer(df, Seq.empty[Row]) + } + + test("scan with aggregate push-down: complete push-down aggregate with alias") { + val df = spark.table("h2.test.employee") + .select($"DEPT", $"SALARY".as("mySalary")) + .groupBy($"DEPT") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df) + checkPushedInfo(df, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]") + checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + + val df2 = spark.table("h2.test.employee") + .select($"DEPT".as("myDept"), $"SALARY".as("mySalary")) + .groupBy($"myDept") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2) + checkPushedInfo(df2, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]") + checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00))) + } + + test("scan with aggregate push-down: partial push-down aggregate with alias") { + val df = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME", $"SALARY".as("mySalary")) + .groupBy($"NAME") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df, false) + checkPushedInfo(df, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]") + checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + + val df2 = spark.read + .option("partitionColumn", "DEPT") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.employee") + .select($"NAME".as("myName"), $"SALARY".as("mySalary")) + .groupBy($"myName") + .agg(sum($"mySalary").as("total")) + .filter($"total" > 1000) + checkAggregateRemoved(df2, false) + checkPushedInfo(df2, + "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]") + checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00), + Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00))) + } + + test("scan with aggregate push-down: partial push-down AVG with overflow") { + def createDataFrame: DataFrame = spark.read + .option("partitionColumn", "id") + .option("lowerBound", "0") + .option("upperBound", "2") + .option("numPartitions", "2") + .table("h2.test.item") + .agg(avg($"PRICE").as("avg")) + + Seq(true, false).foreach { ansiEnabled => + withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) { + val df = createDataFrame + checkAggregateRemoved(df, false) + df.queryExecution.optimizedPlan.collect { + case _: DataSourceV2ScanRelation => + val expected_plan_fragment = + "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]" + checkKeywordsExistsInExplain(df, expected_plan_fragment) + } + if (ansiEnabled) { + val e = intercept[SparkException] { + df.collect() + } + assert(e.getCause.isInstanceOf[ArithmeticException]) + assert(e.getCause.getMessage.contains("cannot be represented as Decimal") || + e.getCause.getMessage.contains("Overflow in sum of decimals")) + } else { + checkAnswer(df, Seq(Row(null))) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index efa2773bfd692..79952e5a6c288 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -227,7 +227,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { JdbcDialects.registerDialect(testH2Dialect) val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val m = intercept[org.h2.jdbc.JdbcSQLException] { + val m = intercept[org.h2.jdbc.JdbcSQLSyntaxErrorException] { df.write.option("createTableOptions", "ENGINE tableEngineName") .jdbc(url1, "TEST.CREATETBLOPTS", properties) }.getMessage @@ -326,7 +326,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { test("save errors if wrong user/password combination") { val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val e = intercept[org.h2.jdbc.JdbcSQLException] { + val e = intercept[org.h2.jdbc.JdbcSQLInvalidAuthorizationSpecException] { df.write.format("jdbc") .option("dbtable", "TEST.SAVETEST") .option("url", url1) @@ -427,7 +427,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { // verify the data types of the created table by reading the database catalog of H2 val query = """ - |(SELECT column_name, type_name, character_maximum_length + |(SELECT column_name, data_type, character_maximum_length | FROM information_schema.columns WHERE table_name = 'DBCOLTYPETEST') """.stripMargin val rows = spark.read.jdbc(url1, query, properties).collect() @@ -436,7 +436,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { val typeName = row.getString(1) // For CHAR and VARCHAR, we also compare the max length if (typeName.contains("CHAR")) { - val charMaxLength = row.getInt(2) + val charMaxLength = row.getLong(2) assert(expectedTypes(row.getString(0)) == s"$typeName($charMaxLength)") } else { assert(expectedTypes(row.getString(0)) == typeName) @@ -452,15 +452,18 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { val df = spark.createDataFrame(sparkContext.parallelize(data), schema) // out-of-order - val expected1 = Map("id" -> "BIGINT", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + val expected1 = + Map("id" -> "BIGINT", "first#name" -> "CHARACTER VARYING(123)", "city" -> "CHARACTER(20)") testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), id BIGINT, city CHAR(20)", expected1) // partial schema - val expected2 = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)") + val expected2 = + Map("id" -> "INTEGER", "first#name" -> "CHARACTER VARYING(123)", "city" -> "CHARACTER(20)") testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), city CHAR(20)", expected2) withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { // should still respect the original column names - val expected = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CLOB") + val expected = Map("id" -> "INTEGER", "first#name" -> "CHARACTER VARYING(123)", + "city" -> "CHARACTER LARGE OBJECT(9223372036854775807)") testUserSpecifiedColTypes(df, "`FiRsT#NaMe` VARCHAR(123)", expected) } @@ -470,7 +473,9 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter { StructField("First#Name", StringType) :: StructField("city", StringType) :: Nil) val df = spark.createDataFrame(sparkContext.parallelize(data), schema) - val expected = Map("id" -> "INTEGER", "First#Name" -> "VARCHAR(123)", "city" -> "CLOB") + val expected = + Map("id" -> "INTEGER", "First#Name" -> "CHARACTER VARYING(123)", + "city" -> "CHARACTER LARGE OBJECT(9223372036854775807)") testUserSpecifiedColTypes(df, "`First#Name` VARCHAR(123)", expected) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 0e62be40607a1..ba0b599f2245d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -22,6 +22,7 @@ import java.net.URI import java.nio.file.Files import java.util.{Locale, UUID} +import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.util.control.NonFatal @@ -459,7 +460,9 @@ private[sql] trait SQLTestUtilsBase */ def getLocalDirSize(file: File): Long = { assert(file.isDirectory) - file.listFiles.filter(f => DataSourceUtils.isDataFile(f.getName)).map(_.length).sum + Files.walk(file.toPath).iterator().asScala + .filter(p => Files.isRegularFile(p) && DataSourceUtils.isDataFile(p.getFileName.toString)) + .map(_.toFile.length).sum } } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index f1dcddd806525..dd3dabb82cc67 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 82bdeaf4e6608..e6bb5d5f49dd2 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../../pom.xml diff --git a/streaming/pom.xml b/streaming/pom.xml index 3a0f9a2f00c71..91db9435a87d4 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index c2b09a8508e2a..2d5830ad83d1c 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.12 - 3.2.0-kylin-4.x-r60 + 3.2.0-kylin-4.x-r61 ../pom.xml