$default$6"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.numericRDDToDoubleRDDFunctions"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intToIntWritable"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intWritableConverter"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.writableWritableConverter"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToAsyncRDDActions"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.boolToBoolWritable"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longToLongWritable"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleWritableConverter"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToOrderedRDDFunctions"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatWritableConverter"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.booleanWritableConverter"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringToText"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleToDoubleWritable"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesWritableConverter"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToSequenceFileRDDFunctions"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesToBytesWritable"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longWritableConverter"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringWritableConverter"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatToFloatWritable"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions$default$4"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addOnCompleteCallback"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.runningLocally"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.attemptId"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.defaultMinSplits"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.runJob"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.runJob"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.tachyonFolderName"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.initLocalProperties"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearJars"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearFiles"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.this"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith$default$2"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.toArray"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith$default$2"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.filterWith"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.foreachWith"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit$default$2"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.SequenceFileRDDFunctions.this"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.splits"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.toArray"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.externalBlockStoreFolderName"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockManager"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.storage.ExternalBlockStore")
- ) ++ Seq(
- // SPARK-12149 Added new fields to ExecutorSummary
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this")
- ) ++
- // SPARK-12665 Remove deprecated and unused classes
- Seq(
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.graphx.GraphKryoRegistrator"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$")
- ) ++ Seq(
- // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge")
- ) ++ Seq(
- // SPARK-12510 Refactor ActorReceiver to support Java
- ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver")
- ) ++ Seq(
- // SPARK-12895 Implement TaskMetrics using accumulators
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.internalMetricsToAccumulators"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectInternalAccumulators"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.collectAccumulators")
- ) ++ Seq(
- // SPARK-12896 Send only accumulator updates to driver, not TaskMetrics
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.Accumulable.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Accumulator.this"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.Accumulator.initialValue")
- ) ++ Seq(
- // SPARK-12692 Scala style: Fix the style violation (Space before "," or ":")
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log_"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkSink.org$apache$spark$streaming$flume$sink$Logging$$log__="),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log_"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.SparkAvroCallbackHandler.org$apache$spark$streaming$flume$sink$Logging$$log__="),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log__="),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$log_"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.Logging.org$apache$spark$streaming$flume$sink$Logging$$_log_="),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log_"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.flume.sink.TransactionProcessor.org$apache$spark$streaming$flume$sink$Logging$$log__=")
- ) ++ Seq(
- // SPARK-12689 Migrate DDL parsing to the newly absorbed parser
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLParser"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.execution.datasources.DDLException"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.ddlParser")
- ) ++ Seq(
- // SPARK-7799 Add "streaming-akka" project
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$6"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.zeromq.ZeroMQUtils.createStream$default$5"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$4"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream$default$3"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.actorStream"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.actorStream"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.streaming.zeromq.ZeroMQReceiver"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver$Supervisor")
- ) ++ Seq(
- // SPARK-12348 Remove deprecated Streaming APIs.
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.dstream.DStream.foreach"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions$default$4"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.awaitTermination"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.StreamingContext.networkStream"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.api.java.JavaStreamingContextFactory"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.awaitTermination"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.sc"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreachRDD"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.api.java.JavaDStreamLike.foreach"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.streaming.api.java.JavaStreamingContext.getOrCreate")
- ) ++ Seq(
- // SPARK-12847 Remove StreamingListenerBus and post all Streaming events to the same thread as Spark events
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.AsynchronousListenerBus")
- ) ++ Seq(
- // SPARK-11622 Make LibSVMRelation extends HadoopFsRelation and Add LibSVMOutputWriter
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.source.libsvm.DefaultSource"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.source.libsvm.DefaultSource.createRelation")
- ) ++ Seq(
- // SPARK-6363 Make Scala 2.11 the default Scala version
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.cleanup"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metadataCleaner"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnDriverEndpoint"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.cluster.YarnSchedulerBackend$YarnSchedulerEndpoint")
- ) ++ Seq(
- // SPARK-7889
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.org$apache$spark$deploy$history$HistoryServer$@tachSparkUI"),
- // SPARK-13296
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.UDFRegistration.register"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedPythonFunction"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.UserDefinedFunction$")
- ) ++ Seq(
- // SPARK-12995 Remove deprecated APIs in graphx
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.lib.SVDPlusPlus.runSVDPlusPlus"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.Graph.mapReduceTriplets$default$3"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.graphx.impl.GraphImpl.mapReduceTriplets")
- ) ++ Seq(
- // SPARK-13426 Remove the support of SIMR
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkMasterRegex.SIMR_REGEX")
- ) ++ Seq(
- // SPARK-13413 Remove SparkContext.metricsSystem/schedulerBackend_ setter
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.metricsSystem"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.schedulerBackend_=")
- ) ++ Seq(
- // SPARK-13220 Deprecate yarn-client and yarn-cluster mode
- ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler")
- ) ++ Seq(
- // SPARK-13465 TaskContext.
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addTaskFailureListener")
- ) ++ Seq (
- // SPARK-7729 Executor which has been killed should also be displayed on Executor Tab
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.ExecutorSummary.this")
- ) ++ Seq(
- // SPARK-13526 Move SQLContext per-session states to new class
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.sql.UDFRegistration.this")
- ) ++ Seq(
- // [SPARK-13486][SQL] Move SQLConf into an internal package
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$")
- ) ++ Seq(
- //SPARK-11011 UserDefinedType serialization should be strongly typed
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
- // SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")
- ) ++ Seq(
- // [SPARK-13244][SQL] Migrates DataFrame to Dataset
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.tables"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.sql"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.baseRelationToDataFrame"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.table"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrame.apply"),
-
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrame$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.LegacyFunctions"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.DataFrameHolder$"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.localSeqToDataFrameHolder"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.stringRddToDataFrameHolder"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.rddToDataFrameHolder"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.longRddToDataFrameHolder"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLImplicits.intRddToDataFrameHolder"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.GroupedDataset"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.subtract"),
-
- // [SPARK-14451][SQL] Move encoder definition into Aggregator interface
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.toColumn"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.bufferEncoder"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.expressions.Aggregator.outputEncoder"),
-
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MultilabelMetrics.this"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionSummary.predictions")
- ) ++ Seq(
- // [SPARK-13686][MLLIB][STREAMING] Add a constructor parameter `reqParam` to (Streaming)LinearRegressionWithSGD
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.regression.LinearRegressionWithSGD.this")
- ) ++ Seq(
- // SPARK-15250 Remove deprecated json API in DataFrameReader
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameReader.json")
- ) ++ Seq(
- // SPARK-13920: MIMA checks should apply to @Experimental and @DeveloperAPI APIs
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineCombinersByKey"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.Aggregator.combineValuesByKey"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.run"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.runJob"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ComplexFutureAction.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.actorSystem"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.cacheManager"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getConfigurationFromJobContext"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.getTaskAttemptIDFromTaskAttemptContext"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkHadoopUtil.newConfiguration"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.bytesReadCallback_="),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.canEqual"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.copy"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productArity"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productElement"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productIterator"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.productPrefix"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.setBytesReadCallback"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.updateBytesRead"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.canEqual"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.copy"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productArity"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productElement"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productIterator"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.productPrefix"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decFetchWaitTime"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decLocalBlocksFetched"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRecordsRead"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBlocksFetched"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleReadMetrics.decRemoteBytesRead"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleBytesWritten"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleRecordsWritten"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.decShuffleWriteTime"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleBytesWritten"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleRecordsWritten"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.incShuffleWriteTime"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.ShuffleWriteMetrics.setShuffleRecordsWritten"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.PCAModel.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithContext"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.AccumulableInfo.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate.taskMetrics"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.TaskInfo.attempt"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.ExperimentalMethods.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUDF"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.callUdf"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.cumeDist"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.denseRank"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.inputFileName"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.isNaN"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.percentRank"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.rowNumber"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.functions.sparkPartitionId"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.apply"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.copy"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.externalBlockStoreSize"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.BlockStatus.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsed"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatus.offHeapUsedByRdd"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.StorageStatusListener.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.streaming.scheduler.BatchInfo.streamIdToNumRecords"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.storageStatusList"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.storage.StorageListener.storageStatusList"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.apply"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.copy"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ExceptionFailure.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.InputMetrics.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.executor.OutputMetrics.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PredictionModel.transformImpl"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.extractLabeledPoints"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Predictor.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionTrainingSummary.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.GBTClassifier.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassifier.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayes.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRest.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.OneVsRestModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.RandomForestClassifier.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeans.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logLikelihood"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.logPerplexity"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.clustering.LDAModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator.evaluate"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.RegressionEvaluator.evaluate"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Binarizer.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Bucketizer.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelector.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizer.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.CountVectorizerModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.HashingTF.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDF.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IDFModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.IndexToString.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Interaction.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScaler.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.OneHotEncoder.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCA.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.PCAModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.QuantileDiscretizer.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormula.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.RFormulaModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.SQLTransformer.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StopWordsRemover.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexer.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StringIndexerModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorAssembler.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexer.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorIndexerModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.VectorSlicer.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2Vec.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALS.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegression.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.GBTRegressor.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.extractWeightedLabeledPoints"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegression.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.extractWeightedLabeledPoints"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegression.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionTrainingSummary.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressor.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplit.fit"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.TrainValidationSplitModel.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.BinaryClassificationMetrics.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.MulticlassMetrics.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.evaluation.RegressionMetrics.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.DataFrameWriter.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.broadcast"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.functions.callUDF"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.sources.InsertableRelation.insert"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.fMeasureByThreshold"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.pr"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.precisionByThreshold"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.predictions"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.recallByThreshold"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.BinaryLogisticRegressionSummary.roc"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.describeTopics"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.getVectors"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.itemFactors"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.recommendation.ALSModel.userFactors"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.predictions"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.residuals"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.name"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.scheduler.AccumulableInfo.value"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.drop"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.fill"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameNaFunctions.replace"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.jdbc"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.json"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.load"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.orc"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.parquet"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.table"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameReader.text"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.crosstab"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.freqItems"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.DataFrameStatFunctions.sampleBy"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.createExternalTable"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.emptyDataFrame"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.range"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.functions.udf"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.scheduler.JobLogger"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorHelper"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.ActorSupervisorStrategy$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.streaming.receiver.Statistics$"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.InputMetrics$"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.executor.OutputMetrics$"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.functions$"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Estimator.fit"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Predictor.train"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.Transformer.transform"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.evaluate"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListener.onOtherEvent"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.InsertableRelation.insert")
- ) ++ Seq(
- // [SPARK-13926] Automatically use Kryo serializer when shuffling RDDs with simple types
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ShuffleDependency.this"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ShuffleDependency.serializer"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.Serializer$")
- ) ++ Seq(
- // SPARK-13927: add row/column iterator to local matrices
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.rowIter"),
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.colIter")
- ) ++ Seq(
- // SPARK-13948: MiMa Check should catch if the visibility change to `private`
- // TODO(josh): Some of these may be legitimate incompatibilities; we should follow up before the 2.0.0 release
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.Dataset.toDS"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.askTimeout"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.RpcUtils.lookupTimeout"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.UnaryTransformer.transform"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.DecisionTreeClassifier.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.DecisionTreeRegressor.train"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.groupBy"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.groupBy"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.select"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.Dataset.toDF"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.Logging.initializeLogIfNecessary"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerEvent.logEvent"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.OutputWriterFactory.newInstance")
- ) ++ Seq(
- // [SPARK-14014] Replace existing analysis.Catalog with SessionCatalog
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.this")
- ) ++ Seq(
- // [SPARK-13928] Move org.apache.spark.Logging into org.apache.spark.internal.Logging
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Logging"),
- (problem: Problem) => problem match {
- case MissingTypesProblem(_, missing)
- if missing.map(_.fullName).sameElements(Seq("org.apache.spark.Logging")) => false
- case _ => true
- }
- ) ++ Seq(
- // [SPARK-13990] Automatically pick serializer when caching RDDs
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.uploadBlock")
- ) ++ Seq(
- // [SPARK-14089][CORE][MLLIB] Remove methods that has been deprecated since 1.1, 1.2, 1.3, 1.4, and 1.5
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.getThreadLocal"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeReduce"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.rdd.RDDFunctions.treeAggregate"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.tree.configuration.Strategy.defaultStategy"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLibSVMFile"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.saveLabeledData"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.util.MLUtils.loadLabeledData"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.optimization.LBFGS.setMaxNumIterations"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.setScoreCol")
- ) ++ Seq(
- // [SPARK-14205][SQL] remove trait Queryable
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.sql.Dataset")
- ) ++ Seq(
- // [SPARK-11262][ML] Unit test for gradient, loss layers, memory management
- // for multilayer perceptron.
- // This class is marked as `private`.
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction")
- ) ++ Seq(
- // [SPARK-13674][SQL] Add wholestage codegen support to Sample
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this")
- ) ++ Seq(
- // [SPARK-13430][ML] moved featureCol from LinearRegressionModelSummary to LinearRegressionSummary
- ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.this")
- ) ++ Seq(
- // [SPARK-14437][Core] Use the address that NettyBlockTransferService listens to create BlockManagerId
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.NettyBlockTransferService.this")
- ) ++ Seq(
- // [SPARK-13048][ML][MLLIB] keepLastCheckpoint option for LDA EM optimizer
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.clustering.DistributedLDAModel.this")
- ) ++ Seq(
- // [SPARK-14475] Propagate user-defined context from driver to executors
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.getLocalProperty"),
- // [SPARK-14617] Remove deprecated APIs in TaskMetrics
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.InputMetrics$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.executor.OutputMetrics$"),
- // [SPARK-14628] Simplify task metrics by always tracking read/write metrics
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.readMethod"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.writeMethod")
- ) ++ Seq(
- // SPARK-14628: Always track input/output/shuffle metrics
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.totalBlocksFetched"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.this"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.inputMetrics"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.outputMetrics"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleWriteMetrics"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.shuffleReadMetrics"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.inputMetrics"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.outputMetrics"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleWriteMetrics"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleReadMetrics"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
- ) ++ Seq(
- // SPARK-13643: Move functionality from SQLContext to SparkSession
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.getSchema")
- ) ++ Seq(
- // [SPARK-14407] Hides HadoopFsRelation related data source API into execution package
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriter"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriterFactory")
- ) ++ Seq(
- // SPARK-14734: Add conversions between mllib and ml Vector, Matrix types
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asML"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asML")
- ) ++ Seq(
- // SPARK-14704: Create accumulators in TaskMetrics
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.InputMetrics.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.executor.OutputMetrics.this")
- ) ++ Seq(
- // SPARK-14861: Replace internal usages of SQLContext with SparkSession
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.ml.clustering.LocalLDAModel.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.ml.clustering.DistributedLDAModel.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.ml.clustering.LDAModel.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem](
- "org.apache.spark.ml.clustering.LDAModel.sqlContext"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.sql.Dataset.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem](
- "org.apache.spark.sql.DataFrameReader.this")
- ) ++ Seq(
- // SPARK-14542 configurable buffer size for pipe RDD
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.rdd.RDD.pipe"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.pipe")
- ) ++ Seq(
- // [SPARK-4452][Core]Shuffle data structures can starve others on the same thread for memory
- ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.util.collection.Spillable")
- ) ++ Seq(
- // [SPARK-14952][Core][ML] Remove methods deprecated in 1.6
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.input.PortableDataStream.close"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.weights"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionModel.weights")
- ) ++ Seq(
- // [SPARK-10653] [Core] Remove unnecessary things from SparkEnv
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.sparkFilesDir"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkEnv.blockTransferService")
- ) ++ Seq(
- // SPARK-14654: New accumulator API
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ExceptionFailure$"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.apply"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.metrics"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.copy"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ExceptionFailure.this"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.remoteBlocksFetched"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.totalBlocksFetched"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.executor.ShuffleReadMetrics.localBlocksFetched"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.remoteBlocksFetched"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ShuffleReadMetrics.localBlocksFetched")
- ) ++ Seq(
- // [SPARK-14615][ML] Use the new ML Vector and Matrix in the ML pipeline based algorithms
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.getOldDocConcentration"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.estimatedDocConcentration"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.LDAModel.topicsMatrix"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.clustering.KMeansModel.clusterCenters"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LabelConverter.decodeLabel"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LabelConverter.encodeLabeledPoint"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.weights"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.predict"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.predictRaw"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.raw2probabilityInPlace"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.theta"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.pi"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.NaiveBayesModel.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.probability2prediction"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.predictRaw"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.raw2prediction"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.raw2probabilityInPlace"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.predict"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.coefficients"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.raw2prediction"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.ClassificationModel.predictRaw"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.classification.ClassificationModel.predictRaw"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.ElementwiseProduct.getScalingVec"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.ElementwiseProduct.setScalingVec"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.PCAModel.pc"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.originalMax"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.originalMin"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.MinMaxScalerModel.this"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.Word2VecModel.findSynonyms"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.IDFModel.idf"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.mean"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.this"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.std"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.predict"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.coefficients"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.predictQuantiles"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.AFTSurvivalRegressionModel.this"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.predictions"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.IsotonicRegressionModel.boundaries"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.predict"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.coefficients"),
- ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.regression.LinearRegressionModel.this")
- ) ++ Seq(
- // [SPARK-15290] Move annotations, like @Since / @DeveloperApi, into spark-tags
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.package$"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.package"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.Private"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.AlphaComponent"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.Experimental"),
- ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.annotation.DeveloperApi")
- ) ++ Seq(
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.asBreeze"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.asBreeze")
- ) ++ Seq(
- // [SPARK-15914] Binary compatibility is broken since consolidation of Dataset and DataFrame
- // in Spark 2.0. However, source level compatibility is still maintained.
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.load"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jsonRDD"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jsonFile"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.jdbc"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.parquetFile"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.SQLContext.applySchema")
- ) ++ Seq(
- // SPARK-17096: Improve exception string reported through the StreamingQueryListener
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.stackTrace"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryListener#QueryTerminated.this")
- ) ++ Seq(
- // SPARK-17406 limit timeline executor events
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorIdToData"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksActive"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksComplete"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputRecords"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleRead"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksFailed"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToShuffleWrite"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToDuration"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToInputBytes"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToLogUrls"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputBytes"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToOutputRecords"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTotalCores"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToTasksMax"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ui.exec.ExecutorsListener.executorToJvmGCTime")
- ) ++ Seq(
- // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature.
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this")
- ) ++ Seq(
- // [SPARK-17498] StringIndexer enhancement for handling unseen labels
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel")
- ) ++ Seq(
- // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext")
- ) ++ Seq(
- // [SPARK-12221] Add CPU time to metrics
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
- ) ++ Seq(
- // [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"),
- ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassifier"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassificationModel"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressor"),
- ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressionModel"),
- ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"),
- ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
- ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
- ) ++ Seq(
- // [SPARK-21680][ML][MLLIB]optimize Vector compress
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.mllib.linalg.Vector.toSparseWithSize"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Vector.toSparseWithSize")
- ) ++ Seq(
- // [SPARK-3181][ML]Implement huber loss for LinearRegression.
- ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.org$apache$spark$ml$param$shared$HasLoss$_setter_$loss_="),
- ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.getLoss"),
- ProblemFilters.exclude[InheritedNewAbstractMethodProblem]("org.apache.spark.ml.param.shared.HasLoss.loss")
- )
- }
-
def excludes(version: String) = version match {
+ case v if v.startsWith("3.4") => v34excludes
+ case v if v.startsWith("3.3") => v33excludes
case v if v.startsWith("3.2") => v32excludes
- case v if v.startsWith("3.1") => v31excludes
- case v if v.startsWith("3.0") => v30excludes
- case v if v.startsWith("2.4") => v24excludes
- case v if v.startsWith("2.3") => v23excludes
- case v if v.startsWith("2.2") => v22excludes
- case v if v.startsWith("2.1") => v21excludes
- case v if v.startsWith("2.0") => v20excludes
case _ => Seq()
}
}
diff --git a/repl/pom.xml b/repl/pom.xml
index 36d9b0e5e43aa..714fdf9d0d8a5 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../pom.xml
diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml
index 77f4385e277a2..dcd4ceace7fad 100644
--- a/resource-managers/kubernetes/core/pom.xml
+++ b/resource-managers/kubernetes/core/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../../../pom.xml
diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml
index 1d12e2ebce1c7..95ea5e12c35bc 100644
--- a/resource-managers/kubernetes/integration-tests/pom.xml
+++ b/resource-managers/kubernetes/integration-tests/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../../../pom.xml
diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml
index 301462026b190..0c764d83c503a 100644
--- a/resource-managers/mesos/pom.xml
+++ b/resource-managers/mesos/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../../pom.xml
diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml
index db7e3e03107ec..d049e217637d3 100644
--- a/resource-managers/yarn/pom.xml
+++ b/resource-managers/yarn/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../../pom.xml
diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml
index 6c089f9feb3e3..631edbd8eb3e4 100644
--- a/sql/catalyst/pom.xml
+++ b/sql/catalyst/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../../pom.xml
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java
index 34f07b12b3666..66e8a431458f9 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingCatalogExtension.java
@@ -20,10 +20,7 @@
import java.util.Map;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException;
-import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
-import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
-import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException;
+import org.apache.spark.sql.catalyst.analysis.*;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
@@ -147,8 +144,10 @@ public void alterNamespace(
}
@Override
- public boolean dropNamespace(String[] namespace) throws NoSuchNamespaceException {
- return asNamespaceCatalog().dropNamespace(namespace);
+ public boolean dropNamespace(
+ String[] namespace,
+ boolean cascade) throws NoSuchNamespaceException, NonEmptyNamespaceException {
+ return asNamespaceCatalog().dropNamespace(namespace, cascade);
}
private TableCatalog asTableCatalog() {
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java
index f70746b612e92..c1a4960068d24 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/SupportsNamespaces.java
@@ -20,6 +20,7 @@
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.catalyst.analysis.NamespaceAlreadyExistsException;
import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
+import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException;
import java.util.Map;
@@ -136,15 +137,20 @@ void alterNamespace(
NamespaceChange... changes) throws NoSuchNamespaceException;
/**
- * Drop a namespace from the catalog, recursively dropping all objects within the namespace.
+ * Drop a namespace from the catalog with cascade mode, recursively dropping all objects
+ * within the namespace if cascade is true.
*
* If the catalog implementation does not support this operation, it may throw
* {@link UnsupportedOperationException}.
*
* @param namespace a multi-part namespace
+ * @param cascade When true, deletes all objects under the namespace
* @return true if the namespace was dropped
* @throws NoSuchNamespaceException If the namespace does not exist (optional)
+ * @throws NonEmptyNamespaceException If the namespace is non-empty and cascade is false
* @throws UnsupportedOperationException If drop is not a supported operation
*/
- boolean dropNamespace(String[] namespace) throws NoSuchNamespaceException;
+ boolean dropNamespace(
+ String[] namespace,
+ boolean cascade) throws NoSuchNamespaceException, NonEmptyNamespaceException;
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java
new file mode 100644
index 0000000000000..1419e975f5695
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/SupportsIndex.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.index;
+
+import java.util.Map;
+import java.util.Properties;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.catalyst.analysis.IndexAlreadyExistsException;
+import org.apache.spark.sql.catalyst.analysis.NoSuchIndexException;
+import org.apache.spark.sql.connector.catalog.Table;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+/**
+ * Table methods for working with index
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public interface SupportsIndex extends Table {
+
+ /**
+ * A reserved property to specify the index type.
+ */
+ String PROP_TYPE = "type";
+
+ /**
+ * Creates an index.
+ *
+ * @param indexName the name of the index to be created
+ * @param columns the columns on which index to be created
+ * @param columnsProperties the properties of the columns on which index to be created
+ * @param properties the properties of the index to be created
+ * @throws IndexAlreadyExistsException If the index already exists.
+ */
+ void createIndex(String indexName,
+ NamedReference[] columns,
+ Map> columnsProperties,
+ Map properties)
+ throws IndexAlreadyExistsException;
+
+ /**
+ * Drops the index with the given name.
+ *
+ * @param indexName the name of the index to be dropped.
+ * @throws NoSuchIndexException If the index does not exist.
+ */
+ void dropIndex(String indexName) throws NoSuchIndexException;
+
+ /**
+ * Checks whether an index exists in this table.
+ *
+ * @param indexName the name of the index
+ * @return true if the index exists, false otherwise
+ */
+ boolean indexExists(String indexName);
+
+ /**
+ * Lists all the indexes in this table.
+ */
+ TableIndex[] listIndexes();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java
new file mode 100644
index 0000000000000..977ed8d6c7528
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/index/TableIndex.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.catalog.index;
+
+import java.util.Collections;
+import java.util.Map;
+import java.util.Properties;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+
+/**
+ * Index in a table
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class TableIndex {
+ private String indexName;
+ private String indexType;
+ private NamedReference[] columns;
+ private Map columnProperties = Collections.emptyMap();
+ private Properties properties;
+
+ public TableIndex(
+ String indexName,
+ String indexType,
+ NamedReference[] columns,
+ Map columnProperties,
+ Properties properties) {
+ this.indexName = indexName;
+ this.indexType = indexType;
+ this.columns = columns;
+ this.columnProperties = columnProperties;
+ this.properties = properties;
+ }
+
+ /**
+ * @return the Index name.
+ */
+ public String indexName() { return indexName; }
+
+ /**
+ * @return the indexType of this Index.
+ */
+ public String indexType() { return indexType; }
+
+ /**
+ * @return the column(s) this Index is on. Could be multi columns (a multi-column index).
+ */
+ public NamedReference[] columns() { return columns; }
+
+ /**
+ * @return the map of column and column property map.
+ */
+ public Map columnProperties() { return columnProperties; }
+
+ /**
+ * Returns the index properties.
+ */
+ public Properties properties() { return properties; }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java
new file mode 100644
index 0000000000000..26b97b46fe2ef
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Cast.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions;
+
+import java.io.Serializable;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.types.DataType;
+
+/**
+ * Represents a cast expression in the public logical expression API.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public class Cast implements Expression, Serializable {
+ private Expression expression;
+ private DataType dataType;
+
+ public Cast(Expression expression, DataType dataType) {
+ this.expression = expression;
+ this.dataType = dataType;
+ }
+
+ public Expression expression() { return expression; }
+ public DataType dataType() { return dataType; }
+
+ @Override
+ public Expression[] children() { return new Expression[]{ expression() }; }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java
index 6540c91597582..76dfe73f666cf 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java
@@ -17,6 +17,8 @@
package org.apache.spark.sql.connector.expressions;
+import java.util.Arrays;
+
import org.apache.spark.annotation.Evolving;
/**
@@ -26,8 +28,23 @@
*/
@Evolving
public interface Expression {
+ Expression[] EMPTY_EXPRESSION = new Expression[0];
+
/**
* Format the expression as a human readable SQL-like string.
*/
- String describe();
+ default String describe() { return this.toString(); }
+
+ /**
+ * Returns an array of the children of this node. Children should not change.
+ */
+ Expression[] children();
+
+ /**
+ * List of fields or columns that are referenced by this expression.
+ */
+ default NamedReference[] references() {
+ return Arrays.stream(children()).map(e -> e.references())
+ .flatMap(Arrays::stream).distinct().toArray(NamedReference[]::new);
+ }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
new file mode 100644
index 0000000000000..58082d5ee09c1
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Objects;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
+import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder;
+
+/**
+ * The general representation of SQL scalar expressions, which contains the upper-cased
+ * expression name and all the children expressions. Please also see {@link Predicate}
+ * for the supported predicate expressions.
+ *
+ * The currently supported SQL scalar expressions:
+ *
+ * - Name:
+
+ *
+ * - SQL semantic:
expr1 + expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
-
+ *
+ * - SQL semantic:
expr1 - expr2
or - expr
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
*
+ *
+ * - SQL semantic:
expr1 * expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
/
+ *
+ * - SQL semantic:
expr1 / expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
%
+ *
+ * - SQL semantic:
expr1 % expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
&
+ *
+ * - SQL semantic:
expr1 & expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
|
+ *
+ * - SQL semantic:
expr1 | expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
^
+ *
+ * - SQL semantic:
expr1 ^ expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
~
+ *
+ * - SQL semantic:
~ expr
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
CASE_WHEN
+ *
+ * - SQL semantic:
+ *
CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END
+ *
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
ABS
+ *
+ * - SQL semantic:
ABS(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
COALESCE
+ *
+ * - SQL semantic:
COALESCE(expr1, expr2)
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
LN
+ *
+ * - SQL semantic:
LN(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
EXP
+ *
+ * - SQL semantic:
EXP(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
POWER
+ *
+ * - SQL semantic:
POWER(expr, number)
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
SQRT
+ *
+ * - SQL semantic:
SQRT(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
FLOOR
+ *
+ * - SQL semantic:
FLOOR(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
CEIL
+ *
+ * - SQL semantic:
CEIL(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
WIDTH_BUCKET
+ *
+ * - SQL semantic:
WIDTH_BUCKET(expr)
+ * - Since version: 3.3.0
+ *
+ *
+ *
+ * Note: SQL semantic conforms ANSI standard, so some expressions are not supported when ANSI off,
+ * including: add, subtract, multiply, divide, remainder, pmod.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public class GeneralScalarExpression implements Expression, Serializable {
+ private String name;
+ private Expression[] children;
+
+ public GeneralScalarExpression(String name, Expression[] children) {
+ this.name = name;
+ this.children = children;
+ }
+
+ public String name() { return name; }
+ public Expression[] children() { return children; }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) return true;
+ if (o == null || getClass() != o.getClass()) return false;
+ GeneralScalarExpression that = (GeneralScalarExpression) o;
+ return Objects.equals(name, that.name) && Arrays.equals(children, that.children);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(name, children);
+ }
+
+ @Override
+ public String toString() {
+ V2ExpressionSQLBuilder builder = new V2ExpressionSQLBuilder();
+ try {
+ return builder.build(this);
+ } catch (Throwable e) {
+ return name + "(" +
+ Arrays.stream(children).map(child -> child.toString()).reduce((a,b) -> a + "," + b) + ")";
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java
index df9e58fa319fd..5e8aeafe74515 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Literal.java
@@ -40,4 +40,7 @@ public interface Literal extends Expression {
* Returns the SQL data type of the literal.
*/
DataType dataType();
+
+ @Override
+ default Expression[] children() { return EMPTY_EXPRESSION; }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java
index 167432fa0e86a..8c0f029a35832 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/NamedReference.java
@@ -32,4 +32,10 @@ public interface NamedReference extends Expression {
* Each string in the returned array represents a field name.
*/
String[] fieldNames();
+
+ @Override
+ default Expression[] children() { return EMPTY_EXPRESSION; }
+
+ @Override
+ default NamedReference[] references() { return new NamedReference[]{ this }; }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java
index 72252457df26e..51401786ca5d7 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/SortOrder.java
@@ -40,4 +40,7 @@ public interface SortOrder extends Expression {
* Returns the null ordering.
*/
NullOrdering nullOrdering();
+
+ @Override
+ default Expression[] children() { return new Expression[]{ expression() }; }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java
index 297205825c6a4..e9ead7fc5fd2a 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java
@@ -34,13 +34,11 @@ public interface Transform extends Expression {
*/
String name();
- /**
- * Returns all field references in the transform arguments.
- */
- NamedReference[] references();
-
/**
* Returns the arguments passed to the transform function.
*/
Expression[] arguments();
+
+ @Override
+ default Expression[] children() { return arguments(); }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java
new file mode 100644
index 0000000000000..d09e5f7ba28a3
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Avg.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.aggregate;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Expression;
+
+/**
+ * An aggregate function that returns the mean of all the values in a group.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class Avg implements AggregateFunc {
+ private final Expression input;
+ private final boolean isDistinct;
+
+ public Avg(Expression column, boolean isDistinct) {
+ this.input = column;
+ this.isDistinct = isDistinct;
+ }
+
+ public Expression column() { return input; }
+ public boolean isDistinct() { return isDistinct; }
+
+ @Override
+ public Expression[] children() { return new Expression[]{ input }; }
+
+ @Override
+ public String toString() {
+ if (isDistinct) {
+ return "AVG(DISTINCT " + input.describe() + ")";
+ } else {
+ return "AVG(" + input.describe() + ")";
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java
index 1273886e297bf..c840b29ad2546 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Count.java
@@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.expressions.aggregate;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.Expression;
/**
* An aggregate function that returns the number of the specific row in a group.
@@ -27,26 +27,26 @@
*/
@Evolving
public final class Count implements AggregateFunc {
- private final NamedReference column;
+ private final Expression input;
private final boolean isDistinct;
- public Count(NamedReference column, boolean isDistinct) {
- this.column = column;
+ public Count(Expression column, boolean isDistinct) {
+ this.input = column;
this.isDistinct = isDistinct;
}
- public NamedReference column() { return column; }
+ public Expression column() { return input; }
public boolean isDistinct() { return isDistinct; }
+ @Override
+ public Expression[] children() { return new Expression[]{ input }; }
+
@Override
public String toString() {
if (isDistinct) {
- return "COUNT(DISTINCT " + column.describe() + ")";
+ return "COUNT(DISTINCT " + input.describe() + ")";
} else {
- return "COUNT(" + column.describe() + ")";
+ return "COUNT(" + input.describe() + ")";
}
}
-
- @Override
- public String describe() { return this.toString(); }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java
index f566ad164b8ef..ff8639cbd05a2 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/CountStar.java
@@ -18,6 +18,7 @@
package org.apache.spark.sql.connector.expressions.aggregate;
import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Expression;
/**
* An aggregate function that returns the number of rows in a group.
@@ -31,8 +32,8 @@ public CountStar() {
}
@Override
- public String toString() { return "COUNT(*)"; }
+ public Expression[] children() { return EMPTY_EXPRESSION; }
@Override
- public String describe() { return this.toString(); }
+ public String toString() { return "COUNT(*)"; }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
new file mode 100644
index 0000000000000..7016644543447
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.aggregate;
+
+import java.util.Arrays;
+import java.util.stream.Collectors;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Expression;
+
+/**
+ * The general implementation of {@link AggregateFunc}, which contains the upper-cased function
+ * name, the `isDistinct` flag and all the inputs. Note that Spark cannot push down partial
+ * aggregate with this function to the source, but can only push down the entire aggregate.
+ *
+ * The currently supported SQL aggregate functions:
+ *
+ * VAR_POP(input1)
Since 3.3.0
+ * VAR_SAMP(input1)
Since 3.3.0
+ * STDDEV_POP(input1)
Since 3.3.0
+ * STDDEV_SAMP(input1)
Since 3.3.0
+ * COVAR_POP(input1, input2)
Since 3.3.0
+ * COVAR_SAMP(input1, input2)
Since 3.3.0
+ * CORR(input1, input2)
Since 3.3.0
+ *
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class GeneralAggregateFunc implements AggregateFunc {
+ private final String name;
+ private final boolean isDistinct;
+ private final Expression[] children;
+
+ public String name() { return name; }
+ public boolean isDistinct() { return isDistinct; }
+
+ public GeneralAggregateFunc(String name, boolean isDistinct, Expression[] children) {
+ this.name = name;
+ this.isDistinct = isDistinct;
+ this.children = children;
+ }
+
+ @Override
+ public Expression[] children() { return children; }
+
+ @Override
+ public String toString() {
+ String inputsString = Arrays.stream(children)
+ .map(Expression::describe)
+ .collect(Collectors.joining(", "));
+ if (isDistinct) {
+ return name + "(DISTINCT " + inputsString + ")";
+ } else {
+ return name + "(" + inputsString + ")";
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java
index ed07cc9e32187..089d2bd751763 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Max.java
@@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.expressions.aggregate;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.Expression;
/**
* An aggregate function that returns the maximum value in a group.
@@ -27,15 +27,15 @@
*/
@Evolving
public final class Max implements AggregateFunc {
- private final NamedReference column;
+ private final Expression input;
- public Max(NamedReference column) { this.column = column; }
+ public Max(Expression column) { this.input = column; }
- public NamedReference column() { return column; }
+ public Expression column() { return input; }
@Override
- public String toString() { return "MAX(" + column.describe() + ")"; }
+ public Expression[] children() { return new Expression[]{ input }; }
@Override
- public String describe() { return this.toString(); }
+ public String toString() { return "MAX(" + input.describe() + ")"; }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java
index 2e761037746fb..253cdea41dd76 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Min.java
@@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.expressions.aggregate;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.Expression;
/**
* An aggregate function that returns the minimum value in a group.
@@ -27,15 +27,15 @@
*/
@Evolving
public final class Min implements AggregateFunc {
- private final NamedReference column;
+ private final Expression input;
- public Min(NamedReference column) { this.column = column; }
+ public Min(Expression column) { this.input = column; }
- public NamedReference column() { return column; }
+ public Expression column() { return input; }
@Override
- public String toString() { return "MIN(" + column.describe() + ")"; }
+ public Expression[] children() { return new Expression[]{ input }; }
@Override
- public String describe() { return this.toString(); }
+ public String toString() { return "MIN(" + input.describe() + ")"; }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java
index 057ebd89f7a19..4e01b92d8c369 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/Sum.java
@@ -18,7 +18,7 @@
package org.apache.spark.sql.connector.expressions.aggregate;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.Expression;
/**
* An aggregate function that returns the summation of all the values in a group.
@@ -27,26 +27,26 @@
*/
@Evolving
public final class Sum implements AggregateFunc {
- private final NamedReference column;
+ private final Expression input;
private final boolean isDistinct;
- public Sum(NamedReference column, boolean isDistinct) {
- this.column = column;
+ public Sum(Expression column, boolean isDistinct) {
+ this.input = column;
this.isDistinct = isDistinct;
}
- public NamedReference column() { return column; }
+ public Expression column() { return input; }
public boolean isDistinct() { return isDistinct; }
+ @Override
+ public Expression[] children() { return new Expression[]{ input }; }
+
@Override
public String toString() {
if (isDistinct) {
- return "SUM(DISTINCT " + column.describe() + ")";
+ return "SUM(DISTINCT " + input.describe() + ")";
} else {
- return "SUM(" + column.describe() + ")";
+ return "SUM(" + input.describe() + ")";
}
}
-
- @Override
- public String describe() { return this.toString(); }
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java
new file mode 100644
index 0000000000000..accdd1acd7d0e
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.filter;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Literal;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+
+/**
+ * A predicate that always evaluates to {@code false}.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class AlwaysFalse extends Predicate implements Literal {
+
+ public AlwaysFalse() {
+ super("ALWAYS_FALSE", new Predicate[]{});
+ }
+
+ public Boolean value() {
+ return false;
+ }
+
+ public DataType dataType() {
+ return DataTypes.BooleanType;
+ }
+
+ public String toString() { return "FALSE"; }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java
new file mode 100644
index 0000000000000..5a14f64b9b7e2
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.filter;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Literal;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+
+/**
+ * A predicate that always evaluates to {@code true}.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class AlwaysTrue extends Predicate implements Literal {
+
+ public AlwaysTrue() {
+ super("ALWAYS_TRUE", new Predicate[]{});
+ }
+
+ public Boolean value() {
+ return true;
+ }
+
+ public DataType dataType() {
+ return DataTypes.BooleanType;
+ }
+
+ public String toString() { return "TRUE"; }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java
new file mode 100644
index 0000000000000..179a4b3c6349d
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/And.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.filter;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * A predicate that evaluates to {@code true} iff both {@code left} and {@code right} evaluate to
+ * {@code true}.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class And extends Predicate {
+
+ public And(Predicate left, Predicate right) {
+ super("AND", new Predicate[]{left, right});
+ }
+
+ public Predicate left() { return (Predicate) children()[0]; }
+ public Predicate right() { return (Predicate) children()[1]; }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java
new file mode 100644
index 0000000000000..d65c9f0b6c3d9
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Not.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.filter;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * A predicate that evaluates to {@code true} iff {@code child} is evaluated to {@code false}.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class Not extends Predicate {
+
+ public Not(Predicate child) {
+ super("NOT", new Predicate[]{child});
+ }
+
+ public Predicate child() { return (Predicate) children()[0]; }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java
new file mode 100644
index 0000000000000..7f1717cc7da58
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Or.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.filter;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * A predicate that evaluates to {@code true} iff at least one of {@code left} or {@code right}
+ * evaluates to {@code true}.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public final class Or extends Predicate {
+
+ public Or(Predicate left, Predicate right) {
+ super("OR", new Predicate[]{left, right});
+ }
+
+ public Predicate left() { return (Predicate) children()[0]; }
+ public Predicate right() { return (Predicate) children()[1]; }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java
new file mode 100644
index 0000000000000..e58cddc274c5f
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/Predicate.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.expressions.filter;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
+
+/**
+ * The general representation of predicate expressions, which contains the upper-cased expression
+ * name and all the children expressions. You can also use these concrete subclasses for better
+ * type safety: {@link And}, {@link Or}, {@link Not}, {@link AlwaysTrue}, {@link AlwaysFalse}.
+ *
+ * The currently supported predicate expressions:
+ *
+ * - Name:
IS_NULL
+ *
+ * - SQL semantic:
expr IS NULL
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
IS_NOT_NULL
+ *
+ * - SQL semantic:
expr IS NOT NULL
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
STARTS_WITH
+ *
+ * - SQL semantic:
expr1 LIKE 'expr2%'
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
ENDS_WITH
+ *
+ * - SQL semantic:
expr1 LIKE '%expr2'
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
CONTAINS
+ *
+ * - SQL semantic:
expr1 LIKE '%expr2%'
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
IN
+ *
+ * - SQL semantic:
expr IN (expr1, expr2, ...)
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
=
+ *
+ * - SQL semantic:
expr1 = expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
<>
+ *
+ * - SQL semantic:
expr1 <> expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
<=>
+ *
+ * - SQL semantic: null-safe version of
expr1 = expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
<
+ *
+ * - SQL semantic:
expr1 < expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
<=
+ *
+ * - SQL semantic:
expr1 <= expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
>
+ *
+ * - SQL semantic:
expr1 > expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
>=
+ *
+ * - SQL semantic:
expr1 >= expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
AND
+ *
+ * - SQL semantic:
expr1 AND expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
OR
+ *
+ * - SQL semantic:
expr1 OR expr2
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
NOT
+ *
+ * - SQL semantic:
NOT expr
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
ALWAYS_TRUE
+ *
+ * - SQL semantic:
TRUE
+ * - Since version: 3.3.0
+ *
+ *
+ * - Name:
ALWAYS_FALSE
+ *
+ * - SQL semantic:
FALSE
+ * - Since version: 3.3.0
+ *
+ *
+ *
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public class Predicate extends GeneralScalarExpression {
+
+ public Predicate(String name, Expression[] children) {
+ super(name, children);
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
index b46f620d4fedb..27ee534d804ff 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/ScanBuilder.java
@@ -21,9 +21,9 @@
/**
* An interface for building the {@link Scan}. Implementations can mixin SupportsPushDownXYZ
- * interfaces to do operator pushdown, and keep the operator pushdown result in the returned
- * {@link Scan}. When pushing down operators, Spark pushes down filters first, then pushes down
- * aggregates or applies column pruning.
+ * interfaces to do operator push down, and keep the operator push down result in the returned
+ * {@link Scan}. When pushing down operators, the push down order is:
+ * sample -> filter -> aggregate -> limit -> column pruning.
*
* @since 3.0.0
*/
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
index 3e643b5493310..4d88ec19c897b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownAggregates.java
@@ -22,18 +22,20 @@
/**
* A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
- * push down aggregates. Spark assumes that the data source can't fully complete the
- * grouping work, and will group the data source output again. For queries like
- * "SELECT min(value) AS m FROM t GROUP BY key", after pushing down the aggregate
- * to the data source, the data source can still output data with duplicated keys, which is OK
- * as Spark will do GROUP BY key again. The final query plan can be something like this:
+ * push down aggregates.
+ *
+ * If the data source can't fully complete the grouping work, then
+ * {@link #supportCompletePushDown(Aggregation)} should return false, and Spark will group the data
+ * source output again. For queries like "SELECT min(value) AS m FROM t GROUP BY key", after
+ * pushing down the aggregate to the data source, the data source can still output data with
+ * duplicated keys, which is OK as Spark will do GROUP BY key again. The final query plan can be
+ * something like this:
*
- * Aggregate [key#1], [min(min(value)#2) AS m#3]
- * +- RelationV2[key#1, min(value)#2]
+ * Aggregate [key#1], [min(min_value#2) AS m#3]
+ * +- RelationV2[key#1, min_value#2]
*
* Similarly, if there is no grouping expression, the data source can still output more than one
* rows.
- *
*
* When pushing down operators, Spark pushes down filters to the data source first, then push down
* aggregates or apply column pruning. Depends on data source implementation, aggregates may or
@@ -45,11 +47,21 @@
@Evolving
public interface SupportsPushDownAggregates extends ScanBuilder {
+ /**
+ * Whether the datasource support complete aggregation push-down. Spark will do grouping again
+ * if this method returns false.
+ *
+ * @param aggregation Aggregation in SQL statement.
+ * @return true if the aggregation can be pushed down to datasource completely, false otherwise.
+ */
+ default boolean supportCompletePushDown(Aggregation aggregation) { return false; }
+
/**
* Pushes down Aggregation to datasource. The order of the datasource scan output columns should
* be: grouping columns, aggregate columns (in the same order as the aggregate functions in
* the given Aggregation).
*
+ * @param aggregation Aggregation in SQL statement.
* @return true if the aggregation can be pushed down to datasource, false otherwise.
*/
boolean pushAggregation(Aggregation aggregation);
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java
new file mode 100644
index 0000000000000..fa6447bc068d5
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownLimit.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
+ * push down LIMIT. Please note that the combination of LIMIT with other operations
+ * such as AGGREGATE, GROUP BY, SORT BY, CLUSTER BY, DISTRIBUTE BY, etc. is NOT pushed down.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public interface SupportsPushDownLimit extends ScanBuilder {
+
+ /**
+ * Pushes down LIMIT to the data source.
+ */
+ boolean pushLimit(int limit);
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
new file mode 100644
index 0000000000000..3630feb4680ea
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTableSample.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+
+/**
+ * A mix-in interface for {@link Scan}. Data sources can implement this interface to
+ * push down SAMPLE.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public interface SupportsPushDownTableSample extends ScanBuilder {
+
+ /**
+ * Pushes down SAMPLE to the data source.
+ */
+ boolean pushTableSample(
+ double lowerBound,
+ double upperBound,
+ boolean withReplacement,
+ long seed);
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java
new file mode 100644
index 0000000000000..cba1592c4fa14
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownTopN.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.SortOrder;
+
+/**
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
+ * push down top N(query with ORDER BY ... LIMIT n). Please note that the combination of top N
+ * with other operations such as AGGREGATE, GROUP BY, CLUSTER BY, DISTRIBUTE BY, etc.
+ * is NOT pushed down.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public interface SupportsPushDownTopN extends ScanBuilder {
+
+ /**
+ * Pushes down top N to the data source.
+ */
+ boolean pushTopN(SortOrder[] orders, int limit);
+
+ /**
+ * Whether the top N is partially pushed or not. If it returns true, then Spark will do top N
+ * again. This method will only be called when {@link #pushTopN} returns true.
+ */
+ default boolean isPartiallyPushed() { return true; }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java
new file mode 100644
index 0000000000000..1fec939aeb474
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/read/SupportsPushDownV2Filters.java
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.read;
+
+import org.apache.spark.annotation.Evolving;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
+
+/**
+ * A mix-in interface for {@link ScanBuilder}. Data sources can implement this interface to
+ * push down V2 {@link Predicate} to the data source and reduce the size of the data to be read.
+ * Please Note that this interface is preferred over {@link SupportsPushDownFilters}, which uses
+ * V1 {@link org.apache.spark.sql.sources.Filter} and is less efficient due to the
+ * internal -> external data conversion.
+ *
+ * @since 3.3.0
+ */
+@Evolving
+public interface SupportsPushDownV2Filters extends ScanBuilder {
+
+ /**
+ * Pushes down predicates, and returns predicates that need to be evaluated after scanning.
+ *
+ * Rows should be returned from the data source if and only if all of the predicates match.
+ * That is, predicates must be interpreted as ANDed together.
+ */
+ Predicate[] pushPredicates(Predicate[] predicates);
+
+ /**
+ * Returns the predicates that are pushed to the data source via
+ * {@link #pushPredicates(Predicate[])}.
+ *
+ * There are 3 kinds of predicates:
+ *
+ * - pushable predicates which don't need to be evaluated again after scanning.
+ * - pushable predicates which still need to be evaluated after scanning, e.g. parquet row
+ * group predicate.
+ * - non-pushable predicates.
+ *
+ *
+ * Both case 1 and 2 should be considered as pushed predicates and should be returned
+ * by this method.
+ *
+ * It's possible that there is no predicates in the query and
+ * {@link #pushPredicates(Predicate[])} is never called,
+ * empty array should be returned for this case.
+ */
+ Predicate[] pushedPredicates();
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
new file mode 100644
index 0000000000000..c9dfa2003e3c1
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
@@ -0,0 +1,231 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector.util;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import org.apache.spark.sql.connector.expressions.Cast;
+import org.apache.spark.sql.connector.expressions.Expression;
+import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
+import org.apache.spark.sql.connector.expressions.Literal;
+import org.apache.spark.sql.types.DataType;
+
+/**
+ * The builder to generate SQL from V2 expressions.
+ */
+public class V2ExpressionSQLBuilder {
+
+ public String build(Expression expr) {
+ if (expr instanceof Literal) {
+ return visitLiteral((Literal) expr);
+ } else if (expr instanceof NamedReference) {
+ return visitNamedReference((NamedReference) expr);
+ } else if (expr instanceof Cast) {
+ Cast cast = (Cast) expr;
+ return visitCast(build(cast.expression()), cast.dataType());
+ } else if (expr instanceof GeneralScalarExpression) {
+ GeneralScalarExpression e = (GeneralScalarExpression) expr;
+ String name = e.name();
+ switch (name) {
+ case "IN": {
+ List children =
+ Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList());
+ return visitIn(children.get(0), children.subList(1, children.size()));
+ }
+ case "IS_NULL":
+ return visitIsNull(build(e.children()[0]));
+ case "IS_NOT_NULL":
+ return visitIsNotNull(build(e.children()[0]));
+ case "STARTS_WITH":
+ return visitStartsWith(build(e.children()[0]), build(e.children()[1]));
+ case "ENDS_WITH":
+ return visitEndsWith(build(e.children()[0]), build(e.children()[1]));
+ case "CONTAINS":
+ return visitContains(build(e.children()[0]), build(e.children()[1]));
+ case "=":
+ case "<>":
+ case "<=>":
+ case "<":
+ case "<=":
+ case ">":
+ case ">=":
+ return visitBinaryComparison(
+ name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1]));
+ case "+":
+ case "*":
+ case "/":
+ case "%":
+ case "&":
+ case "|":
+ case "^":
+ return visitBinaryArithmetic(
+ name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1]));
+ case "-":
+ if (e.children().length == 1) {
+ return visitUnaryArithmetic(name, inputToSQL(e.children()[0]));
+ } else {
+ return visitBinaryArithmetic(
+ name, inputToSQL(e.children()[0]), inputToSQL(e.children()[1]));
+ }
+ case "AND":
+ return visitAnd(name, build(e.children()[0]), build(e.children()[1]));
+ case "OR":
+ return visitOr(name, build(e.children()[0]), build(e.children()[1]));
+ case "NOT":
+ return visitNot(build(e.children()[0]));
+ case "~":
+ return visitUnaryArithmetic(name, inputToSQL(e.children()[0]));
+ case "ABS":
+ case "COALESCE":
+ case "LN":
+ case "EXP":
+ case "POWER":
+ case "SQRT":
+ case "FLOOR":
+ case "CEIL":
+ case "WIDTH_BUCKET":
+ return visitSQLFunction(name,
+ Arrays.stream(e.children()).map(c -> build(c)).toArray(String[]::new));
+ case "CASE_WHEN": {
+ List children =
+ Arrays.stream(e.children()).map(c -> build(c)).collect(Collectors.toList());
+ return visitCaseWhen(children.toArray(new String[e.children().length]));
+ }
+ // TODO supports other expressions
+ default:
+ return visitUnexpectedExpr(expr);
+ }
+ } else {
+ return visitUnexpectedExpr(expr);
+ }
+ }
+
+ protected String visitLiteral(Literal literal) {
+ return literal.toString();
+ }
+
+ protected String visitNamedReference(NamedReference namedRef) {
+ return namedRef.toString();
+ }
+
+ protected String visitIn(String v, List list) {
+ if (list.isEmpty()) {
+ return "CASE WHEN " + v + " IS NULL THEN NULL ELSE FALSE END";
+ }
+ return v + " IN (" + list.stream().collect(Collectors.joining(", ")) + ")";
+ }
+
+ protected String visitIsNull(String v) {
+ return v + " IS NULL";
+ }
+
+ protected String visitIsNotNull(String v) {
+ return v + " IS NOT NULL";
+ }
+
+ protected String visitStartsWith(String l, String r) {
+ // Remove quotes at the beginning and end.
+ // e.g. converts "'str'" to "str".
+ String value = r.substring(1, r.length() - 1);
+ return l + " LIKE '" + value + "%'";
+ }
+
+ protected String visitEndsWith(String l, String r) {
+ // Remove quotes at the beginning and end.
+ // e.g. converts "'str'" to "str".
+ String value = r.substring(1, r.length() - 1);
+ return l + " LIKE '%" + value + "'";
+ }
+
+ protected String visitContains(String l, String r) {
+ // Remove quotes at the beginning and end.
+ // e.g. converts "'str'" to "str".
+ String value = r.substring(1, r.length() - 1);
+ return l + " LIKE '%" + value + "%'";
+ }
+
+ private String inputToSQL(Expression input) {
+ if (input.children().length > 1) {
+ return "(" + build(input) + ")";
+ } else {
+ return build(input);
+ }
+ }
+
+ protected String visitBinaryComparison(String name, String l, String r) {
+ switch (name) {
+ case "<=>":
+ return "(" + l + " = " + r + ") OR (" + l + " IS NULL AND " + r + " IS NULL)";
+ default:
+ return l + " " + name + " " + r;
+ }
+ }
+
+ protected String visitBinaryArithmetic(String name, String l, String r) {
+ return l + " " + name + " " + r;
+ }
+
+ protected String visitCast(String l, DataType dataType) {
+ return "CAST(" + l + " AS " + dataType.typeName() + ")";
+ }
+
+ protected String visitAnd(String name, String l, String r) {
+ return "(" + l + ") " + name + " (" + r + ")";
+ }
+
+ protected String visitOr(String name, String l, String r) {
+ return "(" + l + ") " + name + " (" + r + ")";
+ }
+
+ protected String visitNot(String v) {
+ return "NOT (" + v + ")";
+ }
+
+ protected String visitUnaryArithmetic(String name, String v) { return name + v; }
+
+ protected String visitCaseWhen(String[] children) {
+ StringBuilder sb = new StringBuilder("CASE");
+ for (int i = 0; i < children.length; i += 2) {
+ String c = children[i];
+ int j = i + 1;
+ if (j < children.length) {
+ String v = children[j];
+ sb.append(" WHEN ");
+ sb.append(c);
+ sb.append(" THEN ");
+ sb.append(v);
+ } else {
+ sb.append(" ELSE ");
+ sb.append(c);
+ }
+ }
+ sb.append(" END");
+ return sb.toString();
+ }
+
+ protected String visitSQLFunction(String funcName, String[] inputs) {
+ return funcName + "(" + Arrays.stream(inputs).collect(Collectors.joining(", ")) + ")";
+ }
+
+ protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
+ throw new IllegalArgumentException("Unexpected V2 expression: " + expr);
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala
index 70f821d5f8af0..fb177251a7306 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala
@@ -78,3 +78,6 @@ class PartitionsAlreadyExistException(message: String) extends AnalysisException
class FunctionAlreadyExistsException(db: String, func: String)
extends AnalysisException(s"Function '$func' already exists in database '$db'")
+
+class IndexAlreadyExistsException(message: String, cause: Option[Throwable] = None)
+ extends AnalysisException(message, cause = cause)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
index ba5a9c618c650..8b0710b2c1f19 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala
@@ -95,3 +95,6 @@ class NoSuchPartitionsException(message: String) extends AnalysisException(messa
class NoSuchTempFunctionException(func: String)
extends AnalysisException(s"Temporary function '$func' not found")
+
+class NoSuchIndexException(message: String, cause: Option[Throwable] = None)
+ extends AnalysisException(message, cause = cause)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala
new file mode 100644
index 0000000000000..f3ff28f74fcc3
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NonEmptyException.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
+
+
+/**
+ * Thrown by a catalog when an item already exists. The analyzer will rethrow the exception
+ * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information.
+ */
+case class NonEmptyNamespaceException(
+ override val message: String,
+ override val cause: Option[Throwable] = None)
+ extends AnalysisException(message, cause = cause) {
+
+ def this(namespace: Array[String]) = {
+ this(s"Namespace '${namespace.quoted}' is non empty.")
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala
index 0007d3868eda2..dea7ea0f144bf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AliasHelper.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.MultiAlias
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Project}
+import org.apache.spark.sql.types.Metadata
/**
* Helper methods for collecting and replacing aliases.
@@ -86,10 +87,15 @@ trait AliasHelper {
protected def trimNonTopLevelAliases[T <: Expression](e: T): T = {
val res = e match {
case a: Alias =>
+ val metadata = if (a.metadata == Metadata.empty) {
+ None
+ } else {
+ Some(a.metadata)
+ }
a.copy(child = trimAliases(a.child))(
exprId = a.exprId,
qualifier = a.qualifier,
- explicitMetadata = Some(a.metadata),
+ explicitMetadata = metadata,
nonInheritableMetadataKeys = a.nonInheritableMetadataKeys)
case a: MultiAlias =>
a.copy(child = trimAliases(a.child))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 9714a096a69a2..533f7f20b2530 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -69,15 +69,15 @@ case class Average(
case _ => DoubleType
}
- private lazy val sumDataType = child.dataType match {
+ lazy val sumDataType = child.dataType match {
case _ @ DecimalType.Fixed(p, s) => DecimalType.bounded(p + 10, s)
case _: YearMonthIntervalType => YearMonthIntervalType()
case _: DayTimeIntervalType => DayTimeIntervalType()
case _ => DoubleType
}
- private lazy val sum = AttributeReference("sum", sumDataType)()
- private lazy val count = AttributeReference("count", LongType)()
+ lazy val sum = AttributeReference("sum", sumDataType)()
+ lazy val count = AttributeReference("count", LongType)()
override lazy val aggBufferAttributes = sum :: count :: Nil
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index eb040e23290c9..1a57ee83fa3ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -912,56 +912,69 @@ object ColumnPruning extends Rule[LogicalPlan] {
*/
object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
- def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
- _.containsPattern(PROJECT), ruleId) {
- case p1 @ Project(_, p2: Project) =>
- if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
- p1
- } else {
+ def apply(plan: LogicalPlan): LogicalPlan = {
+ val alwaysInline = conf.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE)
+ plan.transformUpWithPruning(_.containsPattern(PROJECT), ruleId) {
+ case p1 @ Project(_, p2: Project)
+ if canCollapseExpressions(p1.projectList, p2.projectList, alwaysInline) =>
p2.copy(projectList = buildCleanedProjectList(p1.projectList, p2.projectList))
- }
- case p @ Project(_, agg: Aggregate) =>
- if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions) ||
- !canCollapseAggregate(p, agg)) {
- p
- } else {
+ case p @ Project(_, agg: Aggregate)
+ if canCollapseExpressions(p.projectList, agg.aggregateExpressions, alwaysInline) =>
agg.copy(aggregateExpressions = buildCleanedProjectList(
p.projectList, agg.aggregateExpressions))
- }
- case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _))))
+ case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _))))
if isRenaming(l1, l2) =>
- val newProjectList = buildCleanedProjectList(l1, l2)
- g.copy(child = limit.copy(child = p2.copy(projectList = newProjectList)))
- case Project(l1, limit @ LocalLimit(_, p2 @ Project(l2, _))) if isRenaming(l1, l2) =>
- val newProjectList = buildCleanedProjectList(l1, l2)
- limit.copy(child = p2.copy(projectList = newProjectList))
- case Project(l1, r @ Repartition(_, _, p @ Project(l2, _))) if isRenaming(l1, l2) =>
- r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, p.projectList)))
- case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if isRenaming(l1, l2) =>
- s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList)))
- }
-
- private def haveCommonNonDeterministicOutput(
- upper: Seq[NamedExpression], lower: Seq[NamedExpression]): Boolean = {
- val aliases = getAliasMap(lower)
+ val newProjectList = buildCleanedProjectList(l1, l2)
+ g.copy(child = limit.copy(child = p2.copy(projectList = newProjectList)))
+ case Project(l1, limit @ LocalLimit(_, p2 @ Project(l2, _))) if isRenaming(l1, l2) =>
+ val newProjectList = buildCleanedProjectList(l1, l2)
+ limit.copy(child = p2.copy(projectList = newProjectList))
+ case Project(l1, r @ Repartition(_, _, p @ Project(l2, _))) if isRenaming(l1, l2) =>
+ r.copy(child = p.copy(projectList = buildCleanedProjectList(l1, p.projectList)))
+ case Project(l1, s @ Sample(_, _, _, _, p2 @ Project(l2, _))) if isRenaming(l1, l2) =>
+ s.copy(child = p2.copy(projectList = buildCleanedProjectList(l1, p2.projectList)))
+ }
+ }
- // Collapse upper and lower Projects if and only if their overlapped expressions are all
- // deterministic.
- upper.exists(_.collect {
- case a: Attribute if aliases.contains(a) => aliases(a).child
- }.exists(!_.deterministic))
+ /**
+ * Check if we can collapse expressions safely.
+ */
+ def canCollapseExpressions(
+ consumers: Seq[Expression],
+ producers: Seq[NamedExpression],
+ alwaysInline: Boolean): Boolean = {
+ canCollapseExpressions(consumers, getAliasMap(producers), alwaysInline)
}
/**
- * A project cannot be collapsed with an aggregate when there are correlated scalar
- * subqueries in the project list, because currently we only allow correlated subqueries
- * in aggregate if they are also part of the grouping expressions. Otherwise the plan
- * after subquery rewrite will not be valid.
+ * Check if we can collapse expressions safely.
*/
- private def canCollapseAggregate(p: Project, a: Aggregate): Boolean = {
- p.projectList.forall(_.collect {
- case s: ScalarSubquery if s.outerAttrs.nonEmpty => s
- }.isEmpty)
+ def canCollapseExpressions(
+ consumers: Seq[Expression],
+ producerMap: Map[Attribute, Expression],
+ alwaysInline: Boolean = false): Boolean = {
+ // We can only collapse expressions if all input expressions meet the following criteria:
+ // - The input is deterministic.
+ // - The input is only consumed once OR the underlying input expression is cheap.
+ consumers.flatMap(collectReferences)
+ .groupBy(identity)
+ .mapValues(_.size)
+ .forall {
+ case (reference, count) =>
+ val producer = producerMap.getOrElse(reference, reference)
+ producer.deterministic && (count == 1 || alwaysInline || {
+ val relatedConsumers = consumers.filter(_.references.contains(reference))
+ val extractOnly = relatedConsumers.forall(isExtractOnly(_, reference))
+ shouldInline(producer, extractOnly)
+ })
+ }
+ }
+
+ private def isExtractOnly(expr: Expression, ref: Attribute): Boolean = expr match {
+ case a: Alias => isExtractOnly(a.child, ref)
+ case e: ExtractValue => isExtractOnly(e.children.head, ref)
+ case a: Attribute => a.semanticEquals(ref)
+ case _ => false
}
private def buildCleanedProjectList(
@@ -971,6 +984,34 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
upper.map(replaceAliasButKeepName(_, aliases))
}
+ /**
+ * Check if the given expression is cheap that we can inline it.
+ */
+ private def shouldInline(e: Expression, extractOnlyConsumer: Boolean): Boolean = e match {
+ case _: Attribute | _: OuterReference => true
+ case _ if e.foldable => true
+ // PythonUDF is handled by the rule ExtractPythonUDFs
+ case _: PythonUDF => true
+ // Alias and ExtractValue are very cheap.
+ case _: Alias | _: ExtractValue => e.children.forall(shouldInline(_, extractOnlyConsumer))
+ // These collection create functions are not cheap, but we have optimizer rules that can
+ // optimize them out if they are only consumed by ExtractValue, so we need to allow to inline
+ // them to avoid perf regression. As an example:
+ // Project(s.a, s.b, Project(create_struct(a, b, c) as s, child))
+ // We should collapse these two projects and eventually get Project(a, b, child)
+ case _: CreateNamedStruct | _: CreateArray | _: CreateMap | _: UpdateFields =>
+ extractOnlyConsumer
+ case _ => false
+ }
+
+ /**
+ * Return all the references of the given expression without deduplication, which is different
+ * from `Expression.references`.
+ */
+ private def collectReferences(e: Expression): Seq[Attribute] = e.collect {
+ case a: Attribute => a
+ }
+
private def isRenaming(list1: Seq[NamedExpression], list2: Seq[NamedExpression]): Boolean = {
list1.length == list2.length && list1.zip(list2).forall {
case (e1, e2) if e1.semanticEquals(e2) => true
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index fc12f48ec2a11..f33d137ffd607 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -26,46 +26,32 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
-trait OperationHelper {
- type ReturnType = (Seq[NamedExpression], Seq[Expression], LogicalPlan)
-
- protected def collectAliases(fields: Seq[Expression]): AttributeMap[Expression] =
- AttributeMap(fields.collect {
- case a: Alias => (a.toAttribute, a.child)
- })
-
- protected def substitute(aliases: AttributeMap[Expression])(expr: Expression): Expression = {
- // use transformUp instead of transformDown to avoid dead loop
- // in case of there's Alias whose exprId is the same as its child attribute.
- expr.transformUp {
- case a @ Alias(ref: AttributeReference, name) =>
- aliases.get(ref)
- .map(Alias(_, name)(a.exprId, a.qualifier))
- .getOrElse(a)
-
- case a: AttributeReference =>
- aliases.get(a)
- .map(Alias(_, a.name)(a.exprId, a.qualifier)).getOrElse(a)
- }
- }
-}
+trait OperationHelper extends AliasHelper with PredicateHelper {
+ import org.apache.spark.sql.catalyst.optimizer.CollapseProject.canCollapseExpressions
-/**
- * A pattern that matches any number of project or filter operations on top of another relational
- * operator. All filter operators are collected and their conditions are broken up and returned
- * together with the top project operator.
- * [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if
- * necessary.
- */
-object PhysicalOperation extends OperationHelper with PredicateHelper {
+ type ReturnType =
+ (Seq[NamedExpression], Seq[Expression], LogicalPlan)
+ type IntermediateType =
+ (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Alias])
def unapply(plan: LogicalPlan): Option[ReturnType] = {
- val (fields, filters, child, _) = collectProjectsAndFilters(plan)
+ val alwaysInline = SQLConf.get.getConf(SQLConf.COLLAPSE_PROJECT_ALWAYS_INLINE)
+ val (fields, filters, child, _) = collectProjectsAndFilters(plan, alwaysInline)
Some((fields.getOrElse(child.output), filters, child))
}
/**
- * Collects all deterministic projects and filters, in-lining/substituting aliases if necessary.
+ * This legacy mode is for PhysicalOperation which has been there for years and we want to be
+ * extremely safe to not change its behavior. There are two differences when legacy mode is off:
+ * 1. We postpone the deterministic check to the very end (calling `canCollapseExpressions`),
+ * so that it's more likely to collect more projects and filters.
+ * 2. We follow CollapseProject and only collect adjacent projects if they don't produce
+ * repeated expensive expressions.
+ */
+ protected def legacyMode: Boolean
+
+ /**
+ * Collects all adjacent projects and filters, in-lining/substituting aliases if necessary.
* Here are two examples for alias in-lining/substitution.
* Before:
* {{{
@@ -78,25 +64,60 @@ object PhysicalOperation extends OperationHelper with PredicateHelper {
* SELECT key AS c2 FROM t1 WHERE key > 10
* }}}
*/
- private def collectProjectsAndFilters(plan: LogicalPlan):
- (Option[Seq[NamedExpression]], Seq[Expression], LogicalPlan, AttributeMap[Expression]) =
+ private def collectProjectsAndFilters(
+ plan: LogicalPlan,
+ alwaysInline: Boolean): IntermediateType = {
+ def empty: IntermediateType = (None, Nil, plan, AttributeMap.empty)
+
plan match {
- case Project(fields, child) if fields.forall(_.deterministic) =>
- val (_, filters, other, aliases) = collectProjectsAndFilters(child)
- val substitutedFields = fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
- (Some(substitutedFields), filters, other, collectAliases(substitutedFields))
+ case Project(fields, child) if !legacyMode || fields.forall(_.deterministic) =>
+ val (_, filters, other, aliases) = collectProjectsAndFilters(child, alwaysInline)
+ if (legacyMode || canCollapseExpressions(fields, aliases, alwaysInline)) {
+ val replaced = fields.map(replaceAliasButKeepName(_, aliases))
+ (Some(replaced), filters, other, getAliasMap(replaced))
+ } else {
+ empty
+ }
- case Filter(condition, child) if condition.deterministic =>
- val (fields, filters, other, aliases) = collectProjectsAndFilters(child)
- val substitutedCondition = substitute(aliases)(condition)
- (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases)
+ case Filter(condition, child) if !legacyMode || condition.deterministic =>
+ val (fields, filters, other, aliases) = collectProjectsAndFilters(child, alwaysInline)
+ val canIncludeThisFilter = if (legacyMode) {
+ true
+ } else {
+ // When collecting projects and filters, we effectively push down filters through
+ // projects. We need to meet the following conditions to do so:
+ // 1) no Project collected so far or the collected Projects are all deterministic
+ // 2) the collected filters and this filter are all deterministic, or this is the
+ // first collected filter.
+ // 3) this filter does not repeat any expensive expressions from the collected
+ // projects.
+ fields.forall(_.forall(_.deterministic)) && {
+ filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic)
+ } && canCollapseExpressions(Seq(condition), aliases, alwaysInline)
+ }
+ if (canIncludeThisFilter) {
+ val replaced = replaceAlias(condition, aliases)
+ (fields, filters ++ splitConjunctivePredicates(replaced), other, aliases)
+ } else {
+ empty
+ }
- case h: ResolvedHint =>
- collectProjectsAndFilters(h.child)
+ case h: ResolvedHint => collectProjectsAndFilters(h.child, alwaysInline)
- case other =>
- (None, Nil, other, AttributeMap(Seq()))
+ case _ => empty
}
+ }
+}
+
+/**
+ * A pattern that matches any number of project or filter operations on top of another relational
+ * operator. All filter operators are collected and their conditions are broken up and returned
+ * together with the top project operator.
+ * [[org.apache.spark.sql.catalyst.expressions.Alias Aliases]] are in-lined/substituted if
+ * necessary.
+ */
+object PhysicalOperation extends OperationHelper with PredicateHelper {
+ override protected def legacyMode: Boolean = true
}
/**
@@ -105,70 +126,7 @@ object PhysicalOperation extends OperationHelper with PredicateHelper {
* requirement of CollapseProject and CombineFilters.
*/
object ScanOperation extends OperationHelper with PredicateHelper {
- type ScanReturnType = Option[(Option[Seq[NamedExpression]],
- Seq[Expression], LogicalPlan, AttributeMap[Expression])]
-
- def unapply(plan: LogicalPlan): Option[ReturnType] = {
- collectProjectsAndFilters(plan) match {
- case Some((fields, filters, child, _)) =>
- Some((fields.getOrElse(child.output), filters, child))
- case None => None
- }
- }
-
- private def hasCommonNonDeterministic(
- expr: Seq[Expression],
- aliases: AttributeMap[Expression]): Boolean = {
- expr.exists(_.collect {
- case a: AttributeReference if aliases.contains(a) => aliases(a)
- }.exists(!_.deterministic))
- }
-
- private def collectProjectsAndFilters(plan: LogicalPlan): ScanReturnType = {
- plan match {
- case Project(fields, child) =>
- collectProjectsAndFilters(child) match {
- case Some((_, filters, other, aliases)) =>
- // Follow CollapseProject and only keep going if the collected Projects
- // do not have common non-deterministic expressions.
- if (!hasCommonNonDeterministic(fields, aliases)) {
- val substitutedFields =
- fields.map(substitute(aliases)).asInstanceOf[Seq[NamedExpression]]
- Some((Some(substitutedFields), filters, other, collectAliases(substitutedFields)))
- } else {
- None
- }
- case None => None
- }
-
- case Filter(condition, child) =>
- collectProjectsAndFilters(child) match {
- case Some((fields, filters, other, aliases)) =>
- // When collecting projects and filters, we effectively push down filters through
- // projects. We need to meet the following conditions to do so:
- // 1) no Project collected so far or the collected Projects are all deterministic
- // 2) the collected filters and this filter are all deterministic, or this is the
- // first collected filter.
- val canCombineFilters = fields.forall(_.forall(_.deterministic)) && {
- filters.isEmpty || (filters.forall(_.deterministic) && condition.deterministic)
- }
- val substitutedCondition = substitute(aliases)(condition)
- if (canCombineFilters && !hasCommonNonDeterministic(Seq(condition), aliases)) {
- Some((fields, filters ++ splitConjunctivePredicates(substitutedCondition),
- other, aliases))
- } else {
- None
- }
- case None => None
- }
-
- case h: ResolvedHint =>
- collectProjectsAndFilters(h.child)
-
- case other =>
- Some((None, Nil, other, AttributeMap(Seq())))
- }
- }
+ override protected def legacyMode: Boolean = false
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
index 39642fd541706..185a1a2644e2f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
@@ -38,12 +38,13 @@ private[sql] object CatalogV2Implicits {
implicit class BucketSpecHelper(spec: BucketSpec) {
def asTransform: BucketTransform = {
+ val references = spec.bucketColumnNames.map(col => reference(Seq(col)))
if (spec.sortColumnNames.nonEmpty) {
- throw QueryCompilationErrors.cannotConvertBucketWithSortColumnsToTransformError(spec)
+ val sortedCol = spec.sortColumnNames.map(col => reference(Seq(col)))
+ bucket(spec.numBuckets, references.toArray, sortedCol.toArray)
+ } else {
+ bucket(spec.numBuckets, references.toArray)
}
-
- val references = spec.bucketColumnNames.map(col => reference(Seq(col)))
- bucket(spec.numBuckets, references.toArray)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala
index 2863d94d198b2..e3eab6f6730f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala
@@ -45,6 +45,12 @@ private[sql] object LogicalExpressions {
def bucket(numBuckets: Int, references: Array[NamedReference]): BucketTransform =
BucketTransform(literal(numBuckets, IntegerType), references)
+ def bucket(
+ numBuckets: Int,
+ references: Array[NamedReference],
+ sortedCols: Array[NamedReference]): BucketTransform =
+ BucketTransform(literal(numBuckets, IntegerType), references, sortedCols)
+
def identity(reference: NamedReference): IdentityTransform = IdentityTransform(reference)
def years(reference: NamedReference): YearsTransform = YearsTransform(reference)
@@ -82,9 +88,7 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R
override def arguments: Array[Expression] = Array(ref)
- override def describe: String = name + "(" + reference.describe + ")"
-
- override def toString: String = describe
+ override def toString: String = name + "(" + reference.describe + ")"
protected def withNewRef(ref: NamedReference): Transform
@@ -97,7 +101,8 @@ private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends R
private[sql] final case class BucketTransform(
numBuckets: Literal[Int],
- columns: Seq[NamedReference]) extends RewritableTransform {
+ columns: Seq[NamedReference],
+ sortedColumns: Seq[NamedReference] = Seq.empty[NamedReference]) extends RewritableTransform {
override val name: String = "bucket"
@@ -107,9 +112,13 @@ private[sql] final case class BucketTransform(
override def arguments: Array[Expression] = numBuckets +: columns.toArray
- override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})"
-
- override def toString: String = describe
+ override def toString: String =
+ if (sortedColumns.nonEmpty) {
+ s"bucket(${arguments.map(_.describe).mkString(", ")}," +
+ s" ${sortedColumns.map(_.describe).mkString(", ")})"
+ } else {
+ s"bucket(${arguments.map(_.describe).mkString(", ")})"
+ }
override def withReferences(newReferences: Seq[NamedReference]): Transform = {
this.copy(columns = newReferences)
@@ -117,11 +126,12 @@ private[sql] final case class BucketTransform(
}
private[sql] object BucketTransform {
- def unapply(expr: Expression): Option[(Int, FieldReference)] = expr match {
+ def unapply(expr: Expression): Option[(Int, FieldReference, FieldReference)] =
+ expr match {
case transform: Transform =>
transform match {
- case BucketTransform(n, FieldReference(parts)) =>
- Some((n, FieldReference(parts)))
+ case BucketTransform(n, FieldReference(parts), FieldReference(sortCols)) =>
+ Some((n, FieldReference(parts), FieldReference(sortCols)))
case _ =>
None
}
@@ -129,11 +139,17 @@ private[sql] object BucketTransform {
None
}
- def unapply(transform: Transform): Option[(Int, NamedReference)] = transform match {
+ def unapply(transform: Transform): Option[(Int, NamedReference, NamedReference)] =
+ transform match {
+ case NamedTransform("bucket", Seq(
+ Lit(value: Int, IntegerType),
+ Ref(partCols: Seq[String]),
+ Ref(sortCols: Seq[String]))) =>
+ Some((value, FieldReference(partCols), FieldReference(sortCols)))
case NamedTransform("bucket", Seq(
Lit(value: Int, IntegerType),
- Ref(seq: Seq[String]))) =>
- Some((value, FieldReference(seq)))
+ Ref(partCols: Seq[String]))) =>
+ Some((value, FieldReference(partCols), FieldReference(Seq.empty[String])))
case _ =>
None
}
@@ -149,9 +165,7 @@ private[sql] final case class ApplyTransform(
arguments.collect { case named: NamedReference => named }
}
- override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})"
-
- override def toString: String = describe
+ override def toString: String = s"$name(${arguments.map(_.describe).mkString(", ")})"
}
/**
@@ -318,21 +332,19 @@ private[sql] object HoursTransform {
}
private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] {
- override def describe: String = {
+ override def toString: String = {
if (dataType.isInstanceOf[StringType]) {
s"'$value'"
} else {
s"$value"
}
}
- override def toString: String = describe
}
private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper
override def fieldNames: Array[String] = parts.toArray
- override def describe: String = parts.quoted
- override def toString: String = describe
+ override def toString: String = parts.quoted
}
private[sql] object FieldReference {
@@ -346,7 +358,7 @@ private[sql] final case class SortValue(
direction: SortDirection,
nullOrdering: NullOrdering) extends SortOrder {
- override def describe(): String = s"$expression $direction $nullOrdering"
+ override def toString(): String = s"$expression $direction $nullOrdering"
}
private[sql] object SortValue {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index e7af006ad7023..0c7a1030fd434 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -24,7 +24,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, QualifiedTableName, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.{CannotReplaceMissingTableException, NamespaceAlreadyExistsException, NoSuchFunctionException, NoSuchNamespaceException, NoSuchPartitionException, NoSuchTableException, ResolvedNamespace, ResolvedTable, ResolvedView, TableAlreadyExistsException}
-import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, InvalidUDFClassException}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, InvalidUDFClassException}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CreateMap, Expression, GroupingID, NamedExpression, SpecifiedWindowFrame, WindowFrame, WindowFunction, WindowSpecDefinition}
import org.apache.spark.sql.catalyst.plans.JoinType
@@ -555,6 +555,11 @@ object QueryCompilationErrors {
new AnalysisException(s"Database $db is not empty. One or more $details exist.")
}
+ def cannotDropNonemptyNamespaceError(namespace: Seq[String]): Throwable = {
+ new AnalysisException(s"Cannot drop a non-empty namespace: ${namespace.quoted}. " +
+ "Use CASCADE option to drop a non-empty namespace.")
+ }
+
def invalidNameForTableOrDatabaseError(name: String): Throwable = {
new AnalysisException(s"`$name` is not a valid name for tables/databases. " +
"Valid names only contain alphabet characters, numbers and _.")
@@ -1371,11 +1376,6 @@ object QueryCompilationErrors {
new AnalysisException("Cannot use interval type in the table schema.")
}
- def cannotConvertBucketWithSortColumnsToTransformError(spec: BucketSpec): Throwable = {
- new AnalysisException(
- s"Cannot convert bucketing with sort columns to a transform: $spec")
- }
-
def cannotConvertTransformsToPartitionColumnsError(nonIdTransforms: Seq[Transform]): Throwable = {
new AnalysisException("Transforms cannot be converted to partition columns: " +
nonIdTransforms.map(_.describe).mkString(", "))
@@ -2371,4 +2371,8 @@ object QueryCompilationErrors {
messageParameters = Array(fieldName.quoted, path.quoted),
origin = context)
}
+
+ def noSuchFunctionError(database: String, funcInfo: String): Throwable = {
+ new AnalysisException(s"$database does not support function: $funcInfo")
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
index 7f77243af8a88..88ab9e530a1a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala
@@ -1804,4 +1804,16 @@ object QueryExecutionErrors {
def pivotNotAfterGroupByUnsupportedError(): Throwable = {
new UnsupportedOperationException("pivot is only supported after a groupBy")
}
+
+ def unsupportedCreateNamespaceCommentError(): Throwable = {
+ new SQLFeatureNotSupportedException("Create namespace comment is not supported")
+ }
+
+ def unsupportedRemoveNamespaceCommentError(): Throwable = {
+ new SQLFeatureNotSupportedException("Remove namespace comment is not supported")
+ }
+
+ def unsupportedDropNamespaceRestrictError(): Throwable = {
+ new SQLFeatureNotSupportedException("Drop namespace restrict is not supported")
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 15927a9ffdfbf..96ca754cad220 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -851,6 +851,14 @@ object SQLConf {
.checkValue(threshold => threshold >= 0, "The threshold must not be negative.")
.createWithDefault(10)
+ val PARQUET_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.aggregatePushdown")
+ .doc("If true, MAX/MIN/COUNT without filter and group by will be pushed" +
+ " down to Parquet for optimization. MAX/MIN/COUNT for complex types and timestamp" +
+ " can't be pushed down")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(false)
+
val PARQUET_WRITE_LEGACY_FORMAT = buildConf("spark.sql.parquet.writeLegacyFormat")
.doc("If true, data will be written in a way of Spark 1.4 and earlier. For example, decimal " +
"values will be written in Apache Parquet's fixed-length byte array format, which other " +
@@ -942,6 +950,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val ORC_AGGREGATE_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.aggregatePushdown")
+ .doc("If true, aggregates will be pushed down to ORC for optimization. Support MIN, MAX and " +
+ "COUNT as aggregate expression. For MIN/MAX, support boolean, integer, float and date " +
+ "type. For COUNT, support all data types.")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(false)
+
val ORC_SCHEMA_MERGING_ENABLED = buildConf("spark.sql.orc.mergeSchema")
.doc("When true, the Orc data source merges schemas collected from all data files, " +
"otherwise the schema is picked from a random data file.")
@@ -1852,6 +1868,13 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val COLLAPSE_PROJECT_ALWAYS_INLINE = buildConf("spark.sql.optimizer.collapseProjectAlwaysInline")
+ .doc("Whether to always collapse two adjacent projections and inline expressions even if " +
+ "it causes extra duplication.")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(false)
+
val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion")
.internal()
.doc("Whether to delete the expired log files in file stream sink.")
@@ -3679,8 +3702,12 @@ class SQLConf extends Serializable with Logging {
def parquetFilterPushDownInFilterThreshold: Int =
getConf(PARQUET_FILTER_PUSHDOWN_INFILTERTHRESHOLD)
+ def parquetAggregatePushDown: Boolean = getConf(PARQUET_AGGREGATE_PUSHDOWN_ENABLED)
+
def orcFilterPushDown: Boolean = getConf(ORC_FILTER_PUSHDOWN_ENABLED)
+ def orcAggregatePushDown: Boolean = getConf(ORC_AGGREGATE_PUSHDOWN_ENABLED)
+
def isOrcSchemaMergingEnabled: Boolean = getConf(ORC_SCHEMA_MERGING_ENABLED)
def verifyPartitionPath: Boolean = getConf(HIVE_VERIFY_PARTITION_PATH)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala
new file mode 100644
index 0000000000000..9c2a4ac78a24a
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/SupportsPushDownCatalystFilters.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.internal.connector
+
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.sources.Filter
+
+/**
+ * A mix-in interface for {@link FileScanBuilder}. File sources can implement this interface to
+ * push down filters to the file source. The pushed down filters will be separated into partition
+ * filters and data filters. Partition filters are used for partition pruning and data filters are
+ * used to reduce the size of the data to be read.
+ */
+trait SupportsPushDownCatalystFilters {
+
+ /**
+ * Pushes down catalyst Expression filters (which will be separated into partition filters and
+ * data filters), and returns data filters that need to be evaluated after scanning.
+ */
+ def pushFilters(filters: Seq[Expression]): Seq[Expression]
+
+ /**
+ * Returns the data filters that are pushed to the data source via
+ * {@link #pushFilters(Expression[])}.
+ */
+ def pushedFilters: Array[Filter]
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
index 80658f7cec2e3..e358ff0cb6677 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
@@ -18,7 +18,12 @@
package org.apache.spark.sql.sources
import org.apache.spark.annotation.{Evolving, Stable}
+import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath
+import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue}
+import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse => V2AlwaysFalse, AlwaysTrue => V2AlwaysTrue, Predicate}
+import org.apache.spark.sql.types.StringType
+import org.apache.spark.unsafe.types.UTF8String
////////////////////////////////////////////////////////////////////////////////////////////////////
// This file defines all the filters that we can push down to the data sources.
@@ -64,6 +69,11 @@ sealed abstract class Filter {
private[sql] def containsNestedColumn: Boolean = {
this.v2references.exists(_.length > 1)
}
+
+ /**
+ * Converts V1 filter to V2 filter
+ */
+ private[sql] def toV2: Predicate
}
/**
@@ -78,6 +88,11 @@ sealed abstract class Filter {
@Stable
case class EqualTo(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
+ override def toV2: Predicate = {
+ val literal = Literal(value)
+ new Predicate("=",
+ Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType)))
+ }
}
/**
@@ -93,6 +108,11 @@ case class EqualTo(attribute: String, value: Any) extends Filter {
@Stable
case class EqualNullSafe(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
+ override def toV2: Predicate = {
+ val literal = Literal(value)
+ new Predicate("<=>",
+ Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType)))
+ }
}
/**
@@ -107,6 +127,11 @@ case class EqualNullSafe(attribute: String, value: Any) extends Filter {
@Stable
case class GreaterThan(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
+ override def toV2: Predicate = {
+ val literal = Literal(value)
+ new Predicate(">",
+ Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType)))
+ }
}
/**
@@ -121,6 +146,11 @@ case class GreaterThan(attribute: String, value: Any) extends Filter {
@Stable
case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
+ override def toV2: Predicate = {
+ val literal = Literal(value)
+ new Predicate(">=",
+ Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType)))
+ }
}
/**
@@ -135,6 +165,11 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter {
@Stable
case class LessThan(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
+ override def toV2: Predicate = {
+ val literal = Literal(value)
+ new Predicate("<",
+ Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType)))
+ }
}
/**
@@ -149,6 +184,11 @@ case class LessThan(attribute: String, value: Any) extends Filter {
@Stable
case class LessThanOrEqual(attribute: String, value: Any) extends Filter {
override def references: Array[String] = Array(attribute) ++ findReferences(value)
+ override def toV2: Predicate = {
+ val literal = Literal(value)
+ new Predicate("<=",
+ Array(FieldReference(attribute), LiteralValue(literal.value, literal.dataType)))
+ }
}
/**
@@ -185,6 +225,13 @@ case class In(attribute: String, values: Array[Any]) extends Filter {
}
override def references: Array[String] = Array(attribute) ++ values.flatMap(findReferences)
+ override def toV2: Predicate = {
+ val literals = values.map { value =>
+ val literal = Literal(value)
+ LiteralValue(literal.value, literal.dataType)
+ }
+ new Predicate("IN", FieldReference(attribute) +: literals)
+ }
}
/**
@@ -198,6 +245,7 @@ case class In(attribute: String, values: Array[Any]) extends Filter {
@Stable
case class IsNull(attribute: String) extends Filter {
override def references: Array[String] = Array(attribute)
+ override def toV2: Predicate = new Predicate("IS_NULL", Array(FieldReference(attribute)))
}
/**
@@ -211,6 +259,7 @@ case class IsNull(attribute: String) extends Filter {
@Stable
case class IsNotNull(attribute: String) extends Filter {
override def references: Array[String] = Array(attribute)
+ override def toV2: Predicate = new Predicate("IS_NOT_NULL", Array(FieldReference(attribute)))
}
/**
@@ -221,6 +270,7 @@ case class IsNotNull(attribute: String) extends Filter {
@Stable
case class And(left: Filter, right: Filter) extends Filter {
override def references: Array[String] = left.references ++ right.references
+ override def toV2: Predicate = new Predicate("AND", Seq(left, right).map(_.toV2).toArray)
}
/**
@@ -231,6 +281,7 @@ case class And(left: Filter, right: Filter) extends Filter {
@Stable
case class Or(left: Filter, right: Filter) extends Filter {
override def references: Array[String] = left.references ++ right.references
+ override def toV2: Predicate = new Predicate("OR", Seq(left, right).map(_.toV2).toArray)
}
/**
@@ -241,6 +292,7 @@ case class Or(left: Filter, right: Filter) extends Filter {
@Stable
case class Not(child: Filter) extends Filter {
override def references: Array[String] = child.references
+ override def toV2: Predicate = new Predicate("NOT", Array(child.toV2))
}
/**
@@ -255,6 +307,8 @@ case class Not(child: Filter) extends Filter {
@Stable
case class StringStartsWith(attribute: String, value: String) extends Filter {
override def references: Array[String] = Array(attribute)
+ override def toV2: Predicate = new Predicate("STARTS_WITH",
+ Array(FieldReference(attribute), LiteralValue(UTF8String.fromString(value), StringType)))
}
/**
@@ -269,6 +323,8 @@ case class StringStartsWith(attribute: String, value: String) extends Filter {
@Stable
case class StringEndsWith(attribute: String, value: String) extends Filter {
override def references: Array[String] = Array(attribute)
+ override def toV2: Predicate = new Predicate("ENDS_WITH",
+ Array(FieldReference(attribute), LiteralValue(UTF8String.fromString(value), StringType)))
}
/**
@@ -283,6 +339,8 @@ case class StringEndsWith(attribute: String, value: String) extends Filter {
@Stable
case class StringContains(attribute: String, value: String) extends Filter {
override def references: Array[String] = Array(attribute)
+ override def toV2: Predicate = new Predicate("CONTAINS",
+ Array(FieldReference(attribute), LiteralValue(UTF8String.fromString(value), StringType)))
}
/**
@@ -293,6 +351,7 @@ case class StringContains(attribute: String, value: String) extends Filter {
@Evolving
case class AlwaysTrue() extends Filter {
override def references: Array[String] = Array.empty
+ override def toV2: Predicate = new V2AlwaysTrue()
}
@Evolving
@@ -307,6 +366,7 @@ object AlwaysTrue extends AlwaysTrue {
@Evolving
case class AlwaysFalse() extends Filter {
override def references: Array[String] = Array.empty
+ override def toV2: Predicate = new V2AlwaysFalse()
}
@Evolving
@@ -316,4 +376,9 @@ object AlwaysFalse extends AlwaysFalse {
@Evolving
case class Trivial(value: Boolean) extends Filter {
override def references: Array[String] = findReferences(value)
+ override def toV2: Predicate = {
+ val literal = Literal(value)
+ new Predicate("TRIVIAL",
+ Array(LiteralValue(literal.value, literal.dataType)))
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
index 1e7f9b0edd91c..c1d13d14b05f7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseProjectSuite.scala
@@ -121,6 +121,16 @@ class CollapseProjectSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("SPARK-36718: do not collapse project if non-cheap expressions will be repeated") {
+ val query = testRelation
+ .select(('a + 1).as('a_plus_1))
+ .select(('a_plus_1 + 'a_plus_1).as('a_2_plus_2))
+ .analyze
+
+ val optimized = Optimize.execute(query)
+ comparePlans(optimized, query)
+ }
+
test("preserve top-level alias metadata while collapsing projects") {
def hasMetadata(logicalPlan: LogicalPlan): Boolean = {
logicalPlan.asInstanceOf[Project].projectList.exists(_.metadata.contains("key"))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala
index b1baeccbe94b9..eb3899c9187db 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/planning/ScanOperationSuite.scala
@@ -57,7 +57,14 @@ class ScanOperationSuite extends SparkFunSuite {
test("Project which has the same non-deterministic expression with its child Project") {
val project3 = Project(Seq(colA, colR), Project(Seq(colA, aliasR), relation))
- assert(ScanOperation.unapply(project3).isEmpty)
+ project3 match {
+ case ScanOperation(projects, filters, _: Project) =>
+ assert(projects.size === 2)
+ assert(projects(0) === colA)
+ assert(projects(1) === colR)
+ assert(filters.isEmpty)
+ case _ => assert(false)
+ }
}
test("Project which has different non-deterministic expressions with its child Project") {
@@ -73,13 +80,18 @@ class ScanOperationSuite extends SparkFunSuite {
test("Filter with non-deterministic Project") {
val filter1 = Filter(EqualTo(colA, Literal(1)), Project(Seq(colA, aliasR), relation))
- assert(ScanOperation.unapply(filter1).isEmpty)
+ filter1 match {
+ case ScanOperation(projects, filters, _: Filter) =>
+ assert(projects.size === 2)
+ assert(filters.isEmpty)
+ case _ => assert(false)
+ }
}
test("Non-deterministic Filter with deterministic Project") {
- val filter3 = Filter(EqualTo(MonotonicallyIncreasingID(), Literal(1)),
+ val filter2 = Filter(EqualTo(MonotonicallyIncreasingID(), Literal(1)),
Project(Seq(colA, colB), relation))
- filter3 match {
+ filter2 match {
case ScanOperation(projects, filters, _: LocalRelation) =>
assert(projects.size === 2)
assert(projects(0) === colA)
@@ -91,7 +103,11 @@ class ScanOperationSuite extends SparkFunSuite {
test("Deterministic filter which has a non-deterministic child Filter") {
- val filter4 = Filter(EqualTo(colA, Literal(1)), Filter(EqualTo(aliasR, Literal(1)), relation))
- assert(ScanOperation.unapply(filter4).isEmpty)
+ val filter3 = Filter(EqualTo(colA, Literal(1)), Filter(EqualTo(aliasR, Literal(1)), relation))
+ filter3 match {
+ case ScanOperation(projects, filters, _: Filter) =>
+ assert(filters.isEmpty)
+ case _ => assert(false)
+ }
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
index 0cca1cc9bebf2..d00bc31e07f19 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala
@@ -820,7 +820,7 @@ class CatalogSuite extends SparkFunSuite {
assert(catalog.namespaceExists(testNs) === false)
- val ret = catalog.dropNamespace(testNs)
+ val ret = catalog.dropNamespace(testNs, cascade = false)
assert(ret === false)
}
@@ -833,7 +833,7 @@ class CatalogSuite extends SparkFunSuite {
assert(catalog.namespaceExists(testNs) === true)
assert(catalog.loadNamespaceMetadata(testNs).asScala === Map("property" -> "value"))
- val ret = catalog.dropNamespace(testNs)
+ val ret = catalog.dropNamespace(testNs, cascade = false)
assert(ret === true)
assert(catalog.namespaceExists(testNs) === false)
@@ -845,7 +845,7 @@ class CatalogSuite extends SparkFunSuite {
catalog.createNamespace(testNs, Map("property" -> "value").asJava)
catalog.createTable(testIdent, schema, Array.empty, emptyProps)
- assert(catalog.dropNamespace(testNs))
+ assert(catalog.dropNamespace(testNs, cascade = true))
assert(!catalog.namespaceExists(testNs))
intercept[NoSuchNamespaceException](catalog.listTables(testNs))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
index 2f3c5a38538c8..e0604576a94bc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTable.scala
@@ -161,7 +161,7 @@ class InMemoryTable(
case (v, t) =>
throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)")
}
- case BucketTransform(numBuckets, ref) =>
+ case BucketTransform(numBuckets, ref, _) =>
val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row)
val valueHashCode = if (value == null) 0 else value.hashCode
((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
index 0c403baca2113..41063a41b9719 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala
@@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
-import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NonEmptyNamespaceException, NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
import org.apache.spark.sql.connector.distributions.{Distribution, Distributions}
import org.apache.spark.sql.connector.expressions.{SortOrder, Transform}
import org.apache.spark.sql.types.StructType
@@ -193,10 +193,16 @@ class InMemoryTableCatalog extends BasicInMemoryTableCatalog with SupportsNamesp
namespaces.put(namespace.toList, CatalogV2Util.applyNamespaceChanges(metadata, changes))
}
- override def dropNamespace(namespace: Array[String]): Boolean = {
- listNamespaces(namespace).foreach(dropNamespace)
+ override def dropNamespace(namespace: Array[String], cascade: Boolean): Boolean = {
try {
- listTables(namespace).foreach(dropTable)
+ if (!cascade) {
+ if (listTables(namespace).nonEmpty || listNamespaces(namespace).nonEmpty) {
+ throw new NonEmptyNamespaceException(namespace)
+ }
+ } else {
+ listNamespaces(namespace).foreach(namespace => dropNamespace(namespace, cascade))
+ listTables(namespace).foreach(dropTable)
+ }
} catch {
case _: NoSuchNamespaceException =>
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala
index fbd6a886d011b..4a50e063bee68 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/expressions/TransformExtractorSuite.scala
@@ -28,7 +28,7 @@ class TransformExtractorSuite extends SparkFunSuite {
private def lit[T](literal: T): Literal[T] = new Literal[T] {
override def value: T = literal
override def dataType: DataType = catalyst.expressions.Literal(literal).dataType
- override def describe: String = literal.toString
+ override def toString: String = literal.toString
}
/**
@@ -36,7 +36,7 @@ class TransformExtractorSuite extends SparkFunSuite {
*/
private def ref(names: String*): NamedReference = new NamedReference {
override def fieldNames: Array[String] = names.toArray
- override def describe: String = names.mkString(".")
+ override def toString: String = names.mkString(".")
}
/**
@@ -44,9 +44,8 @@ class TransformExtractorSuite extends SparkFunSuite {
*/
private def transform(func: String, ref: NamedReference): Transform = new Transform {
override def name: String = func
- override def references: Array[NamedReference] = Array(ref)
override def arguments: Array[Expression] = Array(ref)
- override def describe: String = ref.describe
+ override def toString: String = ref.describe
}
test("Identity extractor") {
@@ -135,11 +134,11 @@ class TransformExtractorSuite extends SparkFunSuite {
override def name: String = "bucket"
override def references: Array[NamedReference] = Array(col)
override def arguments: Array[Expression] = Array(lit(16), col)
- override def describe: String = s"bucket(16, ${col.describe})"
+ override def toString: String = s"bucket(16, ${col.describe})"
}
bucketTransform match {
- case BucketTransform(numBuckets, FieldReference(seq)) =>
+ case BucketTransform(numBuckets, FieldReference(seq), _) =>
assert(numBuckets === 16)
assert(seq === Seq("a", "b"))
case _ =>
@@ -147,7 +146,7 @@ class TransformExtractorSuite extends SparkFunSuite {
}
transform("unknown", ref("a", "b")) match {
- case BucketTransform(_, _) =>
+ case BucketTransform(_, _, _) =>
fail("Matched unknown transform")
case _ =>
// expected
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 85bb234cf9a97..998de75018d4e 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../../pom.xml
@@ -136,7 +136,7 @@
com.h2database
h2
- 1.4.195
+ 2.0.204
test
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java
new file mode 100644
index 0000000000000..8adb9e8ca20be
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnStatistics.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc;
+
+import org.apache.orc.ColumnStatistics;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Columns statistics interface wrapping ORC {@link ColumnStatistics}s.
+ *
+ * Because ORC {@link ColumnStatistics}s are stored as an flatten array in ORC file footer,
+ * this class is used to covert ORC {@link ColumnStatistics}s from array to nested tree structure,
+ * according to data types. The flatten array stores all data types (including nested types) in
+ * tree pre-ordering. This is used for aggregate push down in ORC.
+ *
+ * For nested data types (array, map and struct), the sub-field statistics are stored recursively
+ * inside parent column's children field. Here is an example of {@link OrcColumnStatistics}:
+ *
+ * Data schema:
+ * c1: int
+ * c2: struct
+ * c3: map
+ * c4: array
+ *
+ * OrcColumnStatistics
+ * | (children)
+ * ---------------------------------------------
+ * / | \ \
+ * c1 c2 c3 c4
+ * (integer) (struct) (map) (array)
+* (min:1, | (children) | (children) | (children)
+ * max:10) ----- ----- element
+ * / \ / \ (integer)
+ * c2.f1 c2.f2 key value
+ * (integer) (float) (integer) (string)
+ * (min:0.1, (min:"a",
+ * max:100.5) max:"zzz")
+ */
+public class OrcColumnStatistics {
+ private final ColumnStatistics statistics;
+ private final List children;
+
+ public OrcColumnStatistics(ColumnStatistics statistics) {
+ this.statistics = statistics;
+ this.children = new ArrayList<>();
+ }
+
+ public ColumnStatistics getStatistics() {
+ return statistics;
+ }
+
+ public OrcColumnStatistics get(int ordinal) {
+ if (ordinal < 0 || ordinal >= children.size()) {
+ throw new IndexOutOfBoundsException(
+ String.format("Ordinal %d out of bounds of statistics size %d", ordinal, children.size()));
+ }
+ return children.get(ordinal);
+ }
+
+ public void add(OrcColumnStatistics newChild) {
+ children.add(newChild);
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java
new file mode 100644
index 0000000000000..546b048648844
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcFooterReader.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc;
+
+import org.apache.orc.ColumnStatistics;
+import org.apache.orc.Reader;
+import org.apache.orc.TypeDescription;
+import org.apache.spark.sql.types.*;
+
+import java.util.Arrays;
+import java.util.LinkedList;
+import java.util.Queue;
+
+/**
+ * {@link OrcFooterReader} is a util class which encapsulates the helper
+ * methods of reading ORC file footer.
+ */
+public class OrcFooterReader {
+
+ /**
+ * Read the columns statistics from ORC file footer.
+ *
+ * @param orcReader the reader to read ORC file footer.
+ * @return Statistics for all columns in the file.
+ */
+ public static OrcColumnStatistics readStatistics(Reader orcReader) {
+ TypeDescription orcSchema = orcReader.getSchema();
+ ColumnStatistics[] orcStatistics = orcReader.getStatistics();
+ StructType sparkSchema = OrcUtils.toCatalystSchema(orcSchema);
+ return convertStatistics(sparkSchema, new LinkedList<>(Arrays.asList(orcStatistics)));
+ }
+
+ /**
+ * Convert a queue of ORC {@link ColumnStatistics}s into Spark {@link OrcColumnStatistics}.
+ * The queue of ORC {@link ColumnStatistics}s are assumed to be ordered as tree pre-order.
+ */
+ private static OrcColumnStatistics convertStatistics(
+ DataType sparkSchema, Queue orcStatistics) {
+ OrcColumnStatistics statistics = new OrcColumnStatistics(orcStatistics.remove());
+ if (sparkSchema instanceof StructType) {
+ for (StructField field : ((StructType) sparkSchema).fields()) {
+ statistics.add(convertStatistics(field.dataType(), orcStatistics));
+ }
+ } else if (sparkSchema instanceof MapType) {
+ statistics.add(convertStatistics(((MapType) sparkSchema).keyType(), orcStatistics));
+ statistics.add(convertStatistics(((MapType) sparkSchema).valueType(), orcStatistics));
+ } else if (sparkSchema instanceof ArrayType) {
+ statistics.add(convertStatistics(((ArrayType) sparkSchema).elementType(), orcStatistics));
+ }
+ return statistics;
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
new file mode 100644
index 0000000000000..b9847d48b2e17
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import org.apache.spark.sql.catalyst.expressions.{Abs, Add, And, BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen, Cast, Ceil, Coalesce, Contains, Divide, EndsWith, EqualTo, Exp, Expression, Floor, In, InSet, IsNotNull, IsNull, Literal, Log, Multiply, Not, Or, Pow, Predicate, Remainder, Sqrt, StartsWith, StringPredicate, Subtract, UnaryMinus, WidthBucket}
+import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue}
+import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}
+import org.apache.spark.sql.execution.datasources.PushableColumn
+import org.apache.spark.sql.types.BooleanType
+
+/**
+ * The builder to generate V2 expressions from catalyst expressions.
+ */
+class V2ExpressionBuilder(
+ e: Expression, nestedPredicatePushdownEnabled: Boolean = false, isPredicate: Boolean = false) {
+
+ val pushableColumn = PushableColumn(nestedPredicatePushdownEnabled)
+
+ def build(): Option[V2Expression] = generateExpression(e, isPredicate)
+
+ private def canTranslate(b: BinaryOperator) = b match {
+ case _: And | _: Or => true
+ case _: BinaryComparison => true
+ case _: BitwiseAnd | _: BitwiseOr | _: BitwiseXor => true
+ case add: Add => add.failOnError
+ case sub: Subtract => sub.failOnError
+ case mul: Multiply => mul.failOnError
+ case div: Divide => div.failOnError
+ case r: Remainder => r.failOnError
+ case _ => false
+ }
+
+ private def generateExpression(
+ expr: Expression, isPredicate: Boolean = false): Option[V2Expression] = expr match {
+ case Literal(true, BooleanType) => Some(new AlwaysTrue())
+ case Literal(false, BooleanType) => Some(new AlwaysFalse())
+ case Literal(value, dataType) => Some(LiteralValue(value, dataType))
+ case col @ pushableColumn(name) if nestedPredicatePushdownEnabled =>
+ if (isPredicate && col.dataType.isInstanceOf[BooleanType]) {
+ Some(new V2Predicate("=", Array(FieldReference(name), LiteralValue(true, BooleanType))))
+ } else {
+ Some(FieldReference(name))
+ }
+ case pushableColumn(name) if !nestedPredicatePushdownEnabled =>
+ Some(FieldReference(name))
+ case in @ InSet(child, hset) =>
+ generateExpression(child).map { v =>
+ val children =
+ (v +: hset.toSeq.map(elem => LiteralValue(elem, in.dataType))).toArray[V2Expression]
+ new V2Predicate("IN", children)
+ }
+ // Because we only convert In to InSet in Optimizer when there are more than certain
+ // items. So it is possible we still get an In expression here that needs to be pushed
+ // down.
+ case In(value, list) =>
+ val v = generateExpression(value)
+ val listExpressions = list.flatMap(generateExpression(_))
+ if (v.isDefined && list.length == listExpressions.length) {
+ val children = (v.get +: listExpressions).toArray[V2Expression]
+ // The children looks like [expr, value1, ..., valueN]
+ Some(new V2Predicate("IN", children))
+ } else {
+ None
+ }
+ case IsNull(col) => generateExpression(col)
+ .map(c => new V2Predicate("IS_NULL", Array[V2Expression](c)))
+ case IsNotNull(col) => generateExpression(col)
+ .map(c => new V2Predicate("IS_NOT_NULL", Array[V2Expression](c)))
+ case p: StringPredicate =>
+ val left = generateExpression(p.left)
+ val right = generateExpression(p.right)
+ if (left.isDefined && right.isDefined) {
+ val name = p match {
+ case _: StartsWith => "STARTS_WITH"
+ case _: EndsWith => "ENDS_WITH"
+ case _: Contains => "CONTAINS"
+ }
+ Some(new V2Predicate(name, Array[V2Expression](left.get, right.get)))
+ } else {
+ None
+ }
+ case Cast(child, dataType, _, true) =>
+ generateExpression(child).map(v => new V2Cast(v, dataType))
+ case Abs(child, true) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("ABS", Array[V2Expression](v)))
+ case Coalesce(children) =>
+ val childrenExpressions = children.flatMap(generateExpression(_))
+ if (children.length == childrenExpressions.length) {
+ Some(new GeneralScalarExpression("COALESCE", childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
+ case Log(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("LN", Array[V2Expression](v)))
+ case Exp(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("EXP", Array[V2Expression](v)))
+ case Pow(left, right) =>
+ val l = generateExpression(left)
+ val r = generateExpression(right)
+ if (l.isDefined && r.isDefined) {
+ Some(new GeneralScalarExpression("POWER", Array[V2Expression](l.get, r.get)))
+ } else {
+ None
+ }
+ case Sqrt(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("SQRT", Array[V2Expression](v)))
+ case Floor(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("FLOOR", Array[V2Expression](v)))
+ case Ceil(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("CEIL", Array[V2Expression](v)))
+ case wb: WidthBucket =>
+ val childrenExpressions = wb.children.flatMap(generateExpression(_))
+ if (childrenExpressions.length == wb.children.length) {
+ Some(new GeneralScalarExpression("WIDTH_BUCKET",
+ childrenExpressions.toArray[V2Expression]))
+ } else {
+ None
+ }
+ case and: And =>
+ // AND expects predicate
+ val l = generateExpression(and.left, true)
+ val r = generateExpression(and.right, true)
+ if (l.isDefined && r.isDefined) {
+ assert(l.get.isInstanceOf[V2Predicate] && r.get.isInstanceOf[V2Predicate])
+ Some(new V2And(l.get.asInstanceOf[V2Predicate], r.get.asInstanceOf[V2Predicate]))
+ } else {
+ None
+ }
+ case or: Or =>
+ // OR expects predicate
+ val l = generateExpression(or.left, true)
+ val r = generateExpression(or.right, true)
+ if (l.isDefined && r.isDefined) {
+ assert(l.get.isInstanceOf[V2Predicate] && r.get.isInstanceOf[V2Predicate])
+ Some(new V2Or(l.get.asInstanceOf[V2Predicate], r.get.asInstanceOf[V2Predicate]))
+ } else {
+ None
+ }
+ case b: BinaryOperator if canTranslate(b) =>
+ val l = generateExpression(b.left)
+ val r = generateExpression(b.right)
+ if (l.isDefined && r.isDefined) {
+ b match {
+ case _: Predicate =>
+ Some(new V2Predicate(b.sqlOperator, Array[V2Expression](l.get, r.get)))
+ case _ =>
+ Some(new GeneralScalarExpression(b.sqlOperator, Array[V2Expression](l.get, r.get)))
+ }
+ } else {
+ None
+ }
+ case Not(eq: EqualTo) =>
+ val left = generateExpression(eq.left)
+ val right = generateExpression(eq.right)
+ if (left.isDefined && right.isDefined) {
+ Some(new V2Predicate("<>", Array[V2Expression](left.get, right.get)))
+ } else {
+ None
+ }
+ case Not(child) => generateExpression(child, true) // NOT expects predicate
+ .map { v =>
+ assert(v.isInstanceOf[V2Predicate])
+ new V2Not(v.asInstanceOf[V2Predicate])
+ }
+ case UnaryMinus(child, true) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("-", Array[V2Expression](v)))
+ case BitwiseNot(child) => generateExpression(child)
+ .map(v => new GeneralScalarExpression("~", Array[V2Expression](v)))
+ case CaseWhen(branches, elseValue) =>
+ val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
+ val values = branches.map(_._2).flatMap(generateExpression(_, true))
+ if (conditions.length == branches.length && values.length == branches.length) {
+ val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
+ Seq[V2Expression](c, v)
+ }
+ if (elseValue.isDefined) {
+ elseValue.flatMap(generateExpression(_)).map { v =>
+ val children = (branchExpressions :+ v).toArray[V2Expression]
+ // The children looks like [condition1, value1, ..., conditionN, valueN, elseValue]
+ new V2Predicate("CASE_WHEN", children)
+ }
+ } else {
+ // The children looks like [condition1, value1, ..., conditionN, valueN]
+ Some(new V2Predicate("CASE_WHEN", branchExpressions.toArray[V2Expression]))
+ }
+ } else {
+ None
+ }
+ // TODO supports other expressions
+ case _ => None
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index efc459c8241fa..432775c9045ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
-import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
+import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.{BaseRelation, Filter}
@@ -103,7 +103,7 @@ case class RowDataSourceScanExec(
requiredSchema: StructType,
filters: Set[Filter],
handledFilters: Set[Filter],
- aggregation: Option[Aggregation],
+ pushedDownOperators: PushedDownOperators,
rdd: RDD[InternalRow],
@transient relation: BaseRelation,
tableIdentifier: Option[TableIdentifier])
@@ -134,13 +134,6 @@ case class RowDataSourceScanExec(
def seqToString(seq: Seq[Any]): String = seq.mkString("[", ", ", "]")
- val (aggString, groupByString) = if (aggregation.nonEmpty) {
- (seqToString(aggregation.get.aggregateExpressions),
- seqToString(aggregation.get.groupByColumns))
- } else {
- ("[]", "[]")
- }
-
val markedFilters = if (filters.nonEmpty) {
for (filter <- filters) yield {
if (handledFilters.contains(filter)) s"*$filter" else s"$filter"
@@ -149,11 +142,31 @@ case class RowDataSourceScanExec(
handledFilters
}
- Map(
- "ReadSchema" -> requiredSchema.catalogString,
- "PushedFilters" -> seqToString(markedFilters.toSeq),
- "PushedAggregates" -> aggString,
- "PushedGroupby" -> groupByString)
+ val topNOrLimitInfo =
+ if (pushedDownOperators.limit.isDefined && pushedDownOperators.sortValues.nonEmpty) {
+ val pushedTopN =
+ s"ORDER BY ${seqToString(pushedDownOperators.sortValues.map(_.describe()))}" +
+ s" LIMIT ${pushedDownOperators.limit.get}"
+ Some("PushedTopN" -> pushedTopN)
+ } else {
+ pushedDownOperators.limit.map(value => "PushedLimit" -> s"LIMIT $value")
+ }
+
+ val pushedFilters = if (pushedDownOperators.pushedPredicates.nonEmpty) {
+ seqToString(pushedDownOperators.pushedPredicates.map(_.describe()))
+ } else {
+ seqToString(markedFilters.toSeq)
+ }
+
+ Map("ReadSchema" -> requiredSchema.catalogString,
+ "PushedFilters" -> pushedFilters) ++
+ pushedDownOperators.aggregation.fold(Map[String, String]()) { v =>
+ Map("PushedAggregates" -> seqToString(v.aggregateExpressions.map(_.describe())),
+ "PushedGroupByColumns" -> seqToString(v.groupByColumns.map(_.describe())))} ++
+ topNOrLimitInfo ++
+ pushedDownOperators.sample.map(v => "PushedSample" ->
+ s"SAMPLE (${(v.upperBound - v.lowerBound) * 100}) ${v.withReplacement} SEED(${v.seed})"
+ )
}
// Don't care about `rdd` and `tableIdentifier` when canonicalizing.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
new file mode 100644
index 0000000000000..6d8cae544f23e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/AggregatePushDownUtils.scala
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.execution.RowToColumnConverter
+import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
+import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
+import org.apache.spark.sql.types.{BooleanType, ByteType, DateType, DoubleType, FloatType, IntegerType, LongType, ShortType, StructField, StructType}
+import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+
+/**
+ * Utility class for aggregate push down to Parquet and ORC.
+ */
+object AggregatePushDownUtils {
+
+ /**
+ * Get the data schema for aggregate to be pushed down.
+ */
+ def getSchemaForPushedAggregation(
+ aggregation: Aggregation,
+ schema: StructType,
+ partitionNames: Set[String],
+ dataFilters: Seq[Expression]): Option[StructType] = {
+
+ var finalSchema = new StructType()
+
+ def getStructFieldForCol(colName: String): StructField = {
+ schema.apply(colName)
+ }
+
+ def isPartitionCol(colName: String) = {
+ partitionNames.contains(colName)
+ }
+
+ def processMinOrMax(agg: AggregateFunc): Boolean = {
+ val (columnName, aggType) = agg match {
+ case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined =>
+ (V2ColumnUtils.extractV2Column(max.column).get, "max")
+ case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined =>
+ (V2ColumnUtils.extractV2Column(min.column).get, "min")
+ case _ => return false
+ }
+
+ if (isPartitionCol(columnName)) {
+ // don't push down partition column, footer doesn't have max/min for partition column
+ return false
+ }
+ val structField = getStructFieldForCol(columnName)
+
+ structField.dataType match {
+ // not push down complex type
+ // not push down Timestamp because INT96 sort order is undefined,
+ // Parquet doesn't return statistics for INT96
+ // not push down Parquet Binary because min/max could be truncated
+ // (https://issues.apache.org/jira/browse/PARQUET-1685), Parquet Binary
+ // could be Spark StringType, BinaryType or DecimalType.
+ // not push down for ORC with same reason.
+ case BooleanType | ByteType | ShortType | IntegerType
+ | LongType | FloatType | DoubleType | DateType =>
+ finalSchema = finalSchema.add(structField.copy(s"$aggType(" + structField.name + ")"))
+ true
+ case _ =>
+ false
+ }
+ }
+
+ if (aggregation.groupByColumns.nonEmpty || dataFilters.nonEmpty) {
+ // Parquet/ORC footer has max/min/count for columns
+ // e.g. SELECT COUNT(col1) FROM t
+ // but footer doesn't have max/min/count for a column if max/min/count
+ // are combined with filter or group by
+ // e.g. SELECT COUNT(col1) FROM t WHERE col2 = 8
+ // SELECT COUNT(col1) FROM t GROUP BY col2
+ // However, if the filter is on partition column, max/min/count can still be pushed down
+ // Todo: add support if groupby column is partition col
+ // (https://issues.apache.org/jira/browse/SPARK-36646)
+ return None
+ }
+ aggregation.groupByColumns.foreach { col =>
+ // don't push down if the group by columns are not the same as the partition columns (orders
+ // doesn't matter because reorder can be done at data source layer)
+ if (col.fieldNames.length != 1 || !isPartitionCol(col.fieldNames.head)) return None
+ finalSchema = finalSchema.add(getStructFieldForCol(col.fieldNames.head))
+ }
+
+ aggregation.aggregateExpressions.foreach {
+ case max: Max =>
+ if (!processMinOrMax(max)) return None
+ case min: Min =>
+ if (!processMinOrMax(min)) return None
+ case count: Count
+ if V2ColumnUtils.extractV2Column(count.column).isDefined && !count.isDistinct =>
+ val columnName = V2ColumnUtils.extractV2Column(count.column).get
+ finalSchema = finalSchema.add(StructField(s"count($columnName)", LongType))
+ case _: CountStar =>
+ finalSchema = finalSchema.add(StructField("count(*)", LongType))
+ case _ =>
+ return None
+ }
+
+ Some(finalSchema)
+ }
+
+ /**
+ * Check if two Aggregation `a` and `b` is equal or not.
+ */
+ def equivalentAggregations(a: Aggregation, b: Aggregation): Boolean = {
+ a.aggregateExpressions.sortBy(_.hashCode())
+ .sameElements(b.aggregateExpressions.sortBy(_.hashCode())) &&
+ a.groupByColumns.sortBy(_.hashCode()).sameElements(b.groupByColumns.sortBy(_.hashCode()))
+ }
+
+ /**
+ * Convert the aggregates result from `InternalRow` to `ColumnarBatch`.
+ * This is used for columnar reader.
+ */
+ def convertAggregatesRowToBatch(
+ aggregatesAsRow: InternalRow,
+ aggregatesSchema: StructType,
+ offHeap: Boolean): ColumnarBatch = {
+ val converter = new RowToColumnConverter(aggregatesSchema)
+ val columnVectors = if (offHeap) {
+ OffHeapColumnVector.allocateColumns(1, aggregatesSchema)
+ } else {
+ OnHeapColumnVector.allocateColumns(1, aggregatesSchema)
+ }
+ converter.convert(aggregatesAsRow, columnVectors.toArray)
+ new ColumnarBatch(columnVectors.asInstanceOf[Array[ColumnVector]], 1)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index a53665fe2f0e4..408da524cbb04 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -38,13 +38,15 @@ import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoStatement, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2
+import org.apache.spark.sql.catalyst.util.V2ExpressionBuilder
import org.apache.spark.sql.connector.catalog.SupportsRead
import org.apache.spark.sql.connector.catalog.TableCapability._
-import org.apache.spark.sql.connector.expressions.FieldReference
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.{InSubqueryExec, RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
+import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
import org.apache.spark.sql.execution.streaming.StreamingRelation
import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy
import org.apache.spark.sql.sources._
@@ -335,7 +337,7 @@ object DataSourceStrategy
l.output.toStructType,
Set.empty,
Set.empty,
- None,
+ PushedDownOperators(None, None, None, Seq.empty, Seq.empty),
toCatalystRDD(l, baseRelation.buildScan()),
baseRelation,
None) :: Nil
@@ -409,7 +411,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
- None,
+ PushedDownOperators(None, None, None, Seq.empty, Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
@@ -432,7 +434,7 @@ object DataSourceStrategy
requestedColumns.toStructType,
pushedFilters.toSet,
handledFilters,
- None,
+ PushedDownOperators(None, None, None, Seq.empty, Seq.empty),
scanBuilder(requestedColumns, candidatePredicates, pushedFilters),
relation.relation,
relation.catalogTable.map(_.identifier))
@@ -698,23 +700,44 @@ object DataSourceStrategy
(nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters)
}
- protected[sql] def translateAggregate(aggregates: AggregateExpression): Option[AggregateFunc] = {
- if (aggregates.filter.isEmpty) {
- aggregates.aggregateFunction match {
- case aggregate.Min(PushableColumnWithoutNestedColumn(name)) =>
- Some(new Min(FieldReference(name)))
- case aggregate.Max(PushableColumnWithoutNestedColumn(name)) =>
- Some(new Max(FieldReference(name)))
+ protected[sql] def translateAggregate(agg: AggregateExpression): Option[AggregateFunc] = {
+ if (agg.filter.isEmpty) {
+ agg.aggregateFunction match {
+ case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr))
+ case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr))
case count: aggregate.Count if count.children.length == 1 =>
count.children.head match {
- // SELECT COUNT(*) FROM table is translated to SELECT 1 FROM table
+ // COUNT(any literal) is the same as COUNT(*)
case Literal(_, _) => Some(new CountStar())
- case PushableColumnWithoutNestedColumn(name) =>
- Some(new Count(FieldReference(name), aggregates.isDistinct))
+ case PushableExpression(expr) => Some(new Count(expr, agg.isDistinct))
case _ => None
}
- case sum @ aggregate.Sum(PushableColumnWithoutNestedColumn(name), _) =>
- Some(new Sum(FieldReference(name), aggregates.isDistinct))
+ case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, agg.isDistinct))
+ case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, agg.isDistinct))
+ case aggregate.VariancePop(PushableColumnWithoutNestedColumn(name), _) =>
+ Some(new GeneralAggregateFunc(
+ "VAR_POP", agg.isDistinct, Array(FieldReference(name))))
+ case aggregate.VarianceSamp(PushableColumnWithoutNestedColumn(name), _) =>
+ Some(new GeneralAggregateFunc(
+ "VAR_SAMP", agg.isDistinct, Array(FieldReference(name))))
+ case aggregate.StddevPop(PushableColumnWithoutNestedColumn(name), _) =>
+ Some(new GeneralAggregateFunc(
+ "STDDEV_POP", agg.isDistinct, Array(FieldReference(name))))
+ case aggregate.StddevSamp(PushableColumnWithoutNestedColumn(name), _) =>
+ Some(new GeneralAggregateFunc(
+ "STDDEV_SAMP", agg.isDistinct, Array(FieldReference(name))))
+ case aggregate.CovPopulation(PushableColumnWithoutNestedColumn(left),
+ PushableColumnWithoutNestedColumn(right), _) =>
+ Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct,
+ Array(FieldReference(left), FieldReference(right))))
+ case aggregate.CovSample(PushableColumnWithoutNestedColumn(left),
+ PushableColumnWithoutNestedColumn(right), _) =>
+ Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct,
+ Array(FieldReference(left), FieldReference(right))))
+ case aggregate.Corr(PushableColumnWithoutNestedColumn(left),
+ PushableColumnWithoutNestedColumn(right), _) =>
+ Some(new GeneralAggregateFunc("CORR", agg.isDistinct,
+ Array(FieldReference(left), FieldReference(right))))
case _ => None
}
} else {
@@ -722,6 +745,49 @@ object DataSourceStrategy
}
}
+ /**
+ * Translate aggregate expressions and group by expressions.
+ *
+ * @return translated aggregation.
+ */
+ protected[sql] def translateAggregation(
+ aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = {
+
+ def columnAsString(e: Expression): Option[FieldReference] = e match {
+ case PushableColumnWithoutNestedColumn(name) =>
+ Some(FieldReference(name).asInstanceOf[FieldReference])
+ case _ => None
+ }
+
+ val translatedAggregates = aggregates.flatMap(translateAggregate)
+ val translatedGroupBys = groupBy.flatMap(columnAsString)
+
+ if (translatedAggregates.length != aggregates.length ||
+ translatedGroupBys.length != groupBy.length) {
+ return None
+ }
+
+ Some(new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray))
+ }
+
+ protected[sql] def translateSortOrders(sortOrders: Seq[SortOrder]): Seq[V2SortOrder] = {
+ def translateOortOrder(sortOrder: SortOrder): Option[V2SortOrder] = sortOrder match {
+ case SortOrder(PushableColumnWithoutNestedColumn(name), directionV1, nullOrderingV1, _) =>
+ val directionV2 = directionV1 match {
+ case Ascending => SortDirection.ASCENDING
+ case Descending => SortDirection.DESCENDING
+ }
+ val nullOrderingV2 = nullOrderingV1 match {
+ case NullsFirst => NullOrdering.NULLS_FIRST
+ case NullsLast => NullOrdering.NULLS_LAST
+ }
+ Some(SortValue(FieldReference(name), directionV2, nullOrderingV2))
+ case _ => None
+ }
+
+ sortOrders.flatMap(translateOortOrder)
+ }
+
/**
* Convert RDD of Row into RDD of InternalRow with objects in catalyst types
*/
@@ -787,3 +853,10 @@ object PushableColumnAndNestedColumn extends PushableColumnBase {
object PushableColumnWithoutNestedColumn extends PushableColumnBase {
override val nestedPredicatePushdownEnabled = false
}
+
+/**
+ * Get the expression of DS V2 to represent catalyst expression that can be pushed down.
+ */
+object PushableExpression {
+ def unapply(e: Expression): Option[V2Expression] = new V2ExpressionBuilder(e).build()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
index fcd95a27bf8ca..67d03998a2a24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala
@@ -28,6 +28,7 @@ import org.json4s.jackson.Serialization
import org.apache.spark.SparkUpgradeException
import org.apache.spark.sql.{SPARK_LEGACY_DATETIME, SPARK_LEGACY_INT96, SPARK_VERSION_METADATA_KEY}
import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper}
import org.apache.spark.sql.catalyst.util.RebaseDateTime
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions
@@ -39,7 +40,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils
-object DataSourceUtils {
+object DataSourceUtils extends PredicateHelper {
/**
* The key to use for storing partitionBy columns as options.
*/
@@ -242,4 +243,22 @@ object DataSourceUtils {
options
}
}
+
+ def getPartitionFiltersAndDataFilters(
+ partitionSchema: StructType,
+ normalizedFilters: Seq[Expression]): (Seq[Expression], Seq[Expression]) = {
+ val partitionColumns = normalizedFilters.flatMap { expr =>
+ expr.collect {
+ case attr: AttributeReference if partitionSchema.names.contains(attr.name) =>
+ attr
+ }
+ }
+ val partitionSet = AttributeSet(partitionColumns)
+ val (partitionFilters, dataFilters) = normalizedFilters.partition(f =>
+ f.references.subsetOf(partitionSet)
+ )
+ val extraPartitionFilter =
+ dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet))
+ (ExpressionSet(partitionFilters ++ extraPartitionFilter).toSeq, dataFilters)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
index 0927027bee0bc..2e8e5426d47be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitions.scala
@@ -17,52 +17,24 @@
package org.apache.spark.sql.execution.datasources
-import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.CatalogStatistics
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LeafNode, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.FilterEstimation
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, FileScan}
-import org.apache.spark.sql.types.StructType
/**
* Prune the partitions of file source based table using partition filters. Currently, this rule
- * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]] and [[DataSourceV2ScanRelation]]
- * with [[FileScan]].
+ * is applied to [[HadoopFsRelation]] with [[CatalogFileIndex]].
*
* For [[HadoopFsRelation]], the location will be replaced by pruned file index, and corresponding
* statistics will be updated. And the partition filters will be kept in the filters of returned
* logical plan.
- *
- * For [[DataSourceV2ScanRelation]], both partition filters and data filters will be added to
- * its underlying [[FileScan]]. And the partition filters will be removed in the filters of
- * returned logical plan.
*/
private[sql] object PruneFileSourcePartitions
extends Rule[LogicalPlan] with PredicateHelper {
- private def getPartitionKeyFiltersAndDataFilters(
- sparkSession: SparkSession,
- relation: LeafNode,
- partitionSchema: StructType,
- filters: Seq[Expression],
- output: Seq[AttributeReference]): (ExpressionSet, Seq[Expression]) = {
- val normalizedFilters = DataSourceStrategy.normalizeExprs(
- filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)), output)
- val partitionColumns =
- relation.resolve(partitionSchema, sparkSession.sessionState.analyzer.resolver)
- val partitionSet = AttributeSet(partitionColumns)
- val (partitionFilters, dataFilters) = normalizedFilters.partition(f =>
- f.references.subsetOf(partitionSet)
- )
- val extraPartitionFilter =
- dataFilters.flatMap(extractPredicatesWithinOutputSet(_, partitionSet))
-
- (ExpressionSet(partitionFilters ++ extraPartitionFilter), dataFilters)
- }
-
private def rebuildPhysicalOperation(
projects: Seq[NamedExpression],
filters: Seq[Expression],
@@ -91,12 +63,14 @@ private[sql] object PruneFileSourcePartitions
_,
_))
if filters.nonEmpty && fsRelation.partitionSchemaOption.isDefined =>
- val (partitionKeyFilters, _) = getPartitionKeyFiltersAndDataFilters(
- fsRelation.sparkSession, logicalRelation, partitionSchema, filters,
+ val normalizedFilters = DataSourceStrategy.normalizeExprs(
+ filters.filter(f => f.deterministic && !SubqueryExpression.hasSubquery(f)),
logicalRelation.output)
+ val (partitionKeyFilters, _) = DataSourceUtils
+ .getPartitionFiltersAndDataFilters(partitionSchema, normalizedFilters)
if (partitionKeyFilters.nonEmpty) {
- val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters.toSeq)
+ val prunedFileIndex = catalogFileIndex.filterPartitions(partitionKeyFilters)
val prunedFsRelation =
fsRelation.copy(location = prunedFileIndex)(fsRelation.sparkSession)
// Change table stats based on the sizeInBytes of pruned files
@@ -117,23 +91,5 @@ private[sql] object PruneFileSourcePartitions
} else {
op
}
-
- case op @ PhysicalOperation(projects, filters,
- v2Relation @ DataSourceV2ScanRelation(_, scan: FileScan, output))
- if filters.nonEmpty =>
- val (partitionKeyFilters, dataFilters) =
- getPartitionKeyFiltersAndDataFilters(scan.sparkSession, v2Relation,
- scan.readPartitionSchema, filters, output)
- // The dataFilters are pushed down only once
- if (partitionKeyFilters.nonEmpty || (dataFilters.nonEmpty && scan.dataFilters.isEmpty)) {
- val prunedV2Relation =
- v2Relation.copy(scan = scan.withFilters(partitionKeyFilters.toSeq, dataFilters))
- // The pushed down partition filters don't need to be reevaluated.
- val afterScanFilters =
- ExpressionSet(filters) -- partitionKeyFilters.filter(_.references.nonEmpty)
- rebuildPhysicalOperation(projects, afterScanFilters.toSeq, prunedV2Relation)
- } else {
- op
- }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
index 8b2ae2beb6d4a..8e047d7f7c7d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala
@@ -191,6 +191,14 @@ class JDBCOptions(
// An option to allow/disallow pushing down aggregate into JDBC data source
val pushDownAggregate = parameters.getOrElse(JDBC_PUSHDOWN_AGGREGATE, "false").toBoolean
+ // An option to allow/disallow pushing down LIMIT into V2 JDBC data source
+ // This only applies to Data Source V2 JDBC
+ val pushDownLimit = parameters.getOrElse(JDBC_PUSHDOWN_LIMIT, "false").toBoolean
+
+ // An option to allow/disallow pushing down TABLESAMPLE into JDBC data source
+ // This only applies to Data Source V2 JDBC
+ val pushDownTableSample = parameters.getOrElse(JDBC_PUSHDOWN_TABLESAMPLE, "false").toBoolean
+
// The local path of user's keytab file, which is assumed to be pre-uploaded to all nodes either
// by --files option of spark-submit or manually
val keytab = {
@@ -263,6 +271,8 @@ object JDBCOptions {
val JDBC_SESSION_INIT_STATEMENT = newOption("sessionInitStatement")
val JDBC_PUSHDOWN_PREDICATE = newOption("pushDownPredicate")
val JDBC_PUSHDOWN_AGGREGATE = newOption("pushDownAggregate")
+ val JDBC_PUSHDOWN_LIMIT = newOption("pushDownLimit")
+ val JDBC_PUSHDOWN_TABLESAMPLE = newOption("pushDownTableSample")
val JDBC_KEYTAB = newOption("keytab")
val JDBC_PRINCIPAL = newOption("principal")
val JDBC_TABLE_COMMENT = newOption("tableComment")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index e024e4bb02102..b30b460ac67db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -25,9 +25,10 @@ import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskCon
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Count, CountStar, Max, Min, Sum}
+import org.apache.spark.sql.connector.expressions.SortOrder
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects}
-import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.CompletionIterator
@@ -59,7 +60,7 @@ object JDBCRDD extends Logging {
def getQueryOutputSchema(
query: String, options: JDBCOptions, dialect: JdbcDialect): StructType = {
- val conn: Connection = JdbcUtils.createConnectionFactory(options)()
+ val conn: Connection = dialect.createConnectionFactory(options)(-1)
try {
val statement = conn.prepareStatement(query)
try {
@@ -91,106 +92,38 @@ object JDBCRDD extends Logging {
new StructType(columns.map(name => fieldMap(name)))
}
- /**
- * Turns a single Filter into a String representing a SQL expression.
- * Returns None for an unhandled filter.
- */
- def compileFilter(f: Filter, dialect: JdbcDialect): Option[String] = {
- def quote(colName: String): String = dialect.quoteIdentifier(colName)
-
- Option(f match {
- case EqualTo(attr, value) => s"${quote(attr)} = ${dialect.compileValue(value)}"
- case EqualNullSafe(attr, value) =>
- val col = quote(attr)
- s"(NOT ($col != ${dialect.compileValue(value)} OR $col IS NULL OR " +
- s"${dialect.compileValue(value)} IS NULL) OR " +
- s"($col IS NULL AND ${dialect.compileValue(value)} IS NULL))"
- case LessThan(attr, value) => s"${quote(attr)} < ${dialect.compileValue(value)}"
- case GreaterThan(attr, value) => s"${quote(attr)} > ${dialect.compileValue(value)}"
- case LessThanOrEqual(attr, value) => s"${quote(attr)} <= ${dialect.compileValue(value)}"
- case GreaterThanOrEqual(attr, value) => s"${quote(attr)} >= ${dialect.compileValue(value)}"
- case IsNull(attr) => s"${quote(attr)} IS NULL"
- case IsNotNull(attr) => s"${quote(attr)} IS NOT NULL"
- case StringStartsWith(attr, value) => s"${quote(attr)} LIKE '${value}%'"
- case StringEndsWith(attr, value) => s"${quote(attr)} LIKE '%${value}'"
- case StringContains(attr, value) => s"${quote(attr)} LIKE '%${value}%'"
- case In(attr, value) if value.isEmpty =>
- s"CASE WHEN ${quote(attr)} IS NULL THEN NULL ELSE FALSE END"
- case In(attr, value) => s"${quote(attr)} IN (${dialect.compileValue(value)})"
- case Not(f) => compileFilter(f, dialect).map(p => s"(NOT ($p))").getOrElse(null)
- case Or(f1, f2) =>
- // We can't compile Or filter unless both sub-filters are compiled successfully.
- // It applies too for the following And filter.
- // If we can make sure compileFilter supports all filters, we can remove this check.
- val or = Seq(f1, f2).flatMap(compileFilter(_, dialect))
- if (or.size == 2) {
- or.map(p => s"($p)").mkString(" OR ")
- } else {
- null
- }
- case And(f1, f2) =>
- val and = Seq(f1, f2).flatMap(compileFilter(_, dialect))
- if (and.size == 2) {
- and.map(p => s"($p)").mkString(" AND ")
- } else {
- null
- }
- case _ => null
- })
- }
-
- def compileAggregates(
- aggregates: Seq[AggregateFunc],
- dialect: JdbcDialect): Option[Seq[String]] = {
- def quote(colName: String): String = dialect.quoteIdentifier(colName)
-
- Some(aggregates.map {
- case min: Min =>
- if (min.column.fieldNames.length != 1) return None
- s"MIN(${quote(min.column.fieldNames.head)})"
- case max: Max =>
- if (max.column.fieldNames.length != 1) return None
- s"MAX(${quote(max.column.fieldNames.head)})"
- case count: Count =>
- if (count.column.fieldNames.length != 1) return None
- val distinct = if (count.isDistinct) "DISTINCT " else ""
- val column = quote(count.column.fieldNames.head)
- s"COUNT($distinct$column)"
- case sum: Sum =>
- if (sum.column.fieldNames.length != 1) return None
- val distinct = if (sum.isDistinct) "DISTINCT " else ""
- val column = quote(sum.column.fieldNames.head)
- s"SUM($distinct$column)"
- case _: CountStar =>
- s"COUNT(*)"
- case _ => return None
- })
- }
-
/**
* Build and return JDBCRDD from the given information.
*
* @param sc - Your SparkContext.
* @param schema - The Catalyst schema of the underlying database table.
* @param requiredColumns - The names of the columns or aggregate columns to SELECT.
- * @param filters - The filters to include in all WHERE clauses.
+ * @param predicates - The predicates to include in all WHERE clauses.
* @param parts - An array of JDBCPartitions specifying partition ids and
* per-partition WHERE clauses.
* @param options - JDBC options that contains url, table and other information.
* @param outputSchema - The schema of the columns or aggregate columns to SELECT.
* @param groupByColumns - The pushed down group by columns.
+ * @param sample - The pushed down tableSample.
+ * @param limit - The pushed down limit. If the value is 0, it means no limit or limit
+ * is not pushed down.
+ * @param sortOrders - The sort orders cooperates with limit to realize top N.
*
* @return An RDD representing "SELECT requiredColumns FROM fqTable".
*/
+ // scalastyle:off argcount
def scanTable(
sc: SparkContext,
schema: StructType,
requiredColumns: Array[String],
- filters: Array[Filter],
+ predicates: Array[Predicate],
parts: Array[Partition],
options: JDBCOptions,
outputSchema: Option[StructType] = None,
- groupByColumns: Option[Array[String]] = None): RDD[InternalRow] = {
+ groupByColumns: Option[Array[String]] = None,
+ sample: Option[TableSampleInfo] = None,
+ limit: Int = 0,
+ sortOrders: Array[SortOrder] = Array.empty[SortOrder]): RDD[InternalRow] = {
val url = options.url
val dialect = JdbcDialects.get(url)
val quotedColumns = if (groupByColumns.isEmpty) {
@@ -201,15 +134,19 @@ object JDBCRDD extends Logging {
}
new JDBCRDD(
sc,
- JdbcUtils.createConnectionFactory(options),
+ dialect.createConnectionFactory(options),
outputSchema.getOrElse(pruneSchema(schema, requiredColumns)),
quotedColumns,
- filters,
+ predicates,
parts,
url,
options,
- groupByColumns)
+ groupByColumns,
+ sample,
+ limit,
+ sortOrders)
}
+ // scalastyle:on argcount
}
/**
@@ -219,14 +156,17 @@ object JDBCRDD extends Logging {
*/
private[jdbc] class JDBCRDD(
sc: SparkContext,
- getConnection: () => Connection,
+ getConnection: Int => Connection,
schema: StructType,
columns: Array[String],
- filters: Array[Filter],
+ predicates: Array[Predicate],
partitions: Array[Partition],
url: String,
options: JDBCOptions,
- groupByColumns: Option[Array[String]])
+ groupByColumns: Option[Array[String]],
+ sample: Option[TableSampleInfo],
+ limit: Int,
+ sortOrders: Array[SortOrder])
extends RDD[InternalRow](sc, Nil) {
/**
@@ -242,10 +182,10 @@ private[jdbc] class JDBCRDD(
/**
* `filters`, but as a WHERE clause suitable for injection into a SQL query.
*/
- private val filterWhereClause: String =
- filters
- .flatMap(JDBCRDD.compileFilter(_, JdbcDialects.get(url)))
- .map(p => s"($p)").mkString(" AND ")
+ private val filterWhereClause: String = {
+ val dialect = JdbcDialects.get(url)
+ predicates.flatMap(dialect.compileExpression(_)).map(p => s"($p)").mkString(" AND ")
+ }
/**
* A WHERE clause representing both `filters`, if any, and the current partition.
@@ -274,6 +214,14 @@ private[jdbc] class JDBCRDD(
}
}
+ private def getOrderByClause: String = {
+ if (sortOrders.nonEmpty) {
+ s" ORDER BY ${sortOrders.map(_.describe()).mkString(", ")}"
+ } else {
+ ""
+ }
+ }
+
/**
* Runs the SQL query against the JDBC driver.
*
@@ -322,7 +270,7 @@ private[jdbc] class JDBCRDD(
val inputMetrics = context.taskMetrics().inputMetrics
val part = thePart.asInstanceOf[JDBCPartition]
- conn = getConnection()
+ conn = getConnection(part.idx)
val dialect = JdbcDialects.get(url)
import scala.collection.JavaConverters._
dialect.beforeFetch(conn, options.asProperties.asScala.toMap)
@@ -349,8 +297,16 @@ private[jdbc] class JDBCRDD(
val myWhereClause = getWhereClause(part)
- val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myWhereClause" +
- s" $getGroupByClause"
+ val myTableSampleClause: String = if (sample.nonEmpty) {
+ JdbcDialects.get(url).getTableSample(sample.get)
+ } else {
+ ""
+ }
+
+ val myLimitClause: String = dialect.getLimitClause(limit)
+
+ val sqlText = s"SELECT $columnList FROM ${options.tableOrQuery} $myTableSampleClause" +
+ s" $myWhereClause $getGroupByClause $getOrderByClause $myLimitClause"
stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
stmt.setFetchSize(options.fetchSize)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
index 8098fa0b83a95..0f1a1b6dc667b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala
@@ -27,7 +27,10 @@ import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp}
+import org.apache.spark.sql.connector.expressions.SortOrder
+import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.errors.QueryCompilationErrors
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources._
@@ -268,10 +271,11 @@ private[sql] case class JDBCRelation(
override val needConversion: Boolean = false
- // Check if JDBCRDD.compileFilter can accept input filters
+ // Check if JdbcDialect can compile input filters
override def unhandledFilters(filters: Array[Filter]): Array[Filter] = {
if (jdbcOptions.pushDownPredicate) {
- filters.filter(JDBCRDD.compileFilter(_, JdbcDialects.get(jdbcOptions.url)).isEmpty)
+ val dialect = JdbcDialects.get(jdbcOptions.url)
+ filters.filter(f => dialect.compileExpression(f.toV2).isEmpty)
} else {
filters
}
@@ -279,17 +283,17 @@ private[sql] case class JDBCRelation(
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
// When pushDownPredicate is false, all Filters that need to be pushed down should be ignored
- val pushedFilters = if (jdbcOptions.pushDownPredicate) {
- filters
+ val pushedPredicates = if (jdbcOptions.pushDownPredicate) {
+ filters.map(_.toV2)
} else {
- Array.empty[Filter]
+ Array.empty[Predicate]
}
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
schema,
requiredColumns,
- pushedFilters,
+ pushedPredicates,
parts,
jdbcOptions).asInstanceOf[RDD[Row]]
}
@@ -297,18 +301,24 @@ private[sql] case class JDBCRelation(
def buildScan(
requiredColumns: Array[String],
finalSchema: StructType,
- filters: Array[Filter],
- groupByColumns: Option[Array[String]]): RDD[Row] = {
+ predicates: Array[Predicate],
+ groupByColumns: Option[Array[String]],
+ tableSample: Option[TableSampleInfo],
+ limit: Int,
+ sortOrders: Array[SortOrder]): RDD[Row] = {
// Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row]
JDBCRDD.scanTable(
sparkSession.sparkContext,
schema,
requiredColumns,
- filters,
+ predicates,
parts,
jdbcOptions,
Some(finalSchema),
- groupByColumns).asInstanceOf[RDD[Row]]
+ groupByColumns,
+ tableSample,
+ limit,
+ sortOrders).asInstanceOf[RDD[Row]]
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
index d953ba45cc2fb..2760c7ac3019c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcRelationProvider.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.jdbc
import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils._
+import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister, RelationProvider}
class JdbcRelationProvider extends CreatableRelationProvider
@@ -45,8 +46,8 @@ class JdbcRelationProvider extends CreatableRelationProvider
df: DataFrame): BaseRelation = {
val options = new JdbcOptionsInWrite(parameters)
val isCaseSensitive = sqlContext.conf.caseSensitiveAnalysis
-
- val conn = JdbcUtils.createConnectionFactory(options)()
+ val dialect = JdbcDialects.get(options.url)
+ val conn = dialect.createConnectionFactory(options)(-1)
try {
val tableExists = JdbcUtils.tableExists(conn, options)
if (tableExists) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 60fcaf94e1986..2d0cbcff8ecc2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -17,11 +17,14 @@
package org.apache.spark.sql.execution.datasources.jdbc
-import java.sql.{Connection, Driver, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
+import java.sql.{Connection, JDBCType, PreparedStatement, ResultSet, ResultSetMetaData, SQLException}
import java.time.{Instant, LocalDate}
+import java.util
import java.util.Locale
import java.util.concurrent.TimeUnit
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.util.Try
import scala.util.control.NonFatal
@@ -37,8 +40,9 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateToDays, toJavaDate, toJavaTimestamp}
import org.apache.spark.sql.connector.catalog.TableChange
+import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex}
+import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
-import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType}
import org.apache.spark.sql.types._
@@ -50,23 +54,6 @@ import org.apache.spark.util.NextIterator
* Util functions for JDBC tables.
*/
object JdbcUtils extends Logging {
- /**
- * Returns a factory for creating connections to the given JDBC URL.
- *
- * @param options - JDBC options that contains url, table and other information.
- * @throws IllegalArgumentException if the driver could not open a JDBC connection.
- */
- def createConnectionFactory(options: JDBCOptions): () => Connection = {
- val driverClass: String = options.driverClass
- () => {
- DriverRegistry.register(driverClass)
- val driver: Driver = DriverRegistry.get(driverClass)
- val connection = ConnectionProvider.create(driver, options.parameters)
- require(connection != null,
- s"The driver could not open a JDBC connection. Check the URL: ${options.url}")
- connection
- }
- }
/**
* Returns true if the table already exists in the JDBC database.
@@ -651,7 +638,6 @@ object JdbcUtils extends Logging {
* updated even with error if it doesn't support transaction, as there're dirty outputs.
*/
def savePartition(
- getConnection: () => Connection,
table: String,
iterator: Iterator[Row],
rddSchema: StructType,
@@ -662,7 +648,7 @@ object JdbcUtils extends Logging {
options: JDBCOptions): Unit = {
val outMetrics = TaskContext.get().taskMetrics().outputMetrics
- val conn = getConnection()
+ val conn = dialect.createConnectionFactory(options)(-1)
var committed = false
var finalIsolationLevel = Connection.TRANSACTION_NONE
@@ -874,7 +860,6 @@ object JdbcUtils extends Logging {
val table = options.table
val dialect = JdbcDialects.get(url)
val rddSchema = df.schema
- val getConnection: () => Connection = createConnectionFactory(options)
val batchSize = options.batchSize
val isolationLevel = options.isolationLevel
@@ -886,8 +871,7 @@ object JdbcUtils extends Logging {
case _ => df
}
repartitionedDF.rdd.foreachPartition { iterator => savePartition(
- getConnection, table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel,
- options)
+ table, iterator, rddSchema, insertStmt, batchSize, dialect, isolationLevel, options)
}
}
@@ -971,52 +955,108 @@ object JdbcUtils extends Logging {
}
/**
- * Creates a namespace.
+ * Creates a schema.
*/
- def createNamespace(
+ def createSchema(
conn: Connection,
options: JDBCOptions,
- namespace: String,
+ schema: String,
comment: String): Unit = {
+ val statement = conn.createStatement
+ try {
+ statement.setQueryTimeout(options.queryTimeout)
+ val dialect = JdbcDialects.get(options.url)
+ dialect.createSchema(statement, schema, comment)
+ } finally {
+ statement.close()
+ }
+ }
+
+ def schemaExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = {
val dialect = JdbcDialects.get(options.url)
- executeStatement(conn, options, s"CREATE SCHEMA ${dialect.quoteIdentifier(namespace)}")
- if (!comment.isEmpty) createNamespaceComment(conn, options, namespace, comment)
+ dialect.schemasExists(conn, options, schema)
}
- def createNamespaceComment(
+ def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = {
+ val dialect = JdbcDialects.get(options.url)
+ dialect.listSchemas(conn, options)
+ }
+
+ def alterSchemaComment(
conn: Connection,
options: JDBCOptions,
- namespace: String,
+ schema: String,
comment: String): Unit = {
val dialect = JdbcDialects.get(options.url)
- try {
- executeStatement(
- conn, options, dialect.getSchemaCommentQuery(namespace, comment))
- } catch {
- case e: Exception =>
- logWarning("Cannot create JDBC catalog comment. The catalog comment will be ignored.")
- }
+ executeStatement(conn, options, dialect.getSchemaCommentQuery(schema, comment))
}
- def removeNamespaceComment(
+ def removeSchemaComment(
conn: Connection,
options: JDBCOptions,
- namespace: String): Unit = {
+ schema: String): Unit = {
val dialect = JdbcDialects.get(options.url)
- try {
- executeStatement(conn, options, dialect.removeSchemaCommentQuery(namespace))
- } catch {
- case e: Exception =>
- logWarning("Cannot drop JDBC catalog comment.")
- }
+ executeStatement(conn, options, dialect.removeSchemaCommentQuery(schema))
+ }
+
+ /**
+ * Drops a schema from the JDBC database.
+ */
+ def dropSchema(
+ conn: Connection, options: JDBCOptions, schema: String, cascade: Boolean): Unit = {
+ val dialect = JdbcDialects.get(options.url)
+ executeStatement(conn, options, dialect.dropSchema(schema, cascade))
+ }
+
+ /**
+ * Create an index.
+ */
+ def createIndex(
+ conn: Connection,
+ indexName: String,
+ tableName: String,
+ columns: Array[NamedReference],
+ columnsProperties: util.Map[NamedReference, util.Map[String, String]],
+ properties: util.Map[String, String],
+ options: JDBCOptions): Unit = {
+ val dialect = JdbcDialects.get(options.url)
+ executeStatement(conn, options,
+ dialect.createIndex(indexName, tableName, columns, columnsProperties, properties))
+ }
+
+ /**
+ * Check if an index exists
+ */
+ def indexExists(
+ conn: Connection,
+ indexName: String,
+ tableName: String,
+ options: JDBCOptions): Boolean = {
+ val dialect = JdbcDialects.get(options.url)
+ dialect.indexExists(conn, indexName, tableName, options)
}
/**
- * Drops a namespace from the JDBC database.
+ * Drop an index.
*/
- def dropNamespace(conn: Connection, options: JDBCOptions, namespace: String): Unit = {
+ def dropIndex(
+ conn: Connection,
+ indexName: String,
+ tableName: String,
+ options: JDBCOptions): Unit = {
+ val dialect = JdbcDialects.get(options.url)
+ executeStatement(conn, options, dialect.dropIndex(indexName, tableName))
+ }
+
+ /**
+ * List all the indexes in a table.
+ */
+ def listIndexes(
+ conn: Connection,
+ tableName: String,
+ options: JDBCOptions): Array[TableIndex] = {
val dialect = JdbcDialects.get(options.url)
- executeStatement(conn, options, s"DROP SCHEMA ${dialect.quoteIdentifier(namespace)}")
+ dialect.listIndexes(conn, tableName, options)
}
private def executeStatement(conn: Connection, options: JDBCOptions, sql: String): Unit = {
@@ -1028,4 +1068,105 @@ object JdbcUtils extends Logging {
statement.close()
}
}
+
+ /**
+ * Check if index exists in a table
+ */
+ def checkIfIndexExists(
+ conn: Connection,
+ sql: String,
+ options: JDBCOptions): Boolean = {
+ val statement = conn.createStatement
+ try {
+ statement.setQueryTimeout(options.queryTimeout)
+ val rs = statement.executeQuery(sql)
+ rs.next
+ } catch {
+ case _: Exception =>
+ logWarning("Cannot retrieved index info.")
+ false
+ } finally {
+ statement.close()
+ }
+ }
+
+ /**
+ * Process index properties and return tuple of indexType and list of the other index properties.
+ */
+ def processIndexProperties(
+ properties: util.Map[String, String],
+ catalogName: String): (String, Array[String]) = {
+ var indexType = ""
+ val indexPropertyList: ArrayBuffer[String] = ArrayBuffer[String]()
+ val supportedIndexTypeList = getSupportedIndexTypeList(catalogName)
+
+ if (!properties.isEmpty) {
+ properties.asScala.foreach { case (k, v) =>
+ if (k.equals(SupportsIndex.PROP_TYPE)) {
+ if (containsIndexTypeIgnoreCase(supportedIndexTypeList, v)) {
+ indexType = s"USING $v"
+ } else {
+ throw new UnsupportedOperationException(s"Index Type $v is not supported." +
+ s" The supported Index Types are: ${supportedIndexTypeList.mkString(" AND ")}")
+ }
+ } else {
+ indexPropertyList.append(s"$k = $v")
+ }
+ }
+ }
+ (indexType, indexPropertyList.toArray)
+ }
+
+ def containsIndexTypeIgnoreCase(supportedIndexTypeList: Array[String], value: String): Boolean = {
+ if (supportedIndexTypeList.isEmpty) {
+ throw new UnsupportedOperationException(
+ "Cannot specify 'USING index_type' in 'CREATE INDEX'")
+ }
+ for (indexType <- supportedIndexTypeList) {
+ if (value.equalsIgnoreCase(indexType)) return true
+ }
+ false
+ }
+
+ def getSupportedIndexTypeList(catalogName: String): Array[String] = {
+ catalogName match {
+ case "mysql" => Array("BTREE", "HASH")
+ case "postgresql" => Array("BTREE", "HASH", "BRIN")
+ case _ => Array.empty
+ }
+ }
+
+ def executeQuery(conn: Connection, options: JDBCOptions, sql: String)(
+ f: ResultSet => Unit): Unit = {
+ val statement = conn.createStatement
+ try {
+ statement.setQueryTimeout(options.queryTimeout)
+ val rs = statement.executeQuery(sql)
+ try {
+ f(rs)
+ } finally {
+ rs.close()
+ }
+ } finally {
+ statement.close()
+ }
+ }
+
+ def classifyException[T](message: String, dialect: JdbcDialect)(f: => T): T = {
+ try {
+ f
+ } catch {
+ case e: Throwable => throw dialect.classifyException(message, e)
+ }
+ }
+
+ def withConnection[T](options: JDBCOptions)(f: Connection => T): T = {
+ val dialect = JdbcDialects.get(options.url)
+ val conn = dialect.createConnectionFactory(options)(-1)
+ try {
+ f(conn)
+ } finally {
+ conn.close()
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala
index fbc69704f1479..ed8398f265848 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/connection/ConnectionProvider.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.jdbc.JdbcConnectionProvider
import org.apache.spark.util.Utils
-private[jdbc] object ConnectionProvider extends Logging {
+protected abstract class ConnectionProviderBase extends Logging {
private val providers = loadProviders()
def loadProviders(): Seq[JdbcConnectionProvider] = {
@@ -73,3 +73,5 @@ private[jdbc] object ConnectionProvider extends Logging {
}
}
}
+
+private[sql] object ConnectionProvider extends ConnectionProviderBase
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
index fa8977f239164..59a52b318622b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala
@@ -68,6 +68,22 @@ class OrcDeserializer(
resultRow
}
+ def deserializeFromValues(orcValues: Seq[WritableComparable[_]]): InternalRow = {
+ var targetColumnIndex = 0
+ while (targetColumnIndex < fieldWriters.length) {
+ if (fieldWriters(targetColumnIndex) != null) {
+ val value = orcValues(requestedColIds(targetColumnIndex))
+ if (value == null) {
+ resultRow.setNullAt(targetColumnIndex)
+ } else {
+ fieldWriters(targetColumnIndex)(value)
+ }
+ }
+ targetColumnIndex += 1
+ }
+ resultRow
+ }
+
/**
* Creates a writer to write ORC values to Catalyst data structure at the given ordinal.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
index a8647726fe022..7758d6a515b51 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala
@@ -24,17 +24,22 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, Path}
-import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription, Writer}
+import org.apache.hadoop.hive.serde2.io.DateWritable
+import org.apache.hadoop.io.{BooleanWritable, ByteWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, ShortWritable, WritableComparable}
+import org.apache.orc.{BooleanColumnStatistics, ColumnStatistics, DateColumnStatistics, DoubleColumnStatistics, IntegerColumnStatistics, OrcConf, OrcFile, Reader, TypeDescription, Writer}
-import org.apache.spark.SPARK_VERSION_SHORT
+import org.apache.spark.{SPARK_VERSION_SHORT, SparkException}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{SPARK_VERSION_METADATA_KEY, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, CharVarcharUtils}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.SchemaMergeUtils
+import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, Utils}
@@ -84,7 +89,7 @@ object OrcUtils extends Logging {
}
}
- private def toCatalystSchema(schema: TypeDescription): StructType = {
+ def toCatalystSchema(schema: TypeDescription): StructType = {
// The Spark query engine has not completely supported CHAR/VARCHAR type yet, and here we
// replace the orc CHAR/VARCHAR with STRING type.
CharVarcharUtils.replaceCharVarcharWithStringInSchema(
@@ -259,4 +264,139 @@ object OrcUtils extends Logging {
OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString)
resultSchemaString
}
+
+ /**
+ * Checks if `dataType` supports columnar reads.
+ *
+ * @param dataType Data type of the orc files.
+ * @param nestedColumnEnabled True if columnar reads is enabled for nested column types.
+ * @return Returns true if data type supports columnar reads.
+ */
+ def supportColumnarReads(
+ dataType: DataType,
+ nestedColumnEnabled: Boolean): Boolean = {
+ dataType match {
+ case _: AtomicType => true
+ case st: StructType if nestedColumnEnabled =>
+ st.forall(f => supportColumnarReads(f.dataType, nestedColumnEnabled))
+ case ArrayType(elementType, _) if nestedColumnEnabled =>
+ supportColumnarReads(elementType, nestedColumnEnabled)
+ case MapType(keyType, valueType, _) if nestedColumnEnabled =>
+ supportColumnarReads(keyType, nestedColumnEnabled) &&
+ supportColumnarReads(valueType, nestedColumnEnabled)
+ case _ => false
+ }
+ }
+
+ /**
+ * When the partial aggregates (Max/Min/Count) are pushed down to ORC, we don't need to read data
+ * from ORC and aggregate at Spark layer. Instead we want to get the partial aggregates
+ * (Max/Min/Count) result using the statistics information from ORC file footer, and then
+ * construct an InternalRow from these aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ def createAggInternalRowFromFooter(
+ reader: Reader,
+ filePath: String,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType): InternalRow = {
+ require(aggregation.groupByColumns.length == 0,
+ s"aggregate $aggregation with group-by column shouldn't be pushed down")
+ var columnsStatistics: OrcColumnStatistics = null
+ try {
+ columnsStatistics = OrcFooterReader.readStatistics(reader)
+ } catch { case e: Exception =>
+ throw new SparkException(
+ s"Cannot read columns statistics in file: $filePath. Please consider disabling " +
+ s"ORC aggregate push down by setting 'spark.sql.orc.aggregatePushdown' to false.", e)
+ }
+
+ // Get column statistics with column name.
+ def getColumnStatistics(columnName: String): ColumnStatistics = {
+ val columnIndex = dataSchema.fieldNames.indexOf(columnName)
+ columnsStatistics.get(columnIndex).getStatistics
+ }
+
+ // Get Min/Max statistics and store as ORC `WritableComparable` format.
+ // Return null if number of non-null values is zero.
+ def getMinMaxFromColumnStatistics(
+ statistics: ColumnStatistics,
+ dataType: DataType,
+ isMax: Boolean): WritableComparable[_] = {
+ if (statistics.getNumberOfValues == 0) {
+ return null
+ }
+
+ statistics match {
+ case s: BooleanColumnStatistics =>
+ val value = if (isMax) s.getTrueCount > 0 else !(s.getFalseCount > 0)
+ new BooleanWritable(value)
+ case s: IntegerColumnStatistics =>
+ val value = if (isMax) s.getMaximum else s.getMinimum
+ dataType match {
+ case ByteType => new ByteWritable(value.toByte)
+ case ShortType => new ShortWritable(value.toShort)
+ case IntegerType => new IntWritable(value.toInt)
+ case LongType => new LongWritable(value)
+ case _ => throw new IllegalArgumentException(
+ s"getMinMaxFromColumnStatistics should not take type $dataType " +
+ "for IntegerColumnStatistics")
+ }
+ case s: DoubleColumnStatistics =>
+ val value = if (isMax) s.getMaximum else s.getMinimum
+ dataType match {
+ case FloatType => new FloatWritable(value.toFloat)
+ case DoubleType => new DoubleWritable(value)
+ case _ => throw new IllegalArgumentException(
+ s"getMinMaxFromColumnStatistics should not take type $dataType " +
+ "for DoubleColumnStatistics")
+ }
+ case s: DateColumnStatistics =>
+ new DateWritable(
+ if (isMax) s.getMaximumDayOfEpoch.toInt else s.getMinimumDayOfEpoch.toInt)
+ case _ => throw new IllegalArgumentException(
+ s"getMinMaxFromColumnStatistics should not take ${statistics.getClass.getName}: " +
+ s"$statistics as the ORC column statistics")
+ }
+ }
+
+ val aggORCValues: Seq[WritableComparable[_]] =
+ aggregation.aggregateExpressions.zipWithIndex.map {
+ case (max: Max, index) if V2ColumnUtils.extractV2Column(max.column).isDefined =>
+ val columnName = V2ColumnUtils.extractV2Column(max.column).get
+ val statistics = getColumnStatistics(columnName)
+ val dataType = aggSchema(index).dataType
+ getMinMaxFromColumnStatistics(statistics, dataType, isMax = true)
+ case (min: Min, index) if V2ColumnUtils.extractV2Column(min.column).isDefined =>
+ val columnName = V2ColumnUtils.extractV2Column(min.column).get
+ val statistics = getColumnStatistics(columnName)
+ val dataType = aggSchema.apply(index).dataType
+ getMinMaxFromColumnStatistics(statistics, dataType, isMax = false)
+ case (count: Count, _) if V2ColumnUtils.extractV2Column(count.column).isDefined =>
+ val columnName = V2ColumnUtils.extractV2Column(count.column).get
+ val isPartitionColumn = partitionSchema.fields.map(_.name).contains(columnName)
+ // NOTE: Count(columnName) doesn't include null values.
+ // org.apache.orc.ColumnStatistics.getNumberOfValues() returns number of non-null values
+ // for ColumnStatistics of individual column. In addition to this, ORC also stores number
+ // of all values (null and non-null) separately.
+ val nonNullRowsCount = if (isPartitionColumn) {
+ columnsStatistics.getStatistics.getNumberOfValues
+ } else {
+ getColumnStatistics(columnName).getNumberOfValues
+ }
+ new LongWritable(nonNullRowsCount)
+ case (_: CountStar, _) =>
+ // Count(*) includes both null and non-null values.
+ new LongWritable(columnsStatistics.getStatistics.getNumberOfValues)
+ case (x, _) =>
+ throw new IllegalArgumentException(
+ s"createAggInternalRowFromFooter should not take $x as the aggregate expression")
+ }
+
+ val orcValuesDeserializer = new OrcDeserializer(aggSchema, (0 until aggSchema.length).toArray)
+ orcValuesDeserializer.deserializeFromValues(aggORCValues)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
index b91d75c55c513..f3836ab8b5ae4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetUtils.scala
@@ -16,10 +16,24 @@
*/
package org.apache.spark.sql.execution.datasources.parquet
+import java.util
+
+import scala.collection.mutable
+import scala.language.existentials
+
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.parquet.hadoop.ParquetFileWriter
+import org.apache.parquet.hadoop.metadata.{ColumnChunkMetaData, ParquetMetadata}
+import org.apache.parquet.io.api.Binary
+import org.apache.parquet.schema.{PrimitiveType, Types}
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
+import org.apache.spark.SparkException
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Count, CountStar, Max, Min}
+import org.apache.spark.sql.execution.datasources.v2.V2ColumnUtils
+import org.apache.spark.sql.internal.SQLConf.{LegacyBehaviorPolicy, PARQUET_AGGREGATE_PUSHDOWN_ENABLED}
import org.apache.spark.sql.types.StructType
object ParquetUtils {
@@ -127,4 +141,176 @@ object ParquetUtils {
file.getName == ParquetFileWriter.PARQUET_COMMON_METADATA_FILE ||
file.getName == ParquetFileWriter.PARQUET_METADATA_FILE
}
+
+ /**
+ * When the partial aggregates (Max/Min/Count) are pushed down to Parquet, we don't need to
+ * createRowBaseReader to read data from Parquet and aggregate at Spark layer. Instead we want
+ * to get the partial aggregates (Max/Min/Count) result using the statistics information
+ * from Parquet footer file, and then construct an InternalRow from these aggregate results.
+ *
+ * @return Aggregate results in the format of InternalRow
+ */
+ private[sql] def createAggInternalRowFromFooter(
+ footer: ParquetMetadata,
+ filePath: String,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ aggSchema: StructType,
+ datetimeRebaseMode: LegacyBehaviorPolicy.Value,
+ isCaseSensitive: Boolean): InternalRow = {
+ val (primitiveTypes, values) = getPushedDownAggResult(
+ footer, filePath, dataSchema, partitionSchema, aggregation, isCaseSensitive)
+
+ val builder = Types.buildMessage
+ primitiveTypes.foreach(t => builder.addField(t))
+ val parquetSchema = builder.named("root")
+
+ val schemaConverter = new ParquetToSparkSchemaConverter
+ val converter = new ParquetRowConverter(schemaConverter, parquetSchema, aggSchema,
+ None, datetimeRebaseMode, LegacyBehaviorPolicy.CORRECTED, NoopUpdater)
+ val primitiveTypeNames = primitiveTypes.map(_.getPrimitiveTypeName)
+ primitiveTypeNames.zipWithIndex.foreach {
+ case (PrimitiveType.PrimitiveTypeName.BOOLEAN, i) =>
+ val v = values(i).asInstanceOf[Boolean]
+ converter.getConverter(i).asPrimitiveConverter.addBoolean(v)
+ case (PrimitiveType.PrimitiveTypeName.INT32, i) =>
+ val v = values(i).asInstanceOf[Integer]
+ converter.getConverter(i).asPrimitiveConverter.addInt(v)
+ case (PrimitiveType.PrimitiveTypeName.INT64, i) =>
+ val v = values(i).asInstanceOf[Long]
+ converter.getConverter(i).asPrimitiveConverter.addLong(v)
+ case (PrimitiveType.PrimitiveTypeName.FLOAT, i) =>
+ val v = values(i).asInstanceOf[Float]
+ converter.getConverter(i).asPrimitiveConverter.addFloat(v)
+ case (PrimitiveType.PrimitiveTypeName.DOUBLE, i) =>
+ val v = values(i).asInstanceOf[Double]
+ converter.getConverter(i).asPrimitiveConverter.addDouble(v)
+ case (PrimitiveType.PrimitiveTypeName.BINARY, i) =>
+ val v = values(i).asInstanceOf[Binary]
+ converter.getConverter(i).asPrimitiveConverter.addBinary(v)
+ case (PrimitiveType.PrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, i) =>
+ val v = values(i).asInstanceOf[Binary]
+ converter.getConverter(i).asPrimitiveConverter.addBinary(v)
+ case (_, i) =>
+ throw new SparkException("Unexpected parquet type name: " + primitiveTypeNames(i))
+ }
+ converter.currentRecord
+ }
+
+ /**
+ * Calculate the pushed down aggregates (Max/Min/Count) result using the statistics
+ * information from Parquet footer file.
+ *
+ * @return A tuple of `Array[PrimitiveType]` and Array[Any].
+ * The first element is the Parquet PrimitiveType of the aggregate column,
+ * and the second element is the aggregated value.
+ */
+ private[sql] def getPushedDownAggResult(
+ footer: ParquetMetadata,
+ filePath: String,
+ dataSchema: StructType,
+ partitionSchema: StructType,
+ aggregation: Aggregation,
+ isCaseSensitive: Boolean)
+ : (Array[PrimitiveType], Array[Any]) = {
+ val footerFileMetaData = footer.getFileMetaData
+ val fields = footerFileMetaData.getSchema.getFields
+ val blocks = footer.getBlocks
+ val primitiveTypeBuilder = mutable.ArrayBuilder.make[PrimitiveType]
+ val valuesBuilder = mutable.ArrayBuilder.make[Any]
+
+ assert(aggregation.groupByColumns.length == 0, "group by shouldn't be pushed down")
+ aggregation.aggregateExpressions.foreach { agg =>
+ var value: Any = None
+ var rowCount = 0L
+ var isCount = false
+ var index = 0
+ var schemaName = ""
+ blocks.forEach { block =>
+ val blockMetaData = block.getColumns
+ agg match {
+ case max: Max if V2ColumnUtils.extractV2Column(max.column).isDefined =>
+ val colName = V2ColumnUtils.extractV2Column(max.column).get
+ index = dataSchema.fieldNames.toList.indexOf(colName)
+ schemaName = "max(" + colName + ")"
+ val currentMax = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, true)
+ if (value == None || currentMax.asInstanceOf[Comparable[Any]].compareTo(value) > 0) {
+ value = currentMax
+ }
+ case min: Min if V2ColumnUtils.extractV2Column(min.column).isDefined =>
+ val colName = V2ColumnUtils.extractV2Column(min.column).get
+ index = dataSchema.fieldNames.toList.indexOf(colName)
+ schemaName = "min(" + colName + ")"
+ val currentMin = getCurrentBlockMaxOrMin(filePath, blockMetaData, index, false)
+ if (value == None || currentMin.asInstanceOf[Comparable[Any]].compareTo(value) < 0) {
+ value = currentMin
+ }
+ case count: Count if V2ColumnUtils.extractV2Column(count.column).isDefined =>
+ val colName = V2ColumnUtils.extractV2Column(count.column).get
+ schemaName = "count(" + colName + ")"
+ rowCount += block.getRowCount
+ var isPartitionCol = false
+ if (partitionSchema.fields.map(_.name).toSet.contains(colName)) {
+ isPartitionCol = true
+ }
+ isCount = true
+ if (!isPartitionCol) {
+ index = dataSchema.fieldNames.toList.indexOf(colName)
+ // Count(*) includes the null values, but Count(colName) doesn't.
+ rowCount -= getNumNulls(filePath, blockMetaData, index)
+ }
+ case _: CountStar =>
+ schemaName = "count(*)"
+ rowCount += block.getRowCount
+ isCount = true
+ case _ =>
+ }
+ }
+ if (isCount) {
+ valuesBuilder += rowCount
+ primitiveTypeBuilder += Types.required(PrimitiveTypeName.INT64).named(schemaName);
+ } else {
+ valuesBuilder += value
+ val field = fields.get(index)
+ primitiveTypeBuilder += Types.required(field.asPrimitiveType.getPrimitiveTypeName)
+ .as(field.getLogicalTypeAnnotation)
+ .length(field.asPrimitiveType.getTypeLength)
+ .named(schemaName)
+ }
+ }
+ (primitiveTypeBuilder.result, valuesBuilder.result)
+ }
+
+ /**
+ * Get the Max or Min value for ith column in the current block
+ *
+ * @return the Max or Min value
+ */
+ private def getCurrentBlockMaxOrMin(
+ filePath: String,
+ columnChunkMetaData: util.List[ColumnChunkMetaData],
+ i: Int,
+ isMax: Boolean): Any = {
+ val statistics = columnChunkMetaData.get(i).getStatistics
+ if (!statistics.hasNonNullValue) {
+ throw new UnsupportedOperationException(s"No min/max found for Parquet file $filePath. " +
+ s"Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute again")
+ } else {
+ if (isMax) statistics.genericGetMax else statistics.genericGetMin
+ }
+ }
+
+ private def getNumNulls(
+ filePath: String,
+ columnChunkMetaData: util.List[ColumnChunkMetaData],
+ i: Int): Long = {
+ val statistics = columnChunkMetaData.get(i).getStatistics
+ if (!statistics.isNumNullsSet) {
+ throw new UnsupportedOperationException(s"Number of nulls not set for Parquet file" +
+ s" $filePath. Set SQLConf ${PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key} to false and execute" +
+ s" again")
+ }
+ statistics.getNumNulls;
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index 1a50c320ea3e3..f267a03cbe218 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -18,14 +18,17 @@
package org.apache.spark.sql.execution.datasources.v2
import scala.collection.JavaConverters._
+import scala.collection.mutable
import org.apache.spark.sql.{SparkSession, Strategy}
import org.apache.spark.sql.catalyst.analysis.{ResolvedNamespace, ResolvedPartitionSpec, ResolvedTable}
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, PredicateHelper, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.expressions.{And, Attribute, DynamicPruning, Expression, NamedExpression, Not, Or, PredicateHelper, SubqueryExpression}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.toPrettySQL
+import org.apache.spark.sql.catalyst.util.{toPrettySQL, V2ExpressionBuilder}
import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, StagingTableCatalog, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, Table, TableCapability, TableCatalog}
+import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not => V2Not, Or => V2Or, Predicate}
import org.apache.spark.sql.connector.read.LocalScan
import org.apache.spark.sql.connector.read.streaming.{ContinuousStream, MicroBatchStream}
import org.apache.spark.sql.connector.write.V1Write
@@ -86,8 +89,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case PhysicalOperation(project, filters,
- DataSourceV2ScanRelation(_, V1ScanWrapper(scan, pushed, aggregate), output)) =>
+ case PhysicalOperation(project, filters, DataSourceV2ScanRelation(
+ _, V1ScanWrapper(scan, pushed, pushedDownOperators), output)) =>
val v1Relation = scan.toV1TableScan[BaseRelation with TableScan](session.sqlContext)
if (v1Relation.schema != scan.readSchema()) {
throw QueryExecutionErrors.fallbackV1RelationReportsInconsistentSchemaError(
@@ -95,12 +98,13 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
}
val rdd = v1Relation.buildScan()
val unsafeRowRDD = DataSourceStrategy.toCatalystRDD(v1Relation, output, rdd)
+
val dsScan = RowDataSourceScanExec(
output,
output.toStructType,
Set.empty,
pushed.toSet,
- aggregate,
+ pushedDownOperators,
unsafeRowRDD,
v1Relation,
tableIdentifier = None)
@@ -427,3 +431,112 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case _ => Nil
}
}
+
+private[sql] object DataSourceV2Strategy {
+
+ private def translateLeafNodeFilterV2(
+ predicate: Expression,
+ supportNestedPredicatePushdown: Boolean): Option[Predicate] = {
+ val pushablePredicate = PushablePredicate(supportNestedPredicatePushdown)
+ predicate match {
+ case pushablePredicate(expr) => Some(expr)
+ case _ => None
+ }
+ }
+
+ /**
+ * Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
+ *
+ * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`.
+ */
+ protected[sql] def translateFilterV2(
+ predicate: Expression,
+ supportNestedPredicatePushdown: Boolean): Option[Predicate] = {
+ translateFilterV2WithMapping(predicate, None, supportNestedPredicatePushdown)
+ }
+
+ /**
+ * Tries to translate a Catalyst [[Expression]] into data source [[Filter]].
+ *
+ * @param predicate The input [[Expression]] to be translated as [[Filter]]
+ * @param translatedFilterToExpr An optional map from leaf node filter expressions to its
+ * translated [[Filter]]. The map is used for rebuilding
+ * [[Expression]] from [[Filter]].
+ * @return a `Some[Filter]` if the input [[Expression]] is convertible, otherwise a `None`.
+ */
+ protected[sql] def translateFilterV2WithMapping(
+ predicate: Expression,
+ translatedFilterToExpr: Option[mutable.HashMap[Predicate, Expression]],
+ nestedPredicatePushdownEnabled: Boolean)
+ : Option[Predicate] = {
+ predicate match {
+ case And(left, right) =>
+ // See SPARK-12218 for detailed discussion
+ // It is not safe to just convert one side if we do not understand the
+ // other side. Here is an example used to explain the reason.
+ // Let's say we have (a = 2 AND trim(b) = 'blah') OR (c > 0)
+ // and we do not understand how to convert trim(b) = 'blah'.
+ // If we only convert a = 2, we will end up with
+ // (a = 2) OR (c > 0), which will generate wrong results.
+ // Pushing one leg of AND down is only safe to do at the top level.
+ // You can see ParquetFilters' createFilter for more details.
+ for {
+ leftFilter <- translateFilterV2WithMapping(
+ left, translatedFilterToExpr, nestedPredicatePushdownEnabled)
+ rightFilter <- translateFilterV2WithMapping(
+ right, translatedFilterToExpr, nestedPredicatePushdownEnabled)
+ } yield new V2And(leftFilter, rightFilter)
+
+ case Or(left, right) =>
+ for {
+ leftFilter <- translateFilterV2WithMapping(
+ left, translatedFilterToExpr, nestedPredicatePushdownEnabled)
+ rightFilter <- translateFilterV2WithMapping(
+ right, translatedFilterToExpr, nestedPredicatePushdownEnabled)
+ } yield new V2Or(leftFilter, rightFilter)
+
+ case Not(child) =>
+ translateFilterV2WithMapping(child, translatedFilterToExpr, nestedPredicatePushdownEnabled)
+ .map(new V2Not(_))
+
+ case other =>
+ val filter = translateLeafNodeFilterV2(other, nestedPredicatePushdownEnabled)
+ if (filter.isDefined && translatedFilterToExpr.isDefined) {
+ translatedFilterToExpr.get(filter.get) = predicate
+ }
+ filter
+ }
+ }
+
+ protected[sql] def rebuildExpressionFromFilter(
+ predicate: Predicate,
+ translatedFilterToExpr: mutable.HashMap[Predicate, Expression]): Expression = {
+ predicate match {
+ case and: V2And =>
+ expressions.And(
+ rebuildExpressionFromFilter(and.left(), translatedFilterToExpr),
+ rebuildExpressionFromFilter(and.right(), translatedFilterToExpr))
+ case or: V2Or =>
+ expressions.Or(
+ rebuildExpressionFromFilter(or.left(), translatedFilterToExpr),
+ rebuildExpressionFromFilter(or.right(), translatedFilterToExpr))
+ case not: V2Not =>
+ expressions.Not(rebuildExpressionFromFilter(not.child(), translatedFilterToExpr))
+ case _ =>
+ translatedFilterToExpr.getOrElse(predicate,
+ throw new IllegalStateException("Failed to rebuild Expression for filter: " + predicate))
+ }
+ }
+}
+
+/**
+ * Get the expression of DS V2 to represent catalyst predicate that can be pushed down.
+ */
+case class PushablePredicate(nestedPredicatePushdownEnabled: Boolean) {
+
+ def unapply(e: Expression): Option[Predicate] =
+ new V2ExpressionBuilder(e, nestedPredicatePushdownEnabled, true).build().map { v =>
+ assert(v.isInstanceOf[Predicate])
+ v.asInstanceOf[Predicate]
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala
index dbd5cbd874945..5d302055e7d91 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropNamespaceExec.scala
@@ -18,9 +18,10 @@
package org.apache.spark.sql.execution.datasources.v2
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.connector.catalog.CatalogPlugin
-import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
+import org.apache.spark.sql.errors.QueryCompilationErrors
/**
* Physical plan node for dropping a namespace.
@@ -37,17 +38,11 @@ case class DropNamespaceExec(
val nsCatalog = catalog.asNamespaceCatalog
val ns = namespace.toArray
if (nsCatalog.namespaceExists(ns)) {
- // The default behavior of `SupportsNamespace.dropNamespace()` is cascading,
- // so make sure the namespace to drop is empty.
- if (!cascade) {
- if (catalog.asTableCatalog.listTables(ns).nonEmpty
- || nsCatalog.listNamespaces(ns).nonEmpty) {
- throw QueryExecutionErrors.cannotDropNonemptyNamespaceError(namespace)
- }
- }
-
- if (!nsCatalog.dropNamespace(ns)) {
- throw QueryExecutionErrors.cannotDropNonemptyNamespaceError(namespace)
+ try {
+ nsCatalog.dropNamespace(ns, cascade)
+ } catch {
+ case _: NonEmptyNamespaceException =>
+ throw QueryCompilationErrors.cannotDropNonemptyNamespaceError(namespace)
}
} else if (!ifExists) {
throw QueryCompilationErrors.noSuchNamespaceError(ns)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
index 4506bd3d49b5b..8b0328cabc5a8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala
@@ -49,6 +49,8 @@ trait FileScan extends Scan
def fileIndex: PartitioningAwareFileIndex
+ def dataSchema: StructType
+
/**
* Returns the required data schema
*/
@@ -69,12 +71,6 @@ trait FileScan extends Scan
*/
def dataFilters: Seq[Expression]
- /**
- * Create a new `FileScan` instance from the current one
- * with different `partitionFilters` and `dataFilters`
- */
- def withFilters(partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan
-
/**
* If a file with `path` is unsplittable, return the unsplittable reason,
* otherwise return `None`.
@@ -187,7 +183,10 @@ trait FileScan extends Scan
new Statistics {
override def sizeInBytes(): OptionalLong = {
val compressionFactor = sparkSession.sessionState.conf.fileCompressionFactor
- val size = (compressionFactor * fileIndex.sizeInBytes).toLong
+ val size = (compressionFactor * fileIndex.sizeInBytes /
+ (dataSchema.defaultSize + fileIndex.partitionSchema.defaultSize) *
+ (readDataSchema.defaultSize + readPartitionSchema.defaultSize)).toLong
+
OptionalLong.of(size)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
index 97874e8f4932e..2dc4137d6f9a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScanBuilder.scala
@@ -16,19 +16,30 @@
*/
package org.apache.spark.sql.execution.datasources.v2
-import org.apache.spark.sql.SparkSession
+import scala.collection.mutable
+
+import org.apache.spark.sql.{sources, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.read.{ScanBuilder, SupportsPushDownRequiredColumns}
-import org.apache.spark.sql.execution.datasources.{PartitioningAwareFileIndex, PartitioningUtils}
+import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, DataSourceUtils, PartitioningAwareFileIndex, PartitioningUtils}
+import org.apache.spark.sql.internal.connector.SupportsPushDownCatalystFilters
+import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
abstract class FileScanBuilder(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
- dataSchema: StructType) extends ScanBuilder with SupportsPushDownRequiredColumns {
+ dataSchema: StructType)
+ extends ScanBuilder
+ with SupportsPushDownRequiredColumns
+ with SupportsPushDownCatalystFilters {
private val partitionSchema = fileIndex.partitionSchema
private val isCaseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis
protected val supportsNestedSchemaPruning = false
protected var requiredSchema = StructType(dataSchema.fields ++ partitionSchema.fields)
+ protected var partitionFilters = Seq.empty[Expression]
+ protected var dataFilters = Seq.empty[Expression]
+ protected var pushedDataFilters = Array.empty[Filter]
override def pruneColumns(requiredSchema: StructType): Unit = {
// [SPARK-30107] While `requiredSchema` might have pruned nested columns,
@@ -48,7 +59,7 @@ abstract class FileScanBuilder(
StructType(fields)
}
- protected def readPartitionSchema(): StructType = {
+ def readPartitionSchema(): StructType = {
val requiredNameSet = createRequiredNameSet()
val fields = partitionSchema.fields.filter { field =>
val colName = PartitioningUtils.getColName(field, isCaseSensitive)
@@ -57,9 +68,34 @@ abstract class FileScanBuilder(
StructType(fields)
}
+ override def pushFilters(filters: Seq[Expression]): Seq[Expression] = {
+ val (partitionFilters, dataFilters) =
+ DataSourceUtils.getPartitionFiltersAndDataFilters(partitionSchema, filters)
+ this.partitionFilters = partitionFilters
+ this.dataFilters = dataFilters
+ val translatedFilters = mutable.ArrayBuffer.empty[sources.Filter]
+ for (filterExpr <- dataFilters) {
+ val translated = DataSourceStrategy.translateFilter(filterExpr, true)
+ if (translated.nonEmpty) {
+ translatedFilters += translated.get
+ }
+ }
+ pushedDataFilters = pushDataFilters(translatedFilters.toArray)
+ dataFilters
+ }
+
+ override def pushedFilters: Array[Filter] = pushedDataFilters
+
+ /*
+ * Push down data filters to the file source, so the data filters can be evaluated there to
+ * reduce the size of the data to be read. By default, data filters are not pushed down.
+ * File source needs to implement this method to push down data filters.
+ */
+ protected def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = Array.empty[Filter]
+
private def createRequiredNameSet(): Set[String] =
requiredSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
- private val partitionNameSet: Set[String] =
+ val partitionNameSet: Set[String] =
partitionSchema.fields.map(PartitioningUtils.getColName(_, isCaseSensitive)).toSet
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
index acc645741819e..2adbd5cf007e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala
@@ -20,14 +20,11 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning}
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.connector.expressions.FieldReference
-import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.connector.expressions.SortOrder
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
-import org.apache.spark.sql.execution.datasources.PushableColumnWithoutNestedColumn
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
@@ -38,9 +35,8 @@ object PushDownUtils extends PredicateHelper {
*
* @return pushed filter and post-scan filters.
*/
- def pushFilters(
- scanBuilder: ScanBuilder,
- filters: Seq[Expression]): (Seq[sources.Filter], Seq[Expression]) = {
+ def pushFilters(scanBuilder: ScanBuilder, filters: Seq[Expression])
+ : (Either[Seq[sources.Filter], Seq[Predicate]], Seq[Expression]) = {
scanBuilder match {
case r: SupportsPushDownFilters =>
// A map from translated data source leaf node filters to original catalyst filter
@@ -69,41 +65,79 @@ object PushDownUtils extends PredicateHelper {
val postScanFilters = r.pushFilters(translatedFilters.toArray).map { filter =>
DataSourceStrategy.rebuildExpressionFromFilter(filter, translatedFilterToExpr)
}
- (r.pushedFilters(), (untranslatableExprs ++ postScanFilters).toSeq)
+ (Left(r.pushedFilters()), (untranslatableExprs ++ postScanFilters).toSeq)
+
+ case r: SupportsPushDownV2Filters =>
+ // A map from translated data source leaf node filters to original catalyst filter
+ // expressions. For a `And`/`Or` predicate, it is possible that the predicate is partially
+ // pushed down. This map can be used to construct a catalyst filter expression from the
+ // input filter, or a superset(partial push down filter) of the input filter.
+ val translatedFilterToExpr = mutable.HashMap.empty[Predicate, Expression]
+ val translatedFilters = mutable.ArrayBuffer.empty[Predicate]
+ // Catalyst filter expression that can't be translated to data source filters.
+ val untranslatableExprs = mutable.ArrayBuffer.empty[Expression]
+
+ for (filterExpr <- filters) {
+ val translated =
+ DataSourceV2Strategy.translateFilterV2WithMapping(
+ filterExpr, Some(translatedFilterToExpr), nestedPredicatePushdownEnabled = true)
+ if (translated.isEmpty) {
+ untranslatableExprs += filterExpr
+ } else {
+ translatedFilters += translated.get
+ }
+ }
+
+ // Data source filters that need to be evaluated again after scanning. which means
+ // the data source cannot guarantee the rows returned can pass these filters.
+ // As a result we must return it so Spark can plan an extra filter operator.
+ val postScanFilters = r.pushPredicates(translatedFilters.toArray).map { predicate =>
+ DataSourceV2Strategy.rebuildExpressionFromFilter(predicate, translatedFilterToExpr)
+ }
+ (Right(r.pushedPredicates), (untranslatableExprs ++ postScanFilters).toSeq)
- case _ => (Nil, filters)
+ case f: FileScanBuilder =>
+ val postScanFilters = f.pushFilters(filters)
+ (Left(f.pushedFilters), postScanFilters)
+
+ case _ => (Left(Nil), filters)
}
}
/**
- * Pushes down aggregates to the data source reader
- *
- * @return pushed aggregation.
+ * Pushes down TableSample to the data source Scan
*/
- def pushAggregates(
- scanBuilder: ScanBuilder,
- aggregates: Seq[AggregateExpression],
- groupBy: Seq[Expression]): Option[Aggregation] = {
-
- def columnAsString(e: Expression): Option[FieldReference] = e match {
- case PushableColumnWithoutNestedColumn(name) =>
- Some(FieldReference(name).asInstanceOf[FieldReference])
- case _ => None
+ def pushTableSample(scanBuilder: ScanBuilder, sample: TableSampleInfo): Boolean = {
+ scanBuilder match {
+ case s: SupportsPushDownTableSample =>
+ s.pushTableSample(
+ sample.lowerBound, sample.upperBound, sample.withReplacement, sample.seed)
+ case _ => false
}
+ }
+ /**
+ * Pushes down LIMIT to the data source Scan
+ */
+ def pushLimit(scanBuilder: ScanBuilder, limit: Int): Boolean = {
scanBuilder match {
- case r: SupportsPushDownAggregates if aggregates.nonEmpty =>
- val translatedAggregates = aggregates.flatMap(DataSourceStrategy.translateAggregate)
- val translatedGroupBys = groupBy.flatMap(columnAsString)
-
- if (translatedAggregates.length != aggregates.length ||
- translatedGroupBys.length != groupBy.length) {
- return None
- }
+ case s: SupportsPushDownLimit =>
+ s.pushLimit(limit)
+ case _ => false
+ }
+ }
- val agg = new Aggregation(translatedAggregates.toArray, translatedGroupBys.toArray)
- Some(agg).filter(r.pushAggregation)
- case _ => None
+ /**
+ * Pushes down top N to the data source Scan
+ */
+ def pushTopN(
+ scanBuilder: ScanBuilder,
+ order: Array[SortOrder],
+ limit: Int): (Boolean, Boolean) = {
+ scanBuilder match {
+ case s: SupportsPushDownTopN if s.pushTopN(order, limit) =>
+ (true, s.isPartiallyPushed)
+ case _ => (false, false)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
new file mode 100644
index 0000000000000..a95b4593fc397
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushedDownOperators.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.connector.expressions.SortOrder
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+
+/**
+ * Pushed down operators
+ */
+case class PushedDownOperators(
+ aggregation: Option[Aggregation],
+ sample: Option[TableSampleInfo],
+ limit: Option[Int],
+ sortValues: Seq[SortOrder],
+ pushedPredicates: Seq[Predicate]) {
+ assert((limit.isEmpty && sortValues.isEmpty) || limit.isDefined)
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala
new file mode 100644
index 0000000000000..cb4fb9eb0809a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/TableSampleInfo.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+case class TableSampleInfo(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala
new file mode 100644
index 0000000000000..9fc220f440bc1
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ColumnUtils.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.connector.expressions.{Expression, NamedReference}
+
+object V2ColumnUtils {
+ def extractV2Column(expr: Expression): Option[String] = expr match {
+ case r: NamedReference if r. fieldNames.length == 1 => Some(r.fieldNames.head)
+ case _ => None
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
index 046155b55cc2d..cdcae15ef4e24 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala
@@ -19,24 +19,28 @@ package org.apache.spark.sql.execution.datasources.v2
import scala.collection.mutable
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, Expression, NamedExpression, PredicateHelper, ProjectionOverSchema, SubqueryExpression}
+import org.apache.spark.sql.catalyst.expressions.{Alias, AliasHelper, And, Attribute, AttributeReference, Cast, Expression, IntegerLiteral, NamedExpression, PredicateHelper, ProjectionOverSchema, SortOrder, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.optimizer.CollapseProject
import org.apache.spark.sql.catalyst.planning.ScanOperation
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LeafNode, Limit, LocalLimit, LogicalPlan, Project, Sample, Sort}
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.connector.expressions.{SortOrder => V2SortOrder}
+import org.apache.spark.sql.connector.expressions.aggregate.{Aggregation, Avg, Count, GeneralAggregateFunc, Sum}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, V1Scan}
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DataType, LongType, StructType}
import org.apache.spark.sql.util.SchemaUtils._
-object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
+object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper with AliasHelper {
import DataSourceV2Implicits._
def apply(plan: LogicalPlan): LogicalPlan = {
- applyColumnPruning(pushDownAggregates(pushDownFilters(createScanBuilder(plan))))
+ applyColumnPruning(
+ applyLimit(pushDownAggregates(pushDownFilters(pushDownSample(createScanBuilder(plan))))))
}
private def createScanBuilder(plan: LogicalPlan) = plan.transform {
@@ -58,12 +62,19 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
// `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter.
val (pushedFilters, postScanFiltersWithoutSubquery) = PushDownUtils.pushFilters(
sHolder.builder, normalizedFiltersWithoutSubquery)
+ val pushedFiltersStr = if (pushedFilters.isLeft) {
+ pushedFilters.left.get.mkString(", ")
+ } else {
+ sHolder.pushedPredicates = pushedFilters.right.get
+ pushedFilters.right.get.mkString(", ")
+ }
+
val postScanFilters = postScanFiltersWithoutSubquery ++ normalizedFiltersWithSubquery
logInfo(
s"""
|Pushing operators to ${sHolder.relation.name}
- |Pushed Filters: ${pushedFilters.mkString(", ")}
+ |Pushed Filters: $pushedFiltersStr
|Post-Scan Filters: ${postScanFilters.mkString(",")}
""".stripMargin)
@@ -76,103 +87,168 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case aggNode @ Aggregate(groupingExpressions, resultExpressions, child) =>
child match {
case ScanOperation(project, filters, sHolder: ScanBuilderHolder)
- if filters.isEmpty && project.forall(_.isInstanceOf[AttributeReference]) =>
+ if filters.isEmpty && CollapseProject.canCollapseExpressions(
+ resultExpressions, project, alwaysInline = true) =>
sHolder.builder match {
- case _: SupportsPushDownAggregates =>
+ case r: SupportsPushDownAggregates =>
+ val aliasMap = getAliasMap(project)
+ val actualResultExprs = resultExpressions.map(replaceAliasButKeepName(_, aliasMap))
+ val actualGroupExprs = groupingExpressions.map(replaceAlias(_, aliasMap))
+
val aggExprToOutputOrdinal = mutable.HashMap.empty[Expression, Int]
- var ordinal = 0
- val aggregates = resultExpressions.flatMap { expr =>
- expr.collect {
- // Do not push down duplicated aggregate expressions. For example,
- // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
- // `max(a)` to the data source.
- case agg: AggregateExpression
- if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
- aggExprToOutputOrdinal(agg.canonicalized) = ordinal
- ordinal += 1
- agg
- }
- }
+ val aggregates = collectAggregates(actualResultExprs, aggExprToOutputOrdinal)
val normalizedAggregates = DataSourceStrategy.normalizeExprs(
aggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
val normalizedGroupingExpressions = DataSourceStrategy.normalizeExprs(
- groupingExpressions, sHolder.relation.output)
- val pushedAggregates = PushDownUtils.pushAggregates(
- sHolder.builder, normalizedAggregates, normalizedGroupingExpressions)
- if (pushedAggregates.isEmpty) {
+ actualGroupExprs, sHolder.relation.output)
+ val translatedAggregates = DataSourceStrategy.translateAggregation(
+ normalizedAggregates, normalizedGroupingExpressions)
+ val (finalResultExpressions, finalAggregates, finalTranslatedAggregates) = {
+ if (translatedAggregates.isEmpty ||
+ r.supportCompletePushDown(translatedAggregates.get) ||
+ translatedAggregates.get.aggregateExpressions().forall(!_.isInstanceOf[Avg])) {
+ (actualResultExprs, aggregates, translatedAggregates)
+ } else {
+ // scalastyle:off
+ // The data source doesn't support the complete push-down of this aggregation.
+ // Here we translate `AVG` to `SUM / COUNT`, so that it's more likely to be
+ // pushed, completely or partially.
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+ // SELECT avg(c1) FROM t GROUP BY c2;
+ // The original logical plan is
+ // Aggregate [c2#10],[avg(c1#9) AS avg(c1)#19]
+ // +- ScanOperation[...]
+ //
+ // After convert avg(c1#9) to sum(c1#9)/count(c1#9)
+ // we have the following
+ // Aggregate [c2#10],[sum(c1#9)/count(c1#9) AS avg(c1)#19]
+ // +- ScanOperation[...]
+ // scalastyle:on
+ val newResultExpressions = actualResultExprs.map { expr =>
+ expr.transform {
+ case AggregateExpression(avg: aggregate.Average, _, isDistinct, _, _) =>
+ val sum = aggregate.Sum(avg.child).toAggregateExpression(isDistinct)
+ val count = aggregate.Count(avg.child).toAggregateExpression(isDistinct)
+ avg.evaluateExpression transform {
+ case a: Attribute if a.semanticEquals(avg.sum) =>
+ addCastIfNeeded(sum, avg.sum.dataType)
+ case a: Attribute if a.semanticEquals(avg.count) =>
+ addCastIfNeeded(count, avg.count.dataType)
+ }
+ }
+ }.asInstanceOf[Seq[NamedExpression]]
+ // Because aggregate expressions changed, translate them again.
+ aggExprToOutputOrdinal.clear()
+ val newAggregates =
+ collectAggregates(newResultExpressions, aggExprToOutputOrdinal)
+ val newNormalizedAggregates = DataSourceStrategy.normalizeExprs(
+ newAggregates, sHolder.relation.output).asInstanceOf[Seq[AggregateExpression]]
+ (newResultExpressions, newAggregates, DataSourceStrategy.translateAggregation(
+ newNormalizedAggregates, normalizedGroupingExpressions))
+ }
+ }
+
+ if (finalTranslatedAggregates.isEmpty) {
+ aggNode // return original plan node
+ } else if (!r.supportCompletePushDown(finalTranslatedAggregates.get) &&
+ !supportPartialAggPushDown(finalTranslatedAggregates.get)) {
aggNode // return original plan node
} else {
- // No need to do column pruning because only the aggregate columns are used as
- // DataSourceV2ScanRelation output columns. All the other columns are not
- // included in the output.
- val scan = sHolder.builder.build()
-
- // scalastyle:off
- // use the group by columns and aggregate columns as the output columns
- // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
- // SELECT min(c1), max(c1) FROM t GROUP BY c2;
- // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation
- // We want to have the following logical plan:
- // == Optimized Logical Plan ==
- // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
- // +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
- // scalastyle:on
- val newOutput = scan.readSchema().toAttributes
- assert(newOutput.length == groupingExpressions.length + aggregates.length)
- val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
- case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
- case (_, b) => b
- }
- val output = groupAttrs ++ newOutput.drop(groupAttrs.length)
-
- logInfo(
- s"""
- |Pushing operators to ${sHolder.relation.name}
- |Pushed Aggregate Functions:
- | ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
- |Pushed Group by:
- | ${pushedAggregates.get.groupByColumns.mkString(", ")}
- |Output: ${output.mkString(", ")}
+ val pushedAggregates = finalTranslatedAggregates.filter(r.pushAggregation)
+ if (pushedAggregates.isEmpty) {
+ aggNode // return original plan node
+ } else {
+ // No need to do column pruning because only the aggregate columns are used as
+ // DataSourceV2ScanRelation output columns. All the other columns are not
+ // included in the output.
+ val scan = sHolder.builder.build()
+
+ // scalastyle:off
+ // use the group by columns and aggregate columns as the output columns
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+ // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+ // Use c2, min(c1), max(c1) as output for DataSourceV2ScanRelation
+ // We want to have the following logical plan:
+ // == Optimized Logical Plan ==
+ // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22]
+ // scalastyle:on
+ val newOutput = scan.readSchema().toAttributes
+ assert(newOutput.length == groupingExpressions.length + finalAggregates.length)
+ val groupAttrs = normalizedGroupingExpressions.zip(newOutput).map {
+ case (a: Attribute, b: Attribute) => b.withExprId(a.exprId)
+ case (_, b) => b
+ }
+ val aggOutput = newOutput.drop(groupAttrs.length)
+ val output = groupAttrs ++ aggOutput
+
+ logInfo(
+ s"""
+ |Pushing operators to ${sHolder.relation.name}
+ |Pushed Aggregate Functions:
+ | ${pushedAggregates.get.aggregateExpressions.mkString(", ")}
+ |Pushed Group by:
+ | ${pushedAggregates.get.groupByColumns.mkString(", ")}
+ |Output: ${output.mkString(", ")}
""".stripMargin)
- val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
-
- val scanRelation = DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
-
- val plan = Aggregate(
- output.take(groupingExpressions.length), resultExpressions, scanRelation)
-
- // scalastyle:off
- // Change the optimized logical plan to reflect the pushed down aggregate
- // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
- // SELECT min(c1), max(c1) FROM t GROUP BY c2;
- // The original logical plan is
- // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
- // +- RelationV2[c1#9, c2#10] ...
- //
- // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
- // we have the following
- // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
- // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
- //
- // We want to change it to
- // == Optimized Logical Plan ==
- // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
- // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
- // scalastyle:on
- val aggOutput = output.drop(groupAttrs.length)
- plan.transformExpressions {
- case agg: AggregateExpression =>
- val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
- val aggFunction: aggregate.AggregateFunction =
- agg.aggregateFunction match {
- case max: aggregate.Max => max.copy(child = aggOutput(ordinal))
- case min: aggregate.Min => min.copy(child = aggOutput(ordinal))
- case sum: aggregate.Sum => sum.copy(child = aggOutput(ordinal))
- case _: aggregate.Count => aggregate.Sum(aggOutput(ordinal))
- case other => other
+ val wrappedScan = getWrappedScan(scan, sHolder, pushedAggregates)
+ val scanRelation =
+ DataSourceV2ScanRelation(sHolder.relation, wrappedScan, output)
+ if (r.supportCompletePushDown(pushedAggregates.get)) {
+ val projectExpressions = finalResultExpressions.map { expr =>
+ // TODO At present, only push down group by attribute is supported.
+ // In future, more attribute conversion is extended here. e.g. GetStructField
+ expr.transform {
+ case agg: AggregateExpression =>
+ val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+ val child =
+ addCastIfNeeded(aggOutput(ordinal), agg.resultAttribute.dataType)
+ Alias(child, agg.resultAttribute.name)(agg.resultAttribute.exprId)
}
- agg.copy(aggregateFunction = aggFunction)
+ }.asInstanceOf[Seq[NamedExpression]]
+ Project(projectExpressions, scanRelation)
+ } else {
+ val plan = Aggregate(output.take(groupingExpressions.length),
+ finalResultExpressions, scanRelation)
+
+ // scalastyle:off
+ // Change the optimized logical plan to reflect the pushed down aggregate
+ // e.g. TABLE t (c1 INT, c2 INT, c3 INT)
+ // SELECT min(c1), max(c1) FROM t GROUP BY c2;
+ // The original logical plan is
+ // Aggregate [c2#10],[min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
+ // +- RelationV2[c1#9, c2#10] ...
+ //
+ // After change the V2ScanRelation output to [c2#10, min(c1)#21, max(c1)#22]
+ // we have the following
+ // !Aggregate [c2#10], [min(c1#9) AS min(c1)#17, max(c1#9) AS max(c1)#18]
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+ //
+ // We want to change it to
+ // == Optimized Logical Plan ==
+ // Aggregate [c2#10], [min(min(c1)#21) AS min(c1)#17, max(max(c1)#22) AS max(c1)#18]
+ // +- RelationV2[c2#10, min(c1)#21, max(c1)#22] ...
+ // scalastyle:on
+ plan.transformExpressions {
+ case agg: AggregateExpression =>
+ val ordinal = aggExprToOutputOrdinal(agg.canonicalized)
+ val aggAttribute = aggOutput(ordinal)
+ val aggFunction: aggregate.AggregateFunction =
+ agg.aggregateFunction match {
+ case max: aggregate.Max =>
+ max.copy(child = addCastIfNeeded(aggAttribute, max.child.dataType))
+ case min: aggregate.Min =>
+ min.copy(child = addCastIfNeeded(aggAttribute, min.child.dataType))
+ case sum: aggregate.Sum =>
+ sum.copy(child = addCastIfNeeded(aggAttribute, sum.child.dataType))
+ case _: aggregate.Count =>
+ aggregate.Sum(addCastIfNeeded(aggAttribute, LongType))
+ case other => other
+ }
+ agg.copy(aggregateFunction = aggFunction)
+ }
+ }
}
}
case _ => aggNode
@@ -181,6 +257,42 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
}
}
+ private def collectAggregates(resultExpressions: Seq[NamedExpression],
+ aggExprToOutputOrdinal: mutable.HashMap[Expression, Int]): Seq[AggregateExpression] = {
+ var ordinal = 0
+ resultExpressions.flatMap { expr =>
+ expr.collect {
+ // Do not push down duplicated aggregate expressions. For example,
+ // `SELECT max(a) + 1, max(a) + 2 FROM ...`, we should only push down one
+ // `max(a)` to the data source.
+ case agg: AggregateExpression
+ if !aggExprToOutputOrdinal.contains(agg.canonicalized) =>
+ aggExprToOutputOrdinal(agg.canonicalized) = ordinal
+ ordinal += 1
+ agg
+ }
+ }
+ }
+
+ private def supportPartialAggPushDown(agg: Aggregation): Boolean = {
+ // We don't know the agg buffer of `GeneralAggregateFunc`, so can't do partial agg push down.
+ // If `Sum`, `Count`, `Avg` with distinct, can't do partial agg push down.
+ agg.aggregateExpressions().exists {
+ case sum: Sum => !sum.isDistinct
+ case count: Count => !count.isDistinct
+ case avg: Avg => !avg.isDistinct
+ case _: GeneralAggregateFunc => false
+ case _ => true
+ }
+ }
+
+ private def addCastIfNeeded(expression: Expression, expectedDataType: DataType) =
+ if (expression.dataType == expectedDataType) {
+ expression
+ } else {
+ Cast(expression, expectedDataType)
+ }
+
def applyColumnPruning(plan: LogicalPlan): LogicalPlan = plan.transform {
case ScanOperation(project, filters, sHolder: ScanBuilderHolder) =>
// column pruning
@@ -219,6 +331,69 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
withProjection
}
+ def pushDownSample(plan: LogicalPlan): LogicalPlan = plan.transform {
+ case sample: Sample => sample.child match {
+ case ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
+ val tableSample = TableSampleInfo(
+ sample.lowerBound,
+ sample.upperBound,
+ sample.withReplacement,
+ sample.seed)
+ val pushed = PushDownUtils.pushTableSample(sHolder.builder, tableSample)
+ if (pushed) {
+ sHolder.pushedSample = Some(tableSample)
+ sample.child
+ } else {
+ sample
+ }
+
+ case _ => sample
+ }
+ }
+
+ private def pushDownLimit(plan: LogicalPlan, limit: Int): LogicalPlan = plan match {
+ case operation @ ScanOperation(_, filter, sHolder: ScanBuilderHolder) if filter.isEmpty =>
+ val limitPushed = PushDownUtils.pushLimit(sHolder.builder, limit)
+ if (limitPushed) {
+ sHolder.pushedLimit = Some(limit)
+ }
+ operation
+ case s @ Sort(order, _, operation @ ScanOperation(project, filter, sHolder: ScanBuilderHolder))
+ if filter.isEmpty && CollapseProject.canCollapseExpressions(
+ order, project, alwaysInline = true) =>
+ val aliasMap = getAliasMap(project)
+ val newOrder = order.map(replaceAlias(_, aliasMap)).asInstanceOf[Seq[SortOrder]]
+ val orders = DataSourceStrategy.translateSortOrders(newOrder)
+ if (orders.length == order.length) {
+ val (isPushed, isPartiallyPushed) =
+ PushDownUtils.pushTopN(sHolder.builder, orders.toArray, limit)
+ if (isPushed) {
+ sHolder.pushedLimit = Some(limit)
+ sHolder.sortOrders = orders
+ if (isPartiallyPushed) {
+ s
+ } else {
+ operation
+ }
+ } else {
+ s
+ }
+ } else {
+ s
+ }
+ case p: Project =>
+ val newChild = pushDownLimit(p.child, limit)
+ p.withNewChildren(Seq(newChild))
+ case other => other
+ }
+
+ def applyLimit(plan: LogicalPlan): LogicalPlan = plan.transform {
+ case globalLimit @ Limit(IntegerLiteral(limitValue), child) =>
+ val newChild = pushDownLimit(child, limitValue)
+ val newLocalLimit = globalLimit.child.asInstanceOf[LocalLimit].withNewChildren(Seq(newChild))
+ globalLimit.withNewChildren(Seq(newLocalLimit))
+ }
+
private def getWrappedScan(
scan: Scan,
sHolder: ScanBuilderHolder,
@@ -230,7 +405,9 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
f.pushedFilters()
case _ => Array.empty[sources.Filter]
}
- V1ScanWrapper(v1, pushedFilters, aggregation)
+ val pushedDownOperators = PushedDownOperators(aggregation, sHolder.pushedSample,
+ sHolder.pushedLimit, sHolder.sortOrders, sHolder.pushedPredicates)
+ V1ScanWrapper(v1, pushedFilters, pushedDownOperators)
case _ => scan
}
}
@@ -239,13 +416,22 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
case class ScanBuilderHolder(
output: Seq[AttributeReference],
relation: DataSourceV2Relation,
- builder: ScanBuilder) extends LeafNode
+ builder: ScanBuilder) extends LeafNode {
+ var pushedLimit: Option[Int] = None
+
+ var sortOrders: Seq[V2SortOrder] = Seq.empty[V2SortOrder]
+
+ var pushedSample: Option[TableSampleInfo] = None
+
+ var pushedPredicates: Seq[Predicate] = Seq.empty[Predicate]
+}
+
-// A wrapper for v1 scan to carry the translated filters and the handled ones. This is required by
-// the physical v1 scan node.
+// A wrapper for v1 scan to carry the translated filters and the handled ones, along with
+// other pushed down operators. This is required by the physical v1 scan node.
case class V1ScanWrapper(
v1Scan: V1Scan,
handledFilters: Seq[sources.Filter],
- pushedAggregate: Option[Aggregation]) extends Scan {
+ pushedDownOperators: PushedDownOperators) extends Scan {
override def readSchema(): StructType = v1Scan.readSchema()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala
index 33b8f22e3f88a..fe91cc486967b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalog.scala
@@ -261,12 +261,11 @@ class V2SessionCatalog(catalog: SessionCatalog)
}
}
- override def dropNamespace(namespace: Array[String]): Boolean = namespace match {
+ override def dropNamespace(
+ namespace: Array[String],
+ cascade: Boolean): Boolean = namespace match {
case Array(db) if catalog.databaseExists(db) =>
- if (catalog.listTables(db).nonEmpty) {
- throw QueryExecutionErrors.namespaceNotEmptyError(namespace)
- }
- catalog.dropDatabase(db, ignoreIfNotExists = false, cascade = false)
+ catalog.dropDatabase(db, ignoreIfNotExists = false, cascade)
true
case Array(_) =>
@@ -293,8 +292,8 @@ private[sql] object V2SessionCatalog {
case IdentityTransform(FieldReference(Seq(col))) =>
identityCols += col
- case BucketTransform(numBuckets, FieldReference(Seq(col))) =>
- bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, Nil))
+ case BucketTransform(numBuckets, FieldReference(Seq(col)), FieldReference(Seq(sortCol))) =>
+ bucketSpec = Some(BucketSpec(numBuckets, col :: Nil, sortCol :: Nil))
case transform =>
throw QueryExecutionErrors.unsupportedPartitionTransformError(transform)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
index 3f77b2147f9ca..cc3c146106670 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScan.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.csv.CSVDataSource
-import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan}
+import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -84,10 +84,6 @@ case class CSVScan(
dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters)
}
- override def withFilters(
- partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
- this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
-
override def equals(obj: Any): Boolean = obj match {
case c: CSVScan => super.equals(c) && dataSchema == c.dataSchema && options == c.options &&
equivalentFilters(pushedFilters, c.pushedFilters)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala
index f7a79bf31948e..2b6edd4f357ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/csv/CSVScanBuilder.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2.csv
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.StructFilters
-import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
+import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
@@ -32,7 +32,7 @@ case class CSVScanBuilder(
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap)
- extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters {
+ extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
override def build(): Scan = {
CSVScan(
@@ -42,17 +42,16 @@ case class CSVScanBuilder(
readDataSchema(),
readPartitionSchema(),
options,
- pushedFilters())
+ pushedDataFilters,
+ partitionFilters,
+ dataFilters)
}
- private var _pushedFilters: Array[Filter] = Array.empty
-
- override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = {
if (sparkSession.sessionState.conf.csvFilterPushDown) {
- _pushedFilters = StructFilters.pushedFilters(filters, dataSchema)
+ StructFilters.pushedFilters(dataFilters, dataSchema)
+ } else {
+ Array.empty[Filter]
}
- filters
}
-
- override def pushedFilters(): Array[Filter] = _pushedFilters
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
index ef42691e5ca94..f68f78d51fd96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScan.scala
@@ -18,17 +18,23 @@ package org.apache.spark.sql.execution.datasources.v2.jdbc
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.connector.expressions.SortOrder
+import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.V1Scan
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRelation
-import org.apache.spark.sql.sources.{BaseRelation, Filter, TableScan}
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
+import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType
case class JDBCScan(
relation: JDBCRelation,
prunedSchema: StructType,
- pushedFilters: Array[Filter],
+ pushedPredicates: Array[Predicate],
pushedAggregateColumn: Array[String] = Array(),
- groupByColumns: Option[Array[String]]) extends V1Scan {
+ groupByColumns: Option[Array[String]],
+ tableSample: Option[TableSampleInfo],
+ pushedLimit: Int,
+ sortOrders: Array[SortOrder]) extends V1Scan {
override def readSchema(): StructType = prunedSchema
@@ -43,7 +49,8 @@ case class JDBCScan(
} else {
pushedAggregateColumn
}
- relation.buildScan(columnList, prunedSchema, pushedFilters, groupByColumns)
+ relation.buildScan(columnList, prunedSchema, pushedPredicates, groupByColumns, tableSample,
+ pushedLimit, sortOrders)
}
}.asInstanceOf[T]
}
@@ -57,7 +64,7 @@ case class JDBCScan(
("[]", "[]")
}
super.description() + ", prunedSchema: " + seqToString(prunedSchema) +
- ", PushedFilters: " + seqToString(pushedFilters) +
+ ", PushedPredicates: " + seqToString(pushedPredicates) +
", PushedAggregates: " + aggString + ", PushedGroupBy: " + groupByString
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
index b0de7c015c91a..0a1542a42956d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCScanBuilder.scala
@@ -20,12 +20,14 @@ import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connector.expressions.SortOrder
import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
-import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownFilters, SupportsPushDownRequiredColumns}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownAggregates, SupportsPushDownLimit, SupportsPushDownRequiredColumns, SupportsPushDownTableSample, SupportsPushDownTopN, SupportsPushDownV2Filters}
import org.apache.spark.sql.execution.datasources.PartitioningUtils
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCRDD, JDBCRelation}
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.jdbc.JdbcDialects
-import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
case class JDBCScanBuilder(
@@ -33,40 +35,56 @@ case class JDBCScanBuilder(
schema: StructType,
jdbcOptions: JDBCOptions)
extends ScanBuilder
- with SupportsPushDownFilters
+ with SupportsPushDownV2Filters
with SupportsPushDownRequiredColumns
with SupportsPushDownAggregates
+ with SupportsPushDownLimit
+ with SupportsPushDownTableSample
+ with SupportsPushDownTopN
with Logging {
private val isCaseSensitive = session.sessionState.conf.caseSensitiveAnalysis
- private var pushedFilter = Array.empty[Filter]
+ private var pushedPredicate = Array.empty[Predicate]
private var finalSchema = schema
- override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ private var tableSample: Option[TableSampleInfo] = None
+
+ private var pushedLimit = 0
+
+ private var sortOrders: Array[SortOrder] = Array.empty[SortOrder]
+
+ override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
if (jdbcOptions.pushDownPredicate) {
val dialect = JdbcDialects.get(jdbcOptions.url)
- val (pushed, unSupported) = filters.partition(JDBCRDD.compileFilter(_, dialect).isDefined)
- this.pushedFilter = pushed
+ val (pushed, unSupported) = predicates.partition(dialect.compileExpression(_).isDefined)
+ this.pushedPredicate = pushed
unSupported
} else {
- filters
+ predicates
}
}
- override def pushedFilters(): Array[Filter] = pushedFilter
+ override def pushedPredicates(): Array[Predicate] = pushedPredicate
private var pushedAggregateList: Array[String] = Array()
private var pushedGroupByCols: Option[Array[String]] = None
+ override def supportCompletePushDown(aggregation: Aggregation): Boolean = {
+ lazy val fieldNames = aggregation.groupByColumns()(0).fieldNames()
+ jdbcOptions.numPartitions.map(_ == 1).getOrElse(true) ||
+ (aggregation.groupByColumns().length == 1 && fieldNames.length == 1 &&
+ jdbcOptions.partitionColumn.exists(fieldNames(0).equalsIgnoreCase(_)))
+ }
+
override def pushAggregation(aggregation: Aggregation): Boolean = {
if (!jdbcOptions.pushDownAggregate) return false
val dialect = JdbcDialects.get(jdbcOptions.url)
- val compiledAgg = JDBCRDD.compileAggregates(aggregation.aggregateExpressions, dialect)
- if (compiledAgg.isEmpty) return false
+ val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileAggregate)
+ if (compiledAggs.length != aggregation.aggregateExpressions.length) return false
val groupByCols = aggregation.groupByColumns.map { col =>
if (col.fieldNames.length != 1) return false
@@ -77,7 +95,7 @@ case class JDBCScanBuilder(
// e.g. "DEPT","NAME",MAX("SALARY"),MIN("BONUS") =>
// SELECT "DEPT","NAME",MAX("SALARY"),MIN("BONUS") FROM "test"."employee"
// GROUP BY "DEPT", "NAME"
- val selectList = groupByCols ++ compiledAgg.get
+ val selectList = groupByCols ++ compiledAggs
val groupByClause = if (groupByCols.isEmpty) {
""
} else {
@@ -98,6 +116,38 @@ case class JDBCScanBuilder(
}
}
+ override def pushTableSample(
+ lowerBound: Double,
+ upperBound: Double,
+ withReplacement: Boolean,
+ seed: Long): Boolean = {
+ if (jdbcOptions.pushDownTableSample &&
+ JdbcDialects.get(jdbcOptions.url).supportsTableSample) {
+ this.tableSample = Some(TableSampleInfo(lowerBound, upperBound, withReplacement, seed))
+ return true
+ }
+ false
+ }
+
+ override def pushLimit(limit: Int): Boolean = {
+ if (jdbcOptions.pushDownLimit) {
+ pushedLimit = limit
+ return true
+ }
+ false
+ }
+
+ override def pushTopN(orders: Array[SortOrder], limit: Int): Boolean = {
+ if (jdbcOptions.pushDownLimit) {
+ pushedLimit = limit
+ sortOrders = orders
+ return true
+ }
+ false
+ }
+
+ override def isPartiallyPushed(): Boolean = jdbcOptions.numPartitions.map(_ > 1).getOrElse(false)
+
override def pruneColumns(requiredSchema: StructType): Unit = {
// JDBC doesn't support nested column pruning.
// TODO (SPARK-32593): JDBC support nested column and nested column pruning.
@@ -122,7 +172,7 @@ case class JDBCScanBuilder(
// "DEPT","NAME",MAX("SALARY"),MIN("BONUS"), instead of getting column names from
// prunedSchema and quote them (will become "MAX(SALARY)", "MIN(BONUS)" and can't
// be used in sql string.
- JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedFilter,
- pushedAggregateList, pushedGroupByCols)
+ JDBCScan(JDBCRelation(schema, parts, jdbcOptions)(session), finalSchema, pushedPredicate,
+ pushedAggregateList, pushedGroupByCols, tableSample, pushedLimit, sortOrders)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala
index 5e11ea66be4c6..793b72727b9ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTable.scala
@@ -23,13 +23,16 @@ import scala.collection.JavaConverters._
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog._
import org.apache.spark.sql.connector.catalog.TableCapability._
+import org.apache.spark.sql.connector.catalog.index.{SupportsIndex, TableIndex}
+import org.apache.spark.sql.connector.expressions.NamedReference
import org.apache.spark.sql.connector.write.{LogicalWriteInfo, WriteBuilder}
-import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite}
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcOptionsInWrite, JdbcUtils}
+import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOptions)
- extends Table with SupportsRead with SupportsWrite {
+ extends Table with SupportsRead with SupportsWrite with SupportsIndex {
override def name(): String = ident.toString
@@ -48,4 +51,39 @@ case class JDBCTable(ident: Identifier, schema: StructType, jdbcOptions: JDBCOpt
jdbcOptions.parameters.originalMap ++ info.options.asCaseSensitiveMap().asScala)
JDBCWriteBuilder(schema, mergedOptions)
}
+
+ override def createIndex(
+ indexName: String,
+ columns: Array[NamedReference],
+ columnsProperties: util.Map[NamedReference, util.Map[String, String]],
+ properties: util.Map[String, String]): Unit = {
+ JdbcUtils.withConnection(jdbcOptions) { conn =>
+ JdbcUtils.classifyException(s"Failed to create index $indexName in $name",
+ JdbcDialects.get(jdbcOptions.url)) {
+ JdbcUtils.createIndex(
+ conn, indexName, name, columns, columnsProperties, properties, jdbcOptions)
+ }
+ }
+ }
+
+ override def indexExists(indexName: String): Boolean = {
+ JdbcUtils.withConnection(jdbcOptions) { conn =>
+ JdbcUtils.indexExists(conn, indexName, name, jdbcOptions)
+ }
+ }
+
+ override def dropIndex(indexName: String): Unit = {
+ JdbcUtils.withConnection(jdbcOptions) { conn =>
+ JdbcUtils.classifyException(s"Failed to drop index: $indexName",
+ JdbcDialects.get(jdbcOptions.url)) {
+ JdbcUtils.dropIndex(conn, indexName, name, jdbcOptions)
+ }
+ }
+ }
+
+ override def listIndexes(): Array[TableIndex] = {
+ JdbcUtils.withConnection(jdbcOptions) { conn =>
+ JdbcUtils.listIndexes(conn, name, jdbcOptions)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala
index a90ab564ddb50..03200d5a6f371 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCTableCatalog.scala
@@ -16,12 +16,11 @@
*/
package org.apache.spark.sql.execution.datasources.v2.jdbc
-import java.sql.{Connection, SQLException}
+import java.sql.SQLException
import java.util
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.collection.mutable.ArrayBuilder
import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.catalog.{Identifier, NamespaceChange, SupportsNamespaces, Table, TableCatalog, TableChange}
@@ -57,7 +56,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
override def listTables(namespace: Array[String]): Array[Identifier] = {
checkNamespace(namespace)
- withConnection { conn =>
+ JdbcUtils.withConnection(options) { conn =>
val schemaPattern = if (namespace.length == 1) namespace.head else null
val rs = conn.getMetaData
.getTables(null, schemaPattern, "%", Array("TABLE"));
@@ -72,14 +71,14 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
checkNamespace(ident.namespace())
val writeOptions = new JdbcOptionsInWrite(
options.parameters + (JDBCOptions.JDBC_TABLE_NAME -> getTableName(ident)))
- classifyException(s"Failed table existence check: $ident") {
- withConnection(JdbcUtils.tableExists(_, writeOptions))
+ JdbcUtils.classifyException(s"Failed table existence check: $ident", dialect) {
+ JdbcUtils.withConnection(options)(JdbcUtils.tableExists(_, writeOptions))
}
}
override def dropTable(ident: Identifier): Boolean = {
checkNamespace(ident.namespace())
- withConnection { conn =>
+ JdbcUtils.withConnection(options) { conn =>
try {
JdbcUtils.dropTable(conn, getTableName(ident), options)
true
@@ -91,8 +90,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = {
checkNamespace(oldIdent.namespace())
- withConnection { conn =>
- classifyException(s"Failed table renaming from $oldIdent to $newIdent") {
+ JdbcUtils.withConnection(options) { conn =>
+ JdbcUtils.classifyException(s"Failed table renaming from $oldIdent to $newIdent", dialect) {
JdbcUtils.renameTable(conn, getTableName(oldIdent), getTableName(newIdent), options)
}
}
@@ -151,8 +150,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
val writeOptions = new JdbcOptionsInWrite(tableOptions)
val caseSensitive = SQLConf.get.caseSensitiveAnalysis
- withConnection { conn =>
- classifyException(s"Failed table creation: $ident") {
+ JdbcUtils.withConnection(options) { conn =>
+ JdbcUtils.classifyException(s"Failed table creation: $ident", dialect) {
JdbcUtils.createTable(conn, getTableName(ident), schema, caseSensitive, writeOptions)
}
}
@@ -162,8 +161,8 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
override def alterTable(ident: Identifier, changes: TableChange*): Table = {
checkNamespace(ident.namespace())
- withConnection { conn =>
- classifyException(s"Failed table altering: $ident") {
+ JdbcUtils.withConnection(options) { conn =>
+ JdbcUtils.classifyException(s"Failed table altering: $ident", dialect) {
JdbcUtils.alterTable(conn, getTableName(ident), changes, options)
}
loadTable(ident)
@@ -172,24 +171,15 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
override def namespaceExists(namespace: Array[String]): Boolean = namespace match {
case Array(db) =>
- withConnection { conn =>
- val rs = conn.getMetaData.getSchemas(null, db)
- while (rs.next()) {
- if (rs.getString(1) == db) return true;
- }
- false
+ JdbcUtils.withConnection(options) { conn =>
+ JdbcUtils.schemaExists(conn, options, db)
}
case _ => false
}
override def listNamespaces(): Array[Array[String]] = {
- withConnection { conn =>
- val schemaBuilder = ArrayBuilder.make[Array[String]]
- val rs = conn.getMetaData.getSchemas()
- while (rs.next()) {
- schemaBuilder += Array(rs.getString(1))
- }
- schemaBuilder.result
+ JdbcUtils.withConnection(options) { conn =>
+ JdbcUtils.listSchemas(conn, options)
}
}
@@ -234,9 +224,9 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
}
}
}
- withConnection { conn =>
- classifyException(s"Failed create name space: $db") {
- JdbcUtils.createNamespace(conn, options, db, comment)
+ JdbcUtils.withConnection(options) { conn =>
+ JdbcUtils.classifyException(s"Failed create name space: $db", dialect) {
+ JdbcUtils.createSchema(conn, options, db, comment)
}
}
@@ -253,8 +243,10 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
changes.foreach {
case set: NamespaceChange.SetProperty =>
if (set.property() == SupportsNamespaces.PROP_COMMENT) {
- withConnection { conn =>
- JdbcUtils.createNamespaceComment(conn, options, db, set.value)
+ JdbcUtils.withConnection(options) { conn =>
+ JdbcUtils.classifyException(s"Failed create comment on name space: $db", dialect) {
+ JdbcUtils.alterSchemaComment(conn, options, db, set.value)
+ }
}
} else {
throw QueryCompilationErrors.cannotSetJDBCNamespaceWithPropertyError(set.property)
@@ -262,8 +254,10 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
case unset: NamespaceChange.RemoveProperty =>
if (unset.property() == SupportsNamespaces.PROP_COMMENT) {
- withConnection { conn =>
- JdbcUtils.removeNamespaceComment(conn, options, db)
+ JdbcUtils.withConnection(options) { conn =>
+ JdbcUtils.classifyException(s"Failed remove comment on name space: $db", dialect) {
+ JdbcUtils.removeSchemaComment(conn, options, db)
+ }
}
} else {
throw QueryCompilationErrors.cannotUnsetJDBCNamespaceWithPropertyError(unset.property)
@@ -278,14 +272,13 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
}
}
- override def dropNamespace(namespace: Array[String]): Boolean = namespace match {
+ override def dropNamespace(
+ namespace: Array[String],
+ cascade: Boolean): Boolean = namespace match {
case Array(db) if namespaceExists(namespace) =>
- if (listTables(Array(db)).nonEmpty) {
- throw QueryExecutionErrors.namespaceNotEmptyError(namespace)
- }
- withConnection { conn =>
- classifyException(s"Failed drop name space: $db") {
- JdbcUtils.dropNamespace(conn, options, db)
+ JdbcUtils.withConnection(options) { conn =>
+ JdbcUtils.classifyException(s"Failed drop name space: $db", dialect) {
+ JdbcUtils.dropSchema(conn, options, db, cascade)
true
}
}
@@ -301,24 +294,7 @@ class JDBCTableCatalog extends TableCatalog with SupportsNamespaces with Logging
}
}
- private def withConnection[T](f: Connection => T): T = {
- val conn = JdbcUtils.createConnectionFactory(options)()
- try {
- f(conn)
- } finally {
- conn.close()
- }
- }
-
private def getTableName(ident: Identifier): String = {
(ident.namespace() :+ ident.name()).map(dialect.quoteIdentifier).mkString(".")
}
-
- private def classifyException[T](message: String)(f: => T): T = {
- try {
- f
- } catch {
- case e: Throwable => throw dialect.classifyException(message, e)
- }
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala
index 0e6c72c2cc331..7449f66ee020f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/jdbc/JDBCWriteBuilder.scala
@@ -20,6 +20,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.connector.write._
import org.apache.spark.sql.execution.datasources.jdbc.{JdbcOptionsInWrite, JdbcUtils}
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.jdbc.JdbcDialects
import org.apache.spark.sql.sources.InsertableRelation
import org.apache.spark.sql.types.StructType
@@ -37,7 +38,8 @@ case class JDBCWriteBuilder(schema: StructType, options: JdbcOptionsInWrite) ext
override def toInsertableRelation: InsertableRelation = (data: DataFrame, _: Boolean) => {
// TODO (SPARK-32595): do truncate and append atomically.
if (isTruncate) {
- val conn = JdbcUtils.createConnectionFactory(options)()
+ val dialect = JdbcDialects.get(options.url)
+ val conn = dialect.createConnectionFactory(options)(-1)
JdbcUtils.truncateTable(conn, options)
}
JdbcUtils.saveTable(data, Some(schema), SQLConf.get.caseSensitiveAnalysis, options)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala
index 29eb8bec9a589..9ab367136fc97 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScan.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.json.JsonDataSource
-import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan}
+import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
@@ -83,10 +83,6 @@ case class JsonScan(
dataSchema, readDataSchema, readPartitionSchema, parsedOptions, pushedFilters)
}
- override def withFilters(
- partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
- this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
-
override def equals(obj: Any): Boolean = obj match {
case j: JsonScan => super.equals(j) && dataSchema == j.dataSchema && options == j.options &&
equivalentFilters(pushedFilters, j.pushedFilters)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala
index cf1204566ddbd..c581617a4b7e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/json/JsonScanBuilder.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.json
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.StructFilters
-import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
+import org.apache.spark.sql.connector.read.Scan
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.sources.Filter
@@ -31,7 +31,7 @@ class JsonScanBuilder (
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap)
- extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters {
+ extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
override def build(): Scan = {
JsonScan(
sparkSession,
@@ -40,17 +40,16 @@ class JsonScanBuilder (
readDataSchema(),
readPartitionSchema(),
options,
- pushedFilters())
+ pushedDataFilters,
+ partitionFilters,
+ dataFilters)
}
- private var _pushedFilters: Array[Filter] = Array.empty
-
- override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = {
if (sparkSession.sessionState.conf.jsonFilterPushDown) {
- _pushedFilters = StructFilters.pushedFilters(filters, dataSchema)
+ StructFilters.pushedFilters(dataFilters, dataSchema)
+ } else {
+ Array.empty[Filter]
}
- filters
}
-
- override def pushedFilters(): Array[Filter] = _pushedFilters
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
index 414252cc12481..79c34827c0bec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
@@ -23,14 +23,15 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.{JobID, TaskAttemptID, TaskID, TaskType}
import org.apache.hadoop.mapreduce.lib.input.FileSplit
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
-import org.apache.orc.{OrcConf, OrcFile, TypeDescription}
+import org.apache.orc.{OrcConf, OrcFile, Reader, TypeDescription}
import org.apache.orc.mapred.OrcStruct
import org.apache.orc.mapreduce.OrcInputFormat
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
-import org.apache.spark.sql.execution.datasources.PartitionedFile
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitionedFile}
import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils}
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
@@ -54,7 +55,8 @@ case class OrcPartitionReaderFactory(
dataSchema: StructType,
readDataSchema: StructType,
partitionSchema: StructType,
- filters: Array[Filter]) extends FilePartitionReaderFactory {
+ filters: Array[Filter],
+ aggregation: Option[Aggregation]) extends FilePartitionReaderFactory {
private val resultSchema = StructType(readDataSchema.fields ++ partitionSchema.fields)
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val capacity = sqlConf.orcVectorizedReaderBatchSize
@@ -79,17 +81,14 @@ case class OrcPartitionReaderFactory(
override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
val conf = broadcastedConf.value.value
-
- OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)
-
val filePath = new Path(new URI(file.filePath))
- pushDownPredicates(filePath, conf)
+ if (aggregation.nonEmpty) {
+ return buildReaderWithAggregates(filePath, conf)
+ }
- val fs = filePath.getFileSystem(conf)
- val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
- Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader =>
+ Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
OrcUtils.requestedColumnIds(
isCaseSensitive, dataSchema, readDataSchema, reader, conf)
}
@@ -126,17 +125,14 @@ case class OrcPartitionReaderFactory(
override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
val conf = broadcastedConf.value.value
-
- OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)
-
val filePath = new Path(new URI(file.filePath))
- pushDownPredicates(filePath, conf)
+ if (aggregation.nonEmpty) {
+ return buildColumnarReaderWithAggregates(filePath, conf)
+ }
- val fs = filePath.getFileSystem(conf)
- val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
val resultedColPruneInfo =
- Utils.tryWithResource(OrcFile.createReader(filePath, readerOptions)) { reader =>
+ Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
OrcUtils.requestedColumnIds(
isCaseSensitive, dataSchema, readDataSchema, reader, conf)
}
@@ -171,4 +167,67 @@ case class OrcPartitionReaderFactory(
}
}
+ private def createORCReader(filePath: Path, conf: Configuration): Reader = {
+ OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(conf, isCaseSensitive)
+
+ pushDownPredicates(filePath, conf)
+
+ val fs = filePath.getFileSystem(conf)
+ val readerOptions = OrcFile.readerOptions(conf).filesystem(fs)
+ OrcFile.createReader(filePath, readerOptions)
+ }
+
+ /**
+ * Build reader with aggregate push down.
+ */
+ private def buildReaderWithAggregates(
+ filePath: Path,
+ conf: Configuration): PartitionReader[InternalRow] = {
+ new PartitionReader[InternalRow] {
+ private var hasNext = true
+ private lazy val row: InternalRow = {
+ Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
+ OrcUtils.createAggInternalRowFromFooter(
+ reader, filePath.toString, dataSchema, partitionSchema, aggregation.get, readDataSchema)
+ }
+ }
+
+ override def next(): Boolean = hasNext
+
+ override def get(): InternalRow = {
+ hasNext = false
+ row
+ }
+
+ override def close(): Unit = {}
+ }
+ }
+
+ /**
+ * Build columnar reader with aggregate push down.
+ */
+ private def buildColumnarReaderWithAggregates(
+ filePath: Path,
+ conf: Configuration): PartitionReader[ColumnarBatch] = {
+ new PartitionReader[ColumnarBatch] {
+ private var hasNext = true
+ private lazy val batch: ColumnarBatch = {
+ Utils.tryWithResource(createORCReader(filePath, conf)) { reader =>
+ val row = OrcUtils.createAggInternalRowFromFooter(
+ reader, filePath.toString, dataSchema, partitionSchema, aggregation.get,
+ readDataSchema)
+ AggregatePushDownUtils.convertAggregatesRowToBatch(row, readDataSchema, offHeap = false)
+ }
+ }
+
+ override def next(): Boolean = hasNext
+
+ override def get(): ColumnarBatch = {
+ hasNext = false
+ batch
+ }
+
+ override def close(): Unit = {}
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
index 8fa7f8dc41ead..6b9d181a7f4c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
@@ -21,8 +21,9 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.PartitionReaderFactory
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.sources.Filter
import org.apache.spark.sql.types.StructType
@@ -37,10 +38,25 @@ case class OrcScan(
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap,
+ pushedAggregate: Option[Aggregation] = None,
pushedFilters: Array[Filter],
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
- override def isSplitable(path: Path): Boolean = true
+ override def isSplitable(path: Path): Boolean = {
+ // If aggregate is pushed down, only the file footer will be read once,
+ // so file should be not split across multiple tasks.
+ pushedAggregate.isEmpty
+ }
+
+ override def readSchema(): StructType = {
+ // If aggregate is pushed down, schema has already been pruned in `OrcScanBuilder`
+ // and no need to call super.readSchema()
+ if (pushedAggregate.nonEmpty) {
+ readDataSchema
+ } else {
+ super.readSchema()
+ }
+ }
override def createReaderFactory(): PartitionReaderFactory = {
val broadcastedConf = sparkSession.sparkContext.broadcast(
@@ -48,28 +64,39 @@ case class OrcScan(
// The partition values are already truncated in `FileScan.partitions`.
// We should use `readPartitionSchema` as the partition schema here.
OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
- dataSchema, readDataSchema, readPartitionSchema, pushedFilters)
+ dataSchema, readDataSchema, readPartitionSchema, pushedFilters, pushedAggregate)
}
override def equals(obj: Any): Boolean = obj match {
case o: OrcScan =>
+ val pushedDownAggEqual = if (pushedAggregate.nonEmpty && o.pushedAggregate.nonEmpty) {
+ AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, o.pushedAggregate.get)
+ } else {
+ pushedAggregate.isEmpty && o.pushedAggregate.isEmpty
+ }
super.equals(o) && dataSchema == o.dataSchema && options == o.options &&
- equivalentFilters(pushedFilters, o.pushedFilters)
-
+ equivalentFilters(pushedFilters, o.pushedFilters) && pushedDownAggEqual
case _ => false
}
override def hashCode(): Int = getClass.hashCode()
+ lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
+ (seqToString(pushedAggregate.get.aggregateExpressions),
+ seqToString(pushedAggregate.get.groupByColumns))
+ } else {
+ ("[]", "[]")
+ }
+
override def description(): String = {
- super.description() + ", PushedFilters: " + seqToString(pushedFilters)
+ super.description() + ", PushedFilters: " + seqToString(pushedFilters) +
+ ", PushedAggregation: " + pushedAggregationsStr +
+ ", PushedGroupBy: " + pushedGroupByStr
}
override def getMetaData(): Map[String, String] = {
- super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters))
+ super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++
+ Map("PushedAggregation" -> pushedAggregationsStr) ++
+ Map("PushedGroupBy" -> pushedGroupByStr)
}
-
- override def withFilters(
- partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
- this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
index dc59526bb316b..d2c17fda4a382 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.v2.orc
import scala.collection.JavaConverters._
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates}
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.orc.OrcFilters
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.internal.SQLConf
@@ -35,30 +36,59 @@ case class OrcScanBuilder(
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap)
- extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters {
+ extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
+ with SupportsPushDownAggregates {
+
lazy val hadoopConf = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap)
}
+ private var finalSchema = new StructType()
+
+ private var pushedAggregations = Option.empty[Aggregation]
+
override protected val supportsNestedSchemaPruning: Boolean = true
override def build(): Scan = {
- OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema,
- readDataSchema(), readPartitionSchema(), options, pushedFilters())
+ // the `finalSchema` is either pruned in pushAggregation (if aggregates are
+ // pushed down), or pruned in readDataSchema() (in regular column pruning). These
+ // two are mutual exclusive.
+ if (pushedAggregations.isEmpty) {
+ finalSchema = readDataSchema()
+ }
+ OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema,
+ readPartitionSchema(), options, pushedAggregations, pushedDataFilters, partitionFilters,
+ dataFilters)
}
- private var _pushedFilters: Array[Filter] = Array.empty
-
- override def pushFilters(filters: Array[Filter]): Array[Filter] = {
+ override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = {
if (sparkSession.sessionState.conf.orcFilterPushDown) {
val dataTypeMap = OrcFilters.getSearchableTypeMap(
readDataSchema(), SQLConf.get.caseSensitiveAnalysis)
- _pushedFilters = OrcFilters.convertibleFilters(dataTypeMap, filters).toArray
+ OrcFilters.convertibleFilters(dataTypeMap, dataFilters).toArray
+ } else {
+ Array.empty[Filter]
}
- filters
}
- override def pushedFilters(): Array[Filter] = _pushedFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+ if (!sparkSession.sessionState.conf.orcAggregatePushDown) {
+ return false
+ }
+
+ AggregatePushDownUtils.getSchemaForPushedAggregation(
+ aggregation,
+ schema,
+ partitionNameSet,
+ dataFilters) match {
+
+ case Some(schema) =>
+ finalSchema = schema
+ this.pushedAggregations = Some(aggregation)
+ true
+ case _ => false
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
index 058669b0937fa..6f021ff2e97f5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
@@ -25,16 +25,18 @@ import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
import org.apache.parquet.filter2.compat.FilterCompat
import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate}
-import org.apache.parquet.format.converter.ParquetMetadataConverter.SKIP_ROW_GROUPS
+import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS}
import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader}
+import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
-import org.apache.spark.sql.execution.datasources.{DataSourceUtils, PartitionedFile, RecordReaderIterator}
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile, RecordReaderIterator}
import org.apache.spark.sql.execution.datasources.parquet._
import org.apache.spark.sql.execution.datasources.v2._
import org.apache.spark.sql.internal.SQLConf
@@ -53,6 +55,7 @@ import org.apache.spark.util.SerializableConfiguration
* @param readDataSchema Required schema of Parquet files.
* @param partitionSchema Schema of partitions.
* @param filters Filters to be pushed down in the batch scan.
+ * @param aggregation Aggregation to be pushed down in the batch scan.
* @param parquetOptions The options of Parquet datasource that are set for the read.
*/
case class ParquetPartitionReaderFactory(
@@ -62,6 +65,7 @@ case class ParquetPartitionReaderFactory(
readDataSchema: StructType,
partitionSchema: StructType,
filters: Array[Filter],
+ aggregation: Option[Aggregation],
parquetOptions: ParquetOptions) extends FilePartitionReaderFactory with Logging {
private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
private val resultSchema = StructType(partitionSchema.fields ++ readDataSchema.fields)
@@ -80,6 +84,30 @@ case class ParquetPartitionReaderFactory(
private val datetimeRebaseModeInRead = parquetOptions.datetimeRebaseModeInRead
private val int96RebaseModeInRead = parquetOptions.int96RebaseModeInRead
+ private def getFooter(file: PartitionedFile): ParquetMetadata = {
+ val conf = broadcastedConf.value.value
+ val filePath = new Path(new URI(file.filePath))
+
+ if (aggregation.isEmpty) {
+ ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS)
+ } else {
+ // For aggregate push down, we will get max/min/count from footer statistics.
+ // We want to read the footer for the whole file instead of reading multiple
+ // footers for every split of the file. Basically if the start (the beginning of)
+ // the offset in PartitionedFile is 0, we will read the footer. Otherwise, it means
+ // that we have already read footer for that file, so we will skip reading again.
+ if (file.start != 0) return null
+ ParquetFooterReader.readFooter(conf, filePath, NO_FILTER)
+ }
+ }
+
+ private def getDatetimeRebaseMode(
+ footerFileMetaData: FileMetaData): LegacyBehaviorPolicy.Value = {
+ DataSourceUtils.datetimeRebaseMode(
+ footerFileMetaData.getKeyValueMetaData.get,
+ datetimeRebaseModeInRead)
+ }
+
override def supportColumnarReads(partition: InputPartition): Boolean = {
sqlConf.parquetVectorizedReaderEnabled && sqlConf.wholeStageEnabled &&
resultSchema.length <= sqlConf.wholeStageMaxNumFields &&
@@ -87,18 +115,44 @@ case class ParquetPartitionReaderFactory(
}
override def buildReader(file: PartitionedFile): PartitionReader[InternalRow] = {
- val reader = if (enableVectorizedReader) {
- createVectorizedReader(file)
- } else {
- createRowBaseReader(file)
- }
+ val fileReader = if (aggregation.isEmpty) {
+ val reader = if (enableVectorizedReader) {
+ createVectorizedReader(file)
+ } else {
+ createRowBaseReader(file)
+ }
+
+ new PartitionReader[InternalRow] {
+ override def next(): Boolean = reader.nextKeyValue()
- val fileReader = new PartitionReader[InternalRow] {
- override def next(): Boolean = reader.nextKeyValue()
+ override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow]
- override def get(): InternalRow = reader.getCurrentValue.asInstanceOf[InternalRow]
+ override def close(): Unit = reader.close()
+ }
+ } else {
+ new PartitionReader[InternalRow] {
+ private var hasNext = true
+ private lazy val row: InternalRow = {
+ val footer = getFooter(file)
+ if (footer != null && footer.getBlocks.size > 0) {
+ ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath, dataSchema,
+ partitionSchema, aggregation.get, readDataSchema,
+ getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
+ } else {
+ null
+ }
+ }
+ override def next(): Boolean = {
+ hasNext && row != null
+ }
- override def close(): Unit = reader.close()
+ override def get(): InternalRow = {
+ hasNext = false
+ row
+ }
+
+ override def close(): Unit = {}
+ }
}
new PartitionReaderWithPartitionValues(fileReader, readDataSchema,
@@ -106,17 +160,47 @@ case class ParquetPartitionReaderFactory(
}
override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = {
- val vectorizedReader = createVectorizedReader(file)
- vectorizedReader.enableReturningBatches()
+ val fileReader = if (aggregation.isEmpty) {
+ val vectorizedReader = createVectorizedReader(file)
+ vectorizedReader.enableReturningBatches()
+
+ new PartitionReader[ColumnarBatch] {
+ override def next(): Boolean = vectorizedReader.nextKeyValue()
- new PartitionReader[ColumnarBatch] {
- override def next(): Boolean = vectorizedReader.nextKeyValue()
+ override def get(): ColumnarBatch =
+ vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
- override def get(): ColumnarBatch =
- vectorizedReader.getCurrentValue.asInstanceOf[ColumnarBatch]
+ override def close(): Unit = vectorizedReader.close()
+ }
+ } else {
+ new PartitionReader[ColumnarBatch] {
+ private var hasNext = true
+ private val batch: ColumnarBatch = {
+ val footer = getFooter(file)
+ if (footer != null && footer.getBlocks.size > 0) {
+ val row = ParquetUtils.createAggInternalRowFromFooter(footer, file.filePath,
+ dataSchema, partitionSchema, aggregation.get, readDataSchema,
+ getDatetimeRebaseMode(footer.getFileMetaData), isCaseSensitive)
+ AggregatePushDownUtils.convertAggregatesRowToBatch(
+ row, readDataSchema, enableOffHeapColumnVector && Option(TaskContext.get()).isDefined)
+ } else {
+ null
+ }
+ }
+
+ override def next(): Boolean = {
+ hasNext && batch != null
+ }
+
+ override def get(): ColumnarBatch = {
+ hasNext = false
+ batch
+ }
- override def close(): Unit = vectorizedReader.close()
+ override def close(): Unit = {}
+ }
}
+ fileReader
}
private def buildReaderBase[T](
@@ -131,11 +215,8 @@ case class ParquetPartitionReaderFactory(
val filePath = new Path(new URI(file.filePath))
val split = new FileSplit(filePath, file.start, file.length, Array.empty[String])
- lazy val footerFileMetaData =
- ParquetFooterReader.readFooter(conf, filePath, SKIP_ROW_GROUPS).getFileMetaData
- val datetimeRebaseMode = DataSourceUtils.datetimeRebaseMode(
- footerFileMetaData.getKeyValueMetaData.get,
- datetimeRebaseModeInRead)
+ lazy val footerFileMetaData = getFooter(file).getFileMetaData
+ val datetimeRebaseMode = getDatetimeRebaseMode(footerFileMetaData)
// Try to push down filters when filter push-down is enabled.
val pushed = if (enableParquetFilterPushDown) {
val parquetSchema = footerFileMetaData.getSchema
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
index 60573ba10ccb6..b92ed82190ae8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScan.scala
@@ -24,8 +24,9 @@ import org.apache.parquet.hadoop.ParquetInputFormat
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
import org.apache.spark.sql.connector.read.PartitionReaderFactory
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetReadSupport, ParquetWriteSupport}
import org.apache.spark.sql.execution.datasources.v2.FileScan
import org.apache.spark.sql.internal.SQLConf
@@ -43,10 +44,17 @@ case class ParquetScan(
readPartitionSchema: StructType,
pushedFilters: Array[Filter],
options: CaseInsensitiveStringMap,
+ pushedAggregate: Option[Aggregation] = None,
partitionFilters: Seq[Expression] = Seq.empty,
dataFilters: Seq[Expression] = Seq.empty) extends FileScan {
override def isSplitable(path: Path): Boolean = true
+ override def readSchema(): StructType = {
+ // If aggregate is pushed down, schema has already been pruned in `ParquetScanBuilder`
+ // and no need to call super.readSchema()
+ if (pushedAggregate.nonEmpty) readDataSchema else super.readSchema()
+ }
+
override def createReaderFactory(): PartitionReaderFactory = {
val readDataSchemaAsJson = readDataSchema.json
hadoopConf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[ParquetReadSupport].getName)
@@ -86,27 +94,40 @@ case class ParquetScan(
readDataSchema,
readPartitionSchema,
pushedFilters,
+ pushedAggregate,
new ParquetOptions(options.asCaseSensitiveMap.asScala.toMap, sqlConf))
}
override def equals(obj: Any): Boolean = obj match {
case p: ParquetScan =>
+ val pushedDownAggEqual = if (pushedAggregate.nonEmpty && p.pushedAggregate.nonEmpty) {
+ AggregatePushDownUtils.equivalentAggregations(pushedAggregate.get, p.pushedAggregate.get)
+ } else {
+ pushedAggregate.isEmpty && p.pushedAggregate.isEmpty
+ }
super.equals(p) && dataSchema == p.dataSchema && options == p.options &&
- equivalentFilters(pushedFilters, p.pushedFilters)
+ equivalentFilters(pushedFilters, p.pushedFilters) && pushedDownAggEqual
case _ => false
}
override def hashCode(): Int = getClass.hashCode()
+ lazy private val (pushedAggregationsStr, pushedGroupByStr) = if (pushedAggregate.nonEmpty) {
+ (seqToString(pushedAggregate.get.aggregateExpressions),
+ seqToString(pushedAggregate.get.groupByColumns))
+ } else {
+ ("[]", "[]")
+ }
+
override def description(): String = {
- super.description() + ", PushedFilters: " + seqToString(pushedFilters)
+ super.description() + ", PushedFilters: " + seqToString(pushedFilters) +
+ ", PushedAggregation: " + pushedAggregationsStr +
+ ", PushedGroupBy: " + pushedGroupByStr
}
override def getMetaData(): Map[String, String] = {
- super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters))
+ super.getMetaData() ++ Map("PushedFilters" -> seqToString(pushedFilters)) ++
+ Map("PushedAggregation" -> pushedAggregationsStr) ++
+ Map("PushedGroupBy" -> pushedGroupByStr)
}
-
- override def withFilters(
- partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
- this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
index 4b3f4e7edca6c..d198321eacdb6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetScanBuilder.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.datasources.v2.parquet
import scala.collection.JavaConverters._
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownFilters}
-import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
+import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
+import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownAggregates}
+import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, PartitioningAwareFileIndex}
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFilters, SparkToParquetSchemaConverter}
import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder
import org.apache.spark.sql.internal.SQLConf.LegacyBehaviorPolicy
@@ -35,7 +36,8 @@ case class ParquetScanBuilder(
schema: StructType,
dataSchema: StructType,
options: CaseInsensitiveStringMap)
- extends FileScanBuilder(sparkSession, fileIndex, dataSchema) with SupportsPushDownFilters {
+ extends FileScanBuilder(sparkSession, fileIndex, dataSchema)
+ with SupportsPushDownAggregates{
lazy val hadoopConf = {
val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap
// Hadoop Configurations are case sensitive.
@@ -63,25 +65,50 @@ case class ParquetScanBuilder(
// The rebase mode doesn't matter here because the filters are used to determine
// whether they is convertible.
LegacyBehaviorPolicy.CORRECTED)
- parquetFilters.convertibleFilters(this.filters).toArray
+ parquetFilters.convertibleFilters(pushedDataFilters).toArray
}
- override protected val supportsNestedSchemaPruning: Boolean = true
+ private var finalSchema = new StructType()
- private var filters: Array[Filter] = Array.empty
+ private var pushedAggregations = Option.empty[Aggregation]
- override def pushFilters(filters: Array[Filter]): Array[Filter] = {
- this.filters = filters
- this.filters
- }
+ override protected val supportsNestedSchemaPruning: Boolean = true
+
+ override def pushDataFilters(dataFilters: Array[Filter]): Array[Filter] = dataFilters
// Note: for Parquet, the actual filter push down happens in [[ParquetPartitionReaderFactory]].
// It requires the Parquet physical schema to determine whether a filter is convertible.
// All filters that can be converted to Parquet are pushed down.
override def pushedFilters(): Array[Filter] = pushedParquetFilters
+ override def pushAggregation(aggregation: Aggregation): Boolean = {
+ if (!sparkSession.sessionState.conf.parquetAggregatePushDown) {
+ return false
+ }
+
+ AggregatePushDownUtils.getSchemaForPushedAggregation(
+ aggregation,
+ schema,
+ partitionNameSet,
+ dataFilters) match {
+
+ case Some(schema) =>
+ finalSchema = schema
+ this.pushedAggregations = Some(aggregation)
+ true
+ case _ => false
+ }
+ }
+
override def build(): Scan = {
- ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, readDataSchema(),
- readPartitionSchema(), pushedParquetFilters, options)
+ // the `finalSchema` is either pruned in pushAggregation (if aggregates are
+ // pushed down), or pruned in readDataSchema() (in regular column pruning). These
+ // two are mutual exclusive.
+ if (pushedAggregations.isEmpty) {
+ finalSchema = readDataSchema()
+ }
+ ParquetScan(sparkSession, hadoopConf, fileIndex, dataSchema, finalSchema,
+ readPartitionSchema(), pushedParquetFilters, options, pushedAggregations,
+ partitionFilters, dataFilters)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala
index e75de2c4a4079..c7b0fec34b4e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScan.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.connector.read.PartitionReaderFactory
import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex
import org.apache.spark.sql.execution.datasources.text.TextOptions
-import org.apache.spark.sql.execution.datasources.v2.{FileScan, TextBasedFileScan}
+import org.apache.spark.sql.execution.datasources.v2.TextBasedFileScan
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.SerializableConfiguration
@@ -33,6 +33,7 @@ import org.apache.spark.util.SerializableConfiguration
case class TextScan(
sparkSession: SparkSession,
fileIndex: PartitioningAwareFileIndex,
+ dataSchema: StructType,
readDataSchema: StructType,
readPartitionSchema: StructType,
options: CaseInsensitiveStringMap,
@@ -71,10 +72,6 @@ case class TextScan(
readPartitionSchema, textOptions)
}
- override def withFilters(
- partitionFilters: Seq[Expression], dataFilters: Seq[Expression]): FileScan =
- this.copy(partitionFilters = partitionFilters, dataFilters = dataFilters)
-
override def equals(obj: Any): Boolean = obj match {
case t: TextScan => super.equals(t) && options == t.options
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala
index b2b518c12b01a..0ebb098bfc1df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/text/TextScanBuilder.scala
@@ -33,6 +33,7 @@ case class TextScanBuilder(
extends FileScanBuilder(sparkSession, fileIndex, dataSchema) {
override def build(): Scan = {
- TextScan(sparkSession, fileIndex, readDataSchema(), readPartitionSchema(), options)
+ TextScan(sparkSession, fileIndex, dataSchema, readDataSchema(), readPartitionSchema(), options,
+ partitionFilters, dataFilters)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
index 0b394db5c8932..9bf25aa0d633f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala
@@ -17,9 +17,12 @@
package org.apache.spark.sql.jdbc
-import java.sql.Types
+import java.sql.{SQLException, Types}
import java.util.Locale
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.types._
private object DB2Dialect extends JdbcDialect {
@@ -27,6 +30,37 @@ private object DB2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:db2")
+ // See https://www.ibm.com/docs/en/db2/11.5?topic=functions-aggregate
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VARIANCE($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VARIANCE_SAMP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV_SAMP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false =>
+ assert(f.children().length == 2)
+ Some(s"COVARIANCE(${f.children().head}, ${f.children().last})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false =>
+ assert(f.children().length == 2)
+ Some(s"COVARIANCE_SAMP(${f.children().head}, ${f.children().last})")
+ case _ => None
+ }
+ )
+ }
+
override def getCatalystType(
sqlType: Int,
typeName: String,
@@ -79,4 +113,28 @@ private object DB2Dialect extends JdbcDialect {
val nullable = if (isNullable) "DROP NOT NULL" else "SET NOT NULL"
s"ALTER TABLE $tableName ALTER COLUMN ${quoteIdentifier(columnName)} $nullable"
}
+
+ override def removeSchemaCommentQuery(schema: String): String = {
+ s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS ''"
+ }
+
+ override def classifyException(message: String, e: Throwable): AnalysisException = {
+ e match {
+ case sqlException: SQLException =>
+ sqlException.getSQLState match {
+ // https://www.ibm.com/docs/en/db2/11.5?topic=messages-sqlstate
+ case "42893" => throw NonEmptyNamespaceException(message, cause = Some(e))
+ case _ => super.classifyException(message, e)
+ }
+ case _ => super.classifyException(message, e)
+ }
+ }
+
+ override def dropSchema(schema: String, cascade: Boolean): String = {
+ if (cascade) {
+ s"DROP SCHEMA ${quoteIdentifier(schema)} CASCADE"
+ } else {
+ s"DROP SCHEMA ${quoteIdentifier(schema)} RESTRICT"
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
index 020733aaee8c0..36c3c6be4a05c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DerbyDialect.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.jdbc
import java.sql.Types
import java.util.Locale
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types._
@@ -29,6 +30,27 @@ private object DerbyDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:derby")
+ // See https://db.apache.org/derby/docs/10.15/ref/index.html
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"VAR_POP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"VAR_SAMP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"STDDEV_POP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"STDDEV_SAMP(${f.children().head})")
+ case _ => None
+ }
+ )
+ }
+
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.REAL) Option(FloatType) else None
@@ -47,7 +69,7 @@ private object DerbyDialect extends JdbcDialect {
override def isCascadingTruncateTable(): Option[Boolean] = Some(false)
- // See https://db.apache.org/derby/docs/10.5/ref/rrefsqljrenametablestatement.html
+ // See https://db.apache.org/derby/docs/10.15/ref/rrefsqljrenametablestatement.html
override def renameTable(oldTable: String, newTable: String): String = {
s"RENAME TABLE $oldTable TO $newTable"
}
@@ -57,4 +79,8 @@ private object DerbyDialect extends JdbcDialect {
override def getTableCommentQuery(table: String, comment: String): String = {
throw QueryExecutionErrors.commentOnTableUnsupportedError()
}
+
+ override def getLimitClause(limit: Integer): String = {
+ ""
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
index 9c727957ffab8..6681aee778dbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala
@@ -20,13 +20,76 @@ package org.apache.spark.sql.jdbc
import java.sql.SQLException
import java.util.Locale
+import scala.util.control.NonFatal
+
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{NoSuchNamespaceException, NoSuchTableException, TableAlreadyExistsException}
+import org.apache.spark.sql.connector.expressions.Expression
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
+import org.apache.spark.sql.errors.QueryCompilationErrors
private object H2Dialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:h2")
+ class H2SQLBuilder extends JDBCSQLBuilder {
+ override def visitSQLFunction(funcName: String, inputs: Array[String]): String = {
+ funcName match {
+ case "WIDTH_BUCKET" =>
+ val functionInfo = super.visitSQLFunction(funcName, inputs)
+ throw QueryCompilationErrors.noSuchFunctionError("H2", functionInfo)
+ case _ => super.visitSQLFunction(funcName, inputs)
+ }
+ }
+ }
+
+ override def compileExpression(expr: Expression): Option[String] = {
+ val h2SQLBuilder = new H2SQLBuilder()
+ try {
+ Some(h2SQLBuilder.build(expr))
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Error occurs while compiling V2 expression", e)
+ None
+ }
+ }
+
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VAR_POP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VAR_SAMP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV_POP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV_SAMP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_POP" =>
+ assert(f.children().length == 2)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" =>
+ assert(f.children().length == 2)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})")
+ case f: GeneralAggregateFunc if f.name() == "CORR" =>
+ assert(f.children().length == 2)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"CORR($distinct${f.children().head}, ${f.children().last})")
+ case _ => None
+ }
+ )
+ }
+
override def classifyException(message: String, e: Throwable): AnalysisException = {
if (e.isInstanceOf[SQLException]) {
// Error codes are from https://www.h2database.com/javadoc/org/h2/api/ErrorCode.html
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
index aa957113b5ca5..397942d7837db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
@@ -17,21 +17,30 @@
package org.apache.spark.sql.jdbc
-import java.sql.{Connection, Date, Timestamp}
+import java.sql.{Connection, Date, Driver, Statement, Timestamp}
import java.time.{Instant, LocalDate}
+import java.util
import scala.collection.mutable.ArrayBuilder
+import scala.util.control.NonFatal
import org.apache.commons.lang3.StringUtils
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter}
import org.apache.spark.sql.connector.catalog.TableChange
import org.apache.spark.sql.connector.catalog.TableChange._
+import org.apache.spark.sql.connector.catalog.index.TableIndex
+import org.apache.spark.sql.connector.expressions.{Expression, Literal, NamedReference}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, Max, Min, Sum}
+import org.apache.spark.sql.connector.util.V2ExpressionSQLBuilder
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
+import org.apache.spark.sql.execution.datasources.jdbc.{DriverRegistry, JDBCOptions, JdbcUtils}
+import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -94,6 +103,29 @@ abstract class JdbcDialect extends Serializable with Logging{
*/
def getJDBCType(dt: DataType): Option[JdbcType] = None
+ /**
+ * Returns a factory for creating connections to the given JDBC URL.
+ * In general, creating a connection has nothing to do with JDBC partition id.
+ * But sometimes it is needed, such as a database with multiple shard nodes.
+ * @param options - JDBC options that contains url, table and other information.
+ * @return The factory method for creating JDBC connections with the RDD partition ID. -1 means
+ the connection is being created at the driver side.
+ * @throws IllegalArgumentException if the driver could not open a JDBC connection.
+ */
+ @Since("3.3.0")
+ def createConnectionFactory(options: JDBCOptions): Int => Connection = {
+ val driverClass: String = options.driverClass
+ (partitionId: Int) => {
+ DriverRegistry.register(driverClass)
+ val driver: Driver = DriverRegistry.get(driverClass)
+ val connection =
+ ConnectionProvider.create(driver, options.parameters)
+ require(connection != null,
+ s"The driver could not open a JDBC connection. Check the URL: ${options.url}")
+ connection
+ }
+ }
+
/**
* Quotes the identifier. This is used to put quotes around the identifier in case the column
* name is a reserved keyword, or in case it contains characters that require quotes (e.g. space).
@@ -189,6 +221,110 @@ abstract class JdbcDialect extends Serializable with Logging{
case _ => value
}
+ class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
+ override def visitLiteral(literal: Literal[_]): String = {
+ compileValue(
+ CatalystTypeConverters.convertToScala(literal.value(), literal.dataType())).toString
+ }
+
+ override def visitNamedReference(namedRef: NamedReference): String = {
+ if (namedRef.fieldNames().length > 1) {
+ throw QueryCompilationErrors.commandNotSupportNestedColumnError(
+ "Filter push down", namedRef.toString)
+ }
+ quoteIdentifier(namedRef.fieldNames.head)
+ }
+
+ override def visitCast(l: String, dataType: DataType): String = {
+ val databaseTypeDefinition =
+ getJDBCType(dataType).map(_.databaseTypeDefinition).getOrElse(dataType.typeName)
+ s"CAST($l AS $databaseTypeDefinition)"
+ }
+ }
+
+ /**
+ * Converts V2 expression to String representing a SQL expression.
+ * @param expr The V2 expression to be converted.
+ * @return Converted value.
+ */
+ @Since("3.3.0")
+ def compileExpression(expr: Expression): Option[String] = {
+ val jdbcSQLBuilder = new JDBCSQLBuilder()
+ try {
+ Some(jdbcSQLBuilder.build(expr))
+ } catch {
+ case NonFatal(e) =>
+ logWarning("Error occurs while compiling V2 expression", e)
+ None
+ }
+ }
+
+ /**
+ * Converts aggregate function to String representing a SQL expression.
+ * @param aggFunction The aggregate function to be converted.
+ * @return Converted value.
+ */
+ @Since("3.3.0")
+ def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ aggFunction match {
+ case min: Min =>
+ compileExpression(min.column).map(v => s"MIN($v)")
+ case max: Max =>
+ compileExpression(max.column).map(v => s"MAX($v)")
+ case count: Count =>
+ val distinct = if (count.isDistinct) "DISTINCT " else ""
+ compileExpression(count.column).map(v => s"COUNT($distinct$v)")
+ case sum: Sum =>
+ val distinct = if (sum.isDistinct) "DISTINCT " else ""
+ compileExpression(sum.column).map(v => s"SUM($distinct$v)")
+ case _: CountStar =>
+ Some("COUNT(*)")
+ case avg: Avg =>
+ val distinct = if (avg.isDistinct) "DISTINCT " else ""
+ compileExpression(avg.column).map(v => s"AVG($distinct$v)")
+ case _ => None
+ }
+ }
+
+ /**
+ * Create schema with an optional comment. Empty string means no comment.
+ */
+ def createSchema(statement: Statement, schema: String, comment: String): Unit = {
+ val schemaCommentQuery = if (comment.nonEmpty) {
+ // We generate comment query here so that it can fail earlier without creating the schema.
+ getSchemaCommentQuery(schema, comment)
+ } else {
+ comment
+ }
+ statement.executeUpdate(s"CREATE SCHEMA ${quoteIdentifier(schema)}")
+ if (comment.nonEmpty) {
+ statement.executeUpdate(schemaCommentQuery)
+ }
+ }
+
+ /**
+ * Check schema exists or not.
+ */
+ def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = {
+ val rs = conn.getMetaData.getSchemas(null, schema)
+ while (rs.next()) {
+ if (rs.getString(1) == schema) return true;
+ }
+ false
+ }
+
+ /**
+ * Lists all the schemas in this table.
+ */
+ def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = {
+ val schemaBuilder = ArrayBuilder.make[Array[String]]
+ val rs = conn.getMetaData.getSchemas()
+ while (rs.next()) {
+ schemaBuilder += Array(rs.getString(1))
+ }
+ schemaBuilder.result
+ }
+
/**
* Return Some[true] iff `TRUNCATE TABLE` causes cascading default.
* Some[true] : TRUNCATE TABLE causes cascading.
@@ -287,6 +423,71 @@ abstract class JdbcDialect extends Serializable with Logging{
s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS NULL"
}
+ def dropSchema(schema: String, cascade: Boolean): String = {
+ if (cascade) {
+ s"DROP SCHEMA ${quoteIdentifier(schema)} CASCADE"
+ } else {
+ s"DROP SCHEMA ${quoteIdentifier(schema)}"
+ }
+ }
+
+ /**
+ * Build a create index SQL statement.
+ *
+ * @param indexName the name of the index to be created
+ * @param tableName the table on which index to be created
+ * @param columns the columns on which index to be created
+ * @param columnsProperties the properties of the columns on which index to be created
+ * @param properties the properties of the index to be created
+ * @return the SQL statement to use for creating the index.
+ */
+ def createIndex(
+ indexName: String,
+ tableName: String,
+ columns: Array[NamedReference],
+ columnsProperties: util.Map[NamedReference, util.Map[String, String]],
+ properties: util.Map[String, String]): String = {
+ throw new UnsupportedOperationException("createIndex is not supported")
+ }
+
+ /**
+ * Checks whether an index exists
+ *
+ * @param indexName the name of the index
+ * @param tableName the table name on which index to be checked
+ * @param options JDBCOptions of the table
+ * @return true if the index with `indexName` exists in the table with `tableName`,
+ * false otherwise
+ */
+ def indexExists(
+ conn: Connection,
+ indexName: String,
+ tableName: String,
+ options: JDBCOptions): Boolean = {
+ throw new UnsupportedOperationException("indexExists is not supported")
+ }
+
+ /**
+ * Build a drop index SQL statement.
+ *
+ * @param indexName the name of the index to be dropped.
+ * @param tableName the table name on which index to be dropped.
+ * @return the SQL statement to use for dropping the index.
+ */
+ def dropIndex(indexName: String, tableName: String): String = {
+ throw new UnsupportedOperationException("dropIndex is not supported")
+ }
+
+ /**
+ * Lists all the indexes in this table.
+ */
+ def listIndexes(
+ conn: Connection,
+ tableName: String,
+ options: JDBCOptions): Array[TableIndex] = {
+ throw new UnsupportedOperationException("listIndexes is not supported")
+ }
+
/**
* Gets a dialect exception, classifies it and wraps it by `AnalysisException`.
* @param message The error message to be placed to the returned exception.
@@ -296,6 +497,18 @@ abstract class JdbcDialect extends Serializable with Logging{
def classifyException(message: String, e: Throwable): AnalysisException = {
new AnalysisException(message, cause = Some(e))
}
+
+ /**
+ * returns the LIMIT clause for the SELECT statement
+ */
+ def getLimitClause(limit: Integer): String = {
+ if (limit > 0 ) s"LIMIT $limit" else ""
+ }
+
+ def supportsTableSample: Boolean = false
+
+ def getTableSample(sample: TableSampleInfo): String =
+ throw new UnsupportedOperationException("TableSample is not supported by this data source")
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
index ea9834830e373..8d2fbec55f919 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MsSqlServerDialect.scala
@@ -17,8 +17,12 @@
package org.apache.spark.sql.jdbc
+import java.sql.SQLException
import java.util.Locale
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.NonEmptyNamespaceException
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -36,6 +40,33 @@ private object MsSqlServerDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:sqlserver")
+ // scalastyle:off line.size.limit
+ // See https://docs.microsoft.com/en-us/sql/t-sql/functions/aggregate-functions-transact-sql?view=sql-server-ver15
+ // scalastyle:on line.size.limit
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VARP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VAR($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDEVP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDEV($distinct${f.children().head})")
+ case _ => None
+ }
+ )
+ }
+
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (typeName.contains("datetimeoffset")) {
@@ -118,4 +149,19 @@ private object MsSqlServerDialect extends JdbcDialect {
override def getTableCommentQuery(table: String, comment: String): String = {
throw QueryExecutionErrors.commentOnTableUnsupportedError()
}
+
+ override def getLimitClause(limit: Integer): String = {
+ ""
+ }
+
+ override def classifyException(message: String, e: Throwable): AnalysisException = {
+ e match {
+ case sqlException: SQLException =>
+ sqlException.getErrorCode match {
+ case 3729 => throw NonEmptyNamespaceException(message, cause = Some(e))
+ case _ => super.classifyException(message, e)
+ }
+ case _ => super.classifyException(message, e)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
index ed107707c9d1f..24f9bac74f86d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala
@@ -17,18 +17,48 @@
package org.apache.spark.sql.jdbc
-import java.sql.Types
+import java.sql.{Connection, SQLException, Types}
+import java.util
import java.util.Locale
+import scala.collection.mutable.ArrayBuilder
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NoSuchIndexException}
+import org.apache.spark.sql.connector.catalog.index.TableIndex
+import org.apache.spark.sql.connector.expressions.{FieldReference, NamedReference}
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.errors.QueryExecutionErrors
-import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
import org.apache.spark.sql.types.{BooleanType, DataType, FloatType, LongType, MetadataBuilder}
-private case object MySQLDialect extends JdbcDialect {
+private case object MySQLDialect extends JdbcDialect with SQLConfHelper {
override def canHandle(url : String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:mysql")
+ // See https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"VAR_POP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"VAR_SAMP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"STDDEV_POP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"STDDEV_SAMP(${f.children().head})")
+ case _ => None
+ }
+ )
+ }
+
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.VARBINARY && typeName.equals("BIT") && size != 1) {
@@ -45,6 +75,25 @@ private case object MySQLDialect extends JdbcDialect {
s"`$colName`"
}
+ override def schemasExists(conn: Connection, options: JDBCOptions, schema: String): Boolean = {
+ listSchemas(conn, options).exists(_.head == schema)
+ }
+
+ override def listSchemas(conn: Connection, options: JDBCOptions): Array[Array[String]] = {
+ val schemaBuilder = ArrayBuilder.make[Array[String]]
+ try {
+ JdbcUtils.executeQuery(conn, options, "SHOW SCHEMAS") { rs =>
+ while (rs.next()) {
+ schemaBuilder += Array(rs.getString("Database"))
+ }
+ }
+ } catch {
+ case _: Exception =>
+ logWarning("Cannot show schemas.")
+ }
+ schemaBuilder.result
+ }
+
override def getTableExistsQuery(table: String): String = {
s"SELECT 1 FROM $table LIMIT 1"
}
@@ -102,4 +151,107 @@ private case object MySQLDialect extends JdbcDialect {
case FloatType => Option(JdbcType("FLOAT", java.sql.Types.FLOAT))
case _ => JdbcUtils.getCommonJDBCType(dt)
}
+
+ override def getSchemaCommentQuery(schema: String, comment: String): String = {
+ throw QueryExecutionErrors.unsupportedCreateNamespaceCommentError()
+ }
+
+ override def removeSchemaCommentQuery(schema: String): String = {
+ throw QueryExecutionErrors.unsupportedRemoveNamespaceCommentError()
+ }
+
+ // CREATE INDEX syntax
+ // https://dev.mysql.com/doc/refman/8.0/en/create-index.html
+ override def createIndex(
+ indexName: String,
+ tableName: String,
+ columns: Array[NamedReference],
+ columnsProperties: util.Map[NamedReference, util.Map[String, String]],
+ properties: util.Map[String, String]): String = {
+ val columnList = columns.map(col => quoteIdentifier(col.fieldNames.head))
+ val (indexType, indexPropertyList) = JdbcUtils.processIndexProperties(properties, "mysql")
+
+ // columnsProperties doesn't apply to MySQL so it is ignored
+ s"CREATE INDEX ${quoteIdentifier(indexName)} $indexType ON" +
+ s" ${quoteIdentifier(tableName)} (${columnList.mkString(", ")})" +
+ s" ${indexPropertyList.mkString(" ")}"
+ }
+
+ // SHOW INDEX syntax
+ // https://dev.mysql.com/doc/refman/8.0/en/show-index.html
+ override def indexExists(
+ conn: Connection,
+ indexName: String,
+ tableName: String,
+ options: JDBCOptions): Boolean = {
+ val sql = s"SHOW INDEXES FROM ${quoteIdentifier(tableName)} WHERE key_name = '$indexName'"
+ JdbcUtils.checkIfIndexExists(conn, sql, options)
+ }
+
+ override def dropIndex(indexName: String, tableName: String): String = {
+ s"DROP INDEX ${quoteIdentifier(indexName)} ON $tableName"
+ }
+
+ // SHOW INDEX syntax
+ // https://dev.mysql.com/doc/refman/8.0/en/show-index.html
+ override def listIndexes(
+ conn: Connection,
+ tableName: String,
+ options: JDBCOptions): Array[TableIndex] = {
+ val sql = s"SHOW INDEXES FROM $tableName"
+ var indexMap: Map[String, TableIndex] = Map()
+ try {
+ JdbcUtils.executeQuery(conn, options, sql) { rs =>
+ while (rs.next()) {
+ val indexName = rs.getString("key_name")
+ val colName = rs.getString("column_name")
+ val indexType = rs.getString("index_type")
+ val indexComment = rs.getString("Index_comment")
+ if (indexMap.contains(indexName)) {
+ val index = indexMap.get(indexName).get
+ val newIndex = new TableIndex(indexName, indexType,
+ index.columns() :+ FieldReference(colName),
+ index.columnProperties, index.properties)
+ indexMap += (indexName -> newIndex)
+ } else {
+ // The only property we are building here is `COMMENT` because it's the only one
+ // we can get from `SHOW INDEXES`.
+ val properties = new util.Properties();
+ if (indexComment.nonEmpty) properties.put("COMMENT", indexComment)
+ val index = new TableIndex(indexName, indexType, Array(FieldReference(colName)),
+ new util.HashMap[NamedReference, util.Properties](), properties)
+ indexMap += (indexName -> index)
+ }
+ }
+ }
+ } catch {
+ case _: Exception =>
+ logWarning("Cannot retrieved index info.")
+ }
+ indexMap.values.toArray
+ }
+
+ override def classifyException(message: String, e: Throwable): AnalysisException = {
+ e match {
+ case sqlException: SQLException =>
+ sqlException.getErrorCode match {
+ // ER_DUP_KEYNAME
+ case 1061 =>
+ throw new IndexAlreadyExistsException(message, cause = Some(e))
+ case 1091 =>
+ throw new NoSuchIndexException(message, cause = Some(e))
+ case _ => super.classifyException(message, e)
+ }
+ case unsupported: UnsupportedOperationException => throw unsupported
+ case _ => super.classifyException(message, e)
+ }
+ }
+
+ override def dropSchema(schema: String, cascade: Boolean): String = {
+ if (cascade) {
+ s"DROP SCHEMA ${quoteIdentifier(schema)}"
+ } else {
+ throw QueryExecutionErrors.unsupportedDropNamespaceRestrictError()
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
index b741ece8dda9b..40333c1757c4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp, Types}
import java.util.{Locale, TimeZone}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
@@ -33,6 +34,38 @@ private case object OracleDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:oracle")
+ // scalastyle:off line.size.limit
+ // https://docs.oracle.com/en/database/oracle/oracle-database/19/sqlrf/Aggregate-Functions.html#GUID-62BE676B-AF18-4E63-BD14-25206FEA0848
+ // scalastyle:on line.size.limit
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: GeneralAggregateFunc if f.name() == "VAR_POP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"VAR_POP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"VAR_SAMP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"STDDEV_POP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" && f.isDistinct == false =>
+ assert(f.children().length == 1)
+ Some(s"STDDEV_SAMP(${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false =>
+ assert(f.children().length == 2)
+ Some(s"COVAR_POP(${f.children().head}, ${f.children().last})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false =>
+ assert(f.children().length == 2)
+ Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})")
+ case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false =>
+ assert(f.children().length == 2)
+ Some(s"CORR(${f.children().head}, ${f.children().last})")
+ case _ => None
+ }
+ )
+ }
+
private def supportTimeZoneTypes: Boolean = {
val timeZone = DateTimeUtils.getTimeZone(SQLConf.get.sessionLocalTimeZone)
// TODO: support timezone types when users are not using the JVM timezone, which
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 3ce785ed844c5..a668d66ee2f9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -17,18 +17,62 @@
package org.apache.spark.sql.jdbc
-import java.sql.{Connection, Types}
+import java.sql.{Connection, SQLException, Types}
+import java.util
import java.util.Locale
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.analysis.{IndexAlreadyExistsException, NonEmptyNamespaceException, NoSuchIndexException}
+import org.apache.spark.sql.connector.expressions.NamedReference
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JdbcUtils}
+import org.apache.spark.sql.execution.datasources.v2.TableSampleInfo
import org.apache.spark.sql.types._
-private object PostgresDialect extends JdbcDialect {
+private object PostgresDialect extends JdbcDialect with SQLConfHelper {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:postgresql")
+ // See https://www.postgresql.org/docs/8.4/functions-aggregate.html
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VAR_POP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VAR_SAMP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV_POP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV_SAMP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_POP" =>
+ assert(f.children().length == 2)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"COVAR_POP($distinct${f.children().head}, ${f.children().last})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" =>
+ assert(f.children().length == 2)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"COVAR_SAMP($distinct${f.children().head}, ${f.children().last})")
+ case f: GeneralAggregateFunc if f.name() == "CORR" =>
+ assert(f.children().length == 2)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"CORR($distinct${f.children().head}, ${f.children().last})")
+ case _ => None
+ }
+ )
+ }
+
override def getCatalystType(
sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = {
if (sqlType == Types.REAL) {
@@ -154,4 +198,66 @@ private object PostgresDialect extends JdbcDialect {
val nullable = if (isNullable) "DROP NOT NULL" else "SET NOT NULL"
s"ALTER TABLE $tableName ALTER COLUMN ${quoteIdentifier(columnName)} $nullable"
}
+
+ override def supportsTableSample: Boolean = true
+
+ override def getTableSample(sample: TableSampleInfo): String = {
+ // hard-coded to BERNOULLI for now because Spark doesn't have a way to specify sample
+ // method name
+ s"TABLESAMPLE BERNOULLI" +
+ s" (${(sample.upperBound - sample.lowerBound) * 100}) REPEATABLE (${sample.seed})"
+ }
+
+ // CREATE INDEX syntax
+ // https://www.postgresql.org/docs/14/sql-createindex.html
+ override def createIndex(
+ indexName: String,
+ tableName: String,
+ columns: Array[NamedReference],
+ columnsProperties: util.Map[NamedReference, util.Map[String, String]],
+ properties: util.Map[String, String]): String = {
+ val columnList = columns.map(col => quoteIdentifier(col.fieldNames.head))
+ var indexProperties = ""
+ val (indexType, indexPropertyList) = JdbcUtils.processIndexProperties(properties, "postgresql")
+
+ if (indexPropertyList.nonEmpty) {
+ indexProperties = "WITH (" + indexPropertyList.mkString(", ") + ")"
+ }
+
+ s"CREATE INDEX ${quoteIdentifier(indexName)} ON ${quoteIdentifier(tableName)}" +
+ s" $indexType (${columnList.mkString(", ")}) $indexProperties"
+ }
+
+ // SHOW INDEX syntax
+ // https://www.postgresql.org/docs/14/view-pg-indexes.html
+ override def indexExists(
+ conn: Connection,
+ indexName: String,
+ tableName: String,
+ options: JDBCOptions): Boolean = {
+ val sql = s"SELECT * FROM pg_indexes WHERE tablename = '$tableName' AND" +
+ s" indexname = '$indexName'"
+ JdbcUtils.checkIfIndexExists(conn, sql, options)
+ }
+
+ // DROP INDEX syntax
+ // https://www.postgresql.org/docs/14/sql-dropindex.html
+ override def dropIndex(indexName: String, tableName: String): String = {
+ s"DROP INDEX ${quoteIdentifier(indexName)}"
+ }
+
+ override def classifyException(message: String, e: Throwable): AnalysisException = {
+ e match {
+ case sqlException: SQLException =>
+ sqlException.getSQLState match {
+ // https://www.postgresql.org/docs/14/errcodes-appendix.html
+ case "42P07" => throw new IndexAlreadyExistsException(message, cause = Some(e))
+ case "42704" => throw new NoSuchIndexException(message, cause = Some(e))
+ case "2BP01" => throw NonEmptyNamespaceException(message, cause = Some(e))
+ case _ => super.classifyException(message, e)
+ }
+ case unsupported: UnsupportedOperationException => throw unsupported
+ case _ => super.classifyException(message, e)
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
index 58fe62cb6e088..79fb710cf03b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/TeradataDialect.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.jdbc
import java.util.Locale
+import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, GeneralAggregateFunc}
import org.apache.spark.sql.types._
@@ -27,6 +28,42 @@ private case object TeradataDialect extends JdbcDialect {
override def canHandle(url: String): Boolean =
url.toLowerCase(Locale.ROOT).startsWith("jdbc:teradata")
+ // scalastyle:off line.size.limit
+ // See https://docs.teradata.com/r/Teradata-VantageTM-SQL-Functions-Expressions-and-Predicates/March-2019/Aggregate-Functions
+ // scalastyle:on line.size.limit
+ override def compileAggregate(aggFunction: AggregateFunc): Option[String] = {
+ super.compileAggregate(aggFunction).orElse(
+ aggFunction match {
+ case f: GeneralAggregateFunc if f.name() == "VAR_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VAR_POP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "VAR_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"VAR_SAMP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_POP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV_POP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "STDDEV_SAMP" =>
+ assert(f.children().length == 1)
+ val distinct = if (f.isDistinct) "DISTINCT " else ""
+ Some(s"STDDEV_SAMP($distinct${f.children().head})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_POP" && f.isDistinct == false =>
+ assert(f.children().length == 2)
+ Some(s"COVAR_POP(${f.children().head}, ${f.children().last})")
+ case f: GeneralAggregateFunc if f.name() == "COVAR_SAMP" && f.isDistinct == false =>
+ assert(f.children().length == 2)
+ Some(s"COVAR_SAMP(${f.children().head}, ${f.children().last})")
+ case f: GeneralAggregateFunc if f.name() == "CORR" && f.isDistinct == false =>
+ assert(f.children().length == 2)
+ Some(s"CORR(${f.children().head}, ${f.children().last})")
+ case _ => None
+ }
+ )
+ }
+
override def getJDBCType(dt: DataType): Option[JdbcType] = dt match {
case StringType => Some(JdbcType("VARCHAR(255)", java.sql.Types.VARCHAR))
case BooleanType => Option(JdbcType("CHAR(1)", java.sql.Types.CHAR))
@@ -55,4 +92,8 @@ private case object TeradataDialect extends JdbcDialect {
override def renameTable(oldTable: String, newTable: String): String = {
s"RENAME TABLE $oldTable TO $newTable"
}
+
+ override def getLimitClause(limit: Integer): String = {
+ ""
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java
new file mode 100644
index 0000000000000..ec532da61042f
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/connector/JavaAdvancedDataSourceV2WithV2Filter.java
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.connector;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
+import org.apache.spark.sql.connector.TestingV2Source;
+import org.apache.spark.sql.connector.catalog.Table;
+import org.apache.spark.sql.connector.expressions.FieldReference;
+import org.apache.spark.sql.connector.expressions.Literal;
+import org.apache.spark.sql.connector.expressions.LiteralValue;
+import org.apache.spark.sql.connector.expressions.filter.Predicate;
+import org.apache.spark.sql.connector.read.*;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.util.CaseInsensitiveStringMap;
+
+public class JavaAdvancedDataSourceV2WithV2Filter implements TestingV2Source {
+
+ @Override
+ public Table getTable(CaseInsensitiveStringMap options) {
+ return new JavaSimpleBatchTable() {
+ @Override
+ public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
+ return new AdvancedScanBuilderWithV2Filter();
+ }
+ };
+ }
+
+ static class AdvancedScanBuilderWithV2Filter implements ScanBuilder, Scan,
+ SupportsPushDownV2Filters, SupportsPushDownRequiredColumns {
+
+ private StructType requiredSchema = TestingV2Source.schema();
+ private Predicate[] predicates = new Predicate[0];
+
+ @Override
+ public void pruneColumns(StructType requiredSchema) {
+ this.requiredSchema = requiredSchema;
+ }
+
+ @Override
+ public StructType readSchema() {
+ return requiredSchema;
+ }
+
+ @Override
+ public Predicate[] pushPredicates(Predicate[] predicates) {
+ Predicate[] supported = Arrays.stream(predicates).filter(f -> {
+ if (f.name().equals(">")) {
+ assert(f.children()[0] instanceof FieldReference);
+ FieldReference column = (FieldReference) f.children()[0];
+ assert(f.children()[1] instanceof LiteralValue);
+ Literal value = (Literal) f.children()[1];
+ return column.describe().equals("i") && value.value() instanceof Integer;
+ } else {
+ return false;
+ }
+ }).toArray(Predicate[]::new);
+
+ Predicate[] unsupported = Arrays.stream(predicates).filter(f -> {
+ if (f.name().equals(">")) {
+ assert(f.children()[0] instanceof FieldReference);
+ FieldReference column = (FieldReference) f.children()[0];
+ assert(f.children()[1] instanceof LiteralValue);
+ Literal value = (LiteralValue) f.children()[1];
+ return !column.describe().equals("i") || !(value.value() instanceof Integer);
+ } else {
+ return true;
+ }
+ }).toArray(Predicate[]::new);
+
+ this.predicates = supported;
+ return unsupported;
+ }
+
+ @Override
+ public Predicate[] pushedPredicates() {
+ return predicates;
+ }
+
+ @Override
+ public Scan build() {
+ return this;
+ }
+
+ @Override
+ public Batch toBatch() {
+ return new AdvancedBatchWithV2Filter(requiredSchema, predicates);
+ }
+ }
+
+ public static class AdvancedBatchWithV2Filter implements Batch {
+ // Exposed for testing.
+ public StructType requiredSchema;
+ public Predicate[] predicates;
+
+ AdvancedBatchWithV2Filter(StructType requiredSchema, Predicate[] predicates) {
+ this.requiredSchema = requiredSchema;
+ this.predicates = predicates;
+ }
+
+ @Override
+ public InputPartition[] planInputPartitions() {
+ List res = new ArrayList<>();
+
+ Integer lowerBound = null;
+ for (Predicate predicate : predicates) {
+ if (predicate.name().equals(">")) {
+ assert(predicate.children()[0] instanceof FieldReference);
+ FieldReference column = (FieldReference) predicate.children()[0];
+ assert(predicate.children()[1] instanceof LiteralValue);
+ Literal value = (Literal) predicate.children()[1];
+ if ("i".equals(column.describe()) && value.value() instanceof Integer) {
+ lowerBound = (Integer) value.value();
+ break;
+ }
+ }
+ }
+
+ if (lowerBound == null) {
+ res.add(new JavaRangeInputPartition(0, 5));
+ res.add(new JavaRangeInputPartition(5, 10));
+ } else if (lowerBound < 4) {
+ res.add(new JavaRangeInputPartition(lowerBound + 1, 5));
+ res.add(new JavaRangeInputPartition(5, 10));
+ } else if (lowerBound < 9) {
+ res.add(new JavaRangeInputPartition(lowerBound + 1, 10));
+ }
+
+ return res.stream().toArray(InputPartition[]::new);
+ }
+
+ @Override
+ public PartitionReaderFactory createReaderFactory() {
+ return new AdvancedReaderFactoryWithV2Filter(requiredSchema);
+ }
+ }
+
+ static class AdvancedReaderFactoryWithV2Filter implements PartitionReaderFactory {
+ StructType requiredSchema;
+
+ AdvancedReaderFactoryWithV2Filter(StructType requiredSchema) {
+ this.requiredSchema = requiredSchema;
+ }
+
+ @Override
+ public PartitionReader createReader(InputPartition partition) {
+ JavaRangeInputPartition p = (JavaRangeInputPartition) partition;
+ return new PartitionReader() {
+ private int current = p.start - 1;
+
+ @Override
+ public boolean next() throws IOException {
+ current += 1;
+ return current < p.end;
+ }
+
+ @Override
+ public InternalRow get() {
+ Object[] values = new Object[requiredSchema.size()];
+ for (int i = 0; i < values.length; i++) {
+ if ("i".equals(requiredSchema.apply(i).name())) {
+ values[i] = current;
+ } else if ("j".equals(requiredSchema.apply(i).name())) {
+ values[i] = -current;
+ }
+ }
+ return new GenericInternalRow(values);
+ }
+
+ @Override
+ public void close() throws IOException {
+
+ }
+ };
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
index 001b6a00af52f..910f159cc49a0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala
@@ -731,6 +731,28 @@ class FileBasedDataSourceSuite extends QueryTest
}
}
+ test("SPARK-36568: FileScan statistics estimation takes read schema into account") {
+ withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
+ withTempDir { dir =>
+ spark.range(1000).map(x => (x / 100, x, x)).toDF("k", "v1", "v2").
+ write.partitionBy("k").mode(SaveMode.Overwrite).orc(dir.toString)
+ val dfAll = spark.read.orc(dir.toString)
+ val dfK = dfAll.select("k")
+ val dfV1 = dfAll.select("v1")
+ val dfV2 = dfAll.select("v2")
+ val dfV1V2 = dfAll.select("v1", "v2")
+
+ def sizeInBytes(df: DataFrame): BigInt = df.queryExecution.optimizedPlan.stats.sizeInBytes
+
+ assert(sizeInBytes(dfAll) === BigInt(getLocalDirSize(dir)))
+ assert(sizeInBytes(dfK) < sizeInBytes(dfAll))
+ assert(sizeInBytes(dfV1) < sizeInBytes(dfAll))
+ assert(sizeInBytes(dfV2) === sizeInBytes(dfV1))
+ assert(sizeInBytes(dfV1V2) < sizeInBytes(dfAll))
+ }
+ }
+ }
+
test("File source v2: support partition pruning") {
withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") {
allFileBasedDataSources.foreach { format =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
index 4e7fe8455ff93..14b59ba23d09f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/FileScanSuite.scala
@@ -354,11 +354,11 @@ class FileScanSuite extends FileScanSuiteBase {
val scanBuilders = Seq[(String, ScanBuilder, Seq[String])](
("ParquetScan",
(s, fi, ds, rds, rps, f, o, pf, df) =>
- ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, pf, df),
+ ParquetScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, f, o, None, pf, df),
Seq.empty),
("OrcScan",
(s, fi, ds, rds, rps, f, o, pf, df) =>
- OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, f, pf, df),
+ OrcScan(s, s.sessionState.newHadoopConf(), fi, ds, rds, rps, o, None, f, pf, df),
Seq.empty),
("CSVScan",
(s, fi, ds, rds, rps, f, o, pf, df) => CSVScan(s, fi, ds, rds, rps, o, f, pf, df),
@@ -367,7 +367,7 @@ class FileScanSuite extends FileScanSuiteBase {
(s, fi, ds, rds, rps, f, o, pf, df) => JsonScan(s, fi, ds, rds, rps, o, f, pf, df),
Seq.empty),
("TextScan",
- (s, fi, _, rds, rps, _, o, pf, df) => TextScan(s, fi, rds, rps, o, pf, df),
+ (s, fi, ds, rds, rps, _, o, pf, df) => TextScan(s, fi, ds, rds, rps, o, pf, df),
Seq("dataSchema", "pushedFilters")))
run(scanBuilders)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
index 91ac7db335cc3..e9c8131fe9bec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSessionCatalogSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.{DataFrame, QueryTest, SaveMode}
@@ -97,7 +95,7 @@ class InMemoryTableSessionCatalog extends TestV2SessionCatalogBase[InMemoryTable
name: String,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): InMemoryTable = {
+ properties: java.util.Map[String, String]): InMemoryTable = {
new InMemoryTable(name, schema, partitions, properties)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index d5417be0f229f..e4ba33c619a7b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
import java.util.Collections
import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage, JavaLongAdd, JavaStrLen}
@@ -35,7 +34,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
class DataSourceV2FunctionSuite extends DatasourceV2SQLBase {
- private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String]
+ private val emptyProps: java.util.Map[String, String] = Collections.emptyMap[String, String]
private def addFunction(ident: Identifier, fn: UnboundFunction): Unit = {
catalog("testcat").asInstanceOf[InMemoryCatalog].createFunction(ident, fn)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
index a326b82dbaf1e..7b941ab0d8f7d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala
@@ -1609,6 +1609,24 @@ class DataSourceV2SQLSuite
}
}
+ test("create table using - with sorted bucket") {
+ val identifier = "testcat.table_name"
+ withTable(identifier) {
+ sql(s"CREATE TABLE $identifier (a int, b string, c int) USING $v2Source PARTITIONED BY (c)" +
+ s" CLUSTERED BY (b) SORTED by (a) INTO 4 BUCKETS")
+ val table = getTableMetadata(identifier)
+ val describe = spark.sql(s"DESCRIBE $identifier")
+ val part1 = describe
+ .filter("col_name = 'Part 0'")
+ .select("data_type").head.getString(0)
+ assert(part1 === "c")
+ val part2 = describe
+ .filter("col_name = 'Part 1'")
+ .select("data_type").head.getString(0)
+ assert(part2 === "bucket(4, b, a)")
+ }
+ }
+
test("REFRESH TABLE: v2 table") {
val t = "testcat.ns1.ns2.tbl"
withTable(t) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
index b42d48d873fee..cff58d7367317 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2Suite.scala
@@ -18,11 +18,8 @@
package org.apache.spark.sql.connector
import java.io.File
-import java.util
import java.util.OptionalLong
-import scala.collection.JavaConverters._
-
import test.org.apache.spark.sql.connector._
import org.apache.spark.SparkException
@@ -30,7 +27,8 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability, TableProvider}
import org.apache.spark.sql.connector.catalog.TableCapability._
-import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.expressions.{Literal, Transform}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.read.partitioning.{ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -54,6 +52,13 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}.head
}
+ private def getBatchWithV2Filter(query: DataFrame): AdvancedBatchWithV2Filter = {
+ query.queryExecution.executedPlan.collect {
+ case d: BatchScanExec =>
+ d.batch.asInstanceOf[AdvancedBatchWithV2Filter]
+ }.head
+ }
+
private def getJavaBatch(query: DataFrame): JavaAdvancedDataSourceV2.AdvancedBatch = {
query.queryExecution.executedPlan.collect {
case d: BatchScanExec =>
@@ -61,6 +66,14 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}.head
}
+ private def getJavaBatchWithV2Filter(
+ query: DataFrame): JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter = {
+ query.queryExecution.executedPlan.collect {
+ case d: BatchScanExec =>
+ d.batch.asInstanceOf[JavaAdvancedDataSourceV2WithV2Filter.AdvancedBatchWithV2Filter]
+ }.head
+ }
+
test("simplest implementation") {
Seq(classOf[SimpleDataSourceV2], classOf[JavaSimpleDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
@@ -131,6 +144,66 @@ class DataSourceV2Suite extends QueryTest with SharedSparkSession with AdaptiveS
}
}
+ test("advanced implementation with V2 Filter") {
+ Seq(classOf[AdvancedDataSourceV2WithV2Filter], classOf[JavaAdvancedDataSourceV2WithV2Filter])
+ .foreach { cls =>
+ withClue(cls.getName) {
+ val df = spark.read.format(cls.getName).load()
+ checkAnswer(df, (0 until 10).map(i => Row(i, -i)))
+
+ val q1 = df.select('j)
+ checkAnswer(q1, (0 until 10).map(i => Row(-i)))
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) {
+ val batch = getBatchWithV2Filter(q1)
+ assert(batch.predicates.isEmpty)
+ assert(batch.requiredSchema.fieldNames === Seq("j"))
+ } else {
+ val batch = getJavaBatchWithV2Filter(q1)
+ assert(batch.predicates.isEmpty)
+ assert(batch.requiredSchema.fieldNames === Seq("j"))
+ }
+
+ val q2 = df.filter('i > 3)
+ checkAnswer(q2, (4 until 10).map(i => Row(i, -i)))
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) {
+ val batch = getBatchWithV2Filter(q2)
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i"))
+ assert(batch.requiredSchema.fieldNames === Seq("i", "j"))
+ } else {
+ val batch = getJavaBatchWithV2Filter(q2)
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i"))
+ assert(batch.requiredSchema.fieldNames === Seq("i", "j"))
+ }
+
+ val q3 = df.select('i).filter('i > 6)
+ checkAnswer(q3, (7 until 10).map(i => Row(i)))
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) {
+ val batch = getBatchWithV2Filter(q3)
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i"))
+ assert(batch.requiredSchema.fieldNames === Seq("i"))
+ } else {
+ val batch = getJavaBatchWithV2Filter(q3)
+ assert(batch.predicates.flatMap(_.references.map(_.describe)).toSet == Set("i"))
+ assert(batch.requiredSchema.fieldNames === Seq("i"))
+ }
+
+ val q4 = df.select('j).filter('j < -10)
+ checkAnswer(q4, Nil)
+ if (cls == classOf[AdvancedDataSourceV2WithV2Filter]) {
+ val batch = getBatchWithV2Filter(q4)
+ // 'j < 10 is not supported by the testing data source.
+ assert(batch.predicates.isEmpty)
+ assert(batch.requiredSchema.fieldNames === Seq("j"))
+ } else {
+ val batch = getJavaBatchWithV2Filter(q4)
+ // 'j < 10 is not supported by the testing data source.
+ assert(batch.predicates.isEmpty)
+ assert(batch.requiredSchema.fieldNames === Seq("j"))
+ }
+ }
+ }
+ }
+
test("columnar batch scan implementation") {
Seq(classOf[ColumnarDataSourceV2], classOf[JavaColumnarDataSourceV2]).foreach { cls =>
withClue(cls.getName) {
@@ -466,7 +539,7 @@ abstract class SimpleBatchTable extends Table with SupportsRead {
override def name(): String = this.getClass.toString
- override def capabilities(): util.Set[TableCapability] = Set(BATCH_READ).asJava
+ override def capabilities(): java.util.Set[TableCapability] = java.util.EnumSet.of(BATCH_READ)
}
abstract class SimpleScanBuilder extends ScanBuilder
@@ -489,7 +562,7 @@ trait TestingV2Source extends TableProvider {
override def getTable(
schema: StructType,
partitioning: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
getTable(new CaseInsensitiveStringMap(properties))
}
@@ -597,6 +670,75 @@ class AdvancedBatch(val filters: Array[Filter], val requiredSchema: StructType)
}
}
+class AdvancedDataSourceV2WithV2Filter extends TestingV2Source {
+
+ override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable {
+ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
+ new AdvancedScanBuilderWithV2Filter()
+ }
+ }
+}
+
+class AdvancedScanBuilderWithV2Filter extends ScanBuilder
+ with Scan with SupportsPushDownV2Filters with SupportsPushDownRequiredColumns {
+
+ var requiredSchema = TestingV2Source.schema
+ var predicates = Array.empty[Predicate]
+
+ override def pruneColumns(requiredSchema: StructType): Unit = {
+ this.requiredSchema = requiredSchema
+ }
+
+ override def readSchema(): StructType = requiredSchema
+
+ override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
+ val (supported, unsupported) = predicates.partition {
+ case p: Predicate if p.name() == ">" => true
+ case _ => false
+ }
+ this.predicates = supported
+ unsupported
+ }
+
+ override def pushedPredicates(): Array[Predicate] = predicates
+
+ override def build(): Scan = this
+
+ override def toBatch: Batch = new AdvancedBatchWithV2Filter(predicates, requiredSchema)
+}
+
+class AdvancedBatchWithV2Filter(
+ val predicates: Array[Predicate],
+ val requiredSchema: StructType) extends Batch {
+
+ override def planInputPartitions(): Array[InputPartition] = {
+ val lowerBound = predicates.collectFirst {
+ case p: Predicate if p.name().equals(">") =>
+ val value = p.children()(1)
+ assert(value.isInstanceOf[Literal[_]])
+ value.asInstanceOf[Literal[_]]
+ }
+
+ val res = scala.collection.mutable.ArrayBuffer.empty[InputPartition]
+
+ if (lowerBound.isEmpty) {
+ res.append(RangeInputPartition(0, 5))
+ res.append(RangeInputPartition(5, 10))
+ } else if (lowerBound.get.value.asInstanceOf[Integer] < 4) {
+ res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 5))
+ res.append(RangeInputPartition(5, 10))
+ } else if (lowerBound.get.value.asInstanceOf[Integer] < 9) {
+ res.append(RangeInputPartition(lowerBound.get.value.asInstanceOf[Integer] + 1, 10))
+ }
+
+ res.toArray
+ }
+
+ override def createReaderFactory(): PartitionReaderFactory = {
+ new AdvancedReaderFactory(requiredSchema)
+ }
+}
+
class AdvancedReaderFactory(requiredSchema: StructType) extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
@@ -640,7 +782,7 @@ class SchemaRequiredDataSource extends TableProvider {
override def getTable(
schema: StructType,
partitioning: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
val userGivenSchema = schema
new SimpleBatchTable {
override def schema(): StructType = userGivenSchema
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala
index db71eeb75eae0..e3d61a846fdb4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/LocalScanSuite.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
-import scala.collection.JavaConverters._
-
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability}
@@ -63,7 +59,7 @@ class TestLocalScanCatalog extends BasicInMemoryTableCatalog {
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
val table = new TestLocalScanTable(ident.toString)
tables.put(ident, table)
table
@@ -78,7 +74,8 @@ object TestLocalScanTable {
class TestLocalScanTable(override val name: String) extends Table with SupportsRead {
override def schema(): StructType = TestLocalScanTable.schema
- override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_READ).asJava
+ override def capabilities(): java.util.Set[TableCapability] =
+ java.util.EnumSet.of(TableCapability.BATCH_READ)
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
new TestLocalScanBuilder
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
index bb2acecc782b2..64c893ed74fdb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/SimpleWritableDataSource.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.connector
import java.io.{BufferedReader, InputStreamReader, IOException}
-import java.util
import scala.collection.JavaConverters._
@@ -138,8 +137,8 @@ class SimpleWritableDataSource extends TestingV2Source {
new MyWriteBuilder(path, info)
}
- override def capabilities(): util.Set[TableCapability] =
- Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava
+ override def capabilities(): java.util.Set[TableCapability] =
+ java.util.EnumSet.of(BATCH_READ, BATCH_WRITE, TRUNCATE)
}
override def getTable(options: CaseInsensitiveStringMap): Table = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
index ce94d3b5c2fc0..5f2e0b28aeccc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TableCapabilityCheckSuite.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
-import scala.collection.JavaConverters._
-
import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal}
@@ -217,7 +213,11 @@ private case object TestRelation extends LeafNode with NamedRelation {
private case class CapabilityTable(_capabilities: TableCapability*) extends Table {
override def name(): String = "capability_test_table"
override def schema(): StructType = TableCapabilityCheckSuite.schema
- override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava
+ override def capabilities(): java.util.Set[TableCapability] = {
+ val set = java.util.EnumSet.noneOf(classOf[TableCapability])
+ _capabilities.foreach(set.add)
+ set
+ }
}
private class TestStreamSourceProvider extends StreamSourceProvider {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
index bf2749d1afc53..0a0aaa8021996 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/TestV2SessionCatalogBase.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicBoolean
@@ -35,7 +34,7 @@ import org.apache.spark.sql.types.StructType
*/
private[connector] trait TestV2SessionCatalogBase[T <: Table] extends DelegatingCatalogExtension {
- protected val tables: util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]()
+ protected val tables: java.util.Map[Identifier, T] = new ConcurrentHashMap[Identifier, T]()
private val tableCreated: AtomicBoolean = new AtomicBoolean(false)
@@ -48,7 +47,7 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating
name: String,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): T
+ properties: java.util.Map[String, String]): T
override def loadTable(ident: Identifier): Table = {
if (tables.containsKey(ident)) {
@@ -69,12 +68,12 @@ private[connector] trait TestV2SessionCatalogBase[T <: Table] extends Delegating
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
val key = TestV2SessionCatalogBase.SIMULATE_ALLOW_EXTERNAL_PROPERTY
val propsWithLocation = if (properties.containsKey(key)) {
// Always set a location so that CREATE EXTERNAL TABLE won't fail with LOCATION not specified.
if (!properties.containsKey(TableCatalog.PROP_LOCATION)) {
- val newProps = new util.HashMap[String, String]()
+ val newProps = new java.util.HashMap[String, String]()
newProps.putAll(properties)
newProps.put(TableCatalog.PROP_LOCATION, "file:/abc")
newProps
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
index 847953e09cef7..c5be222645b19 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1ReadFallbackSuite.scala
@@ -17,10 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
-import scala.collection.JavaConverters._
-
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, QueryTest, Row, SparkSession, SQLContext}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, Identifier, SupportsRead, Table, TableCapability}
@@ -106,7 +102,7 @@ class V1ReadFallbackCatalog extends BasicInMemoryTableCatalog {
ident: Identifier,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): Table = {
+ properties: java.util.Map[String, String]): Table = {
// To simplify the test implementation, only support fixed schema.
if (schema != V1ReadFallbackCatalog.schema || partitions.nonEmpty) {
throw new UnsupportedOperationException
@@ -131,8 +127,8 @@ class TableWithV1ReadFallback(override val name: String) extends Table with Supp
override def schema(): StructType = V1ReadFallbackCatalog.schema
- override def capabilities(): util.Set[TableCapability] = {
- Set(TableCapability.BATCH_READ).asJava
+ override def capabilities(): java.util.Set[TableCapability] = {
+ java.util.EnumSet.of(TableCapability.BATCH_READ)
}
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
index 7effc747ab323..992c46cc6cdb1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/V1WriteFallbackSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.connector
-import java.util
-
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -223,7 +221,7 @@ class V1FallbackTableCatalog extends TestV2SessionCatalogBase[InMemoryTableWithV
name: String,
schema: StructType,
partitions: Array[Transform],
- properties: util.Map[String, String]): InMemoryTableWithV1Fallback = {
+ properties: java.util.Map[String, String]): InMemoryTableWithV1Fallback = {
val t = new InMemoryTableWithV1Fallback(name, schema, partitions, properties)
InMemoryV1Provider.tables.put(name, t)
tables.put(Identifier.of(Array("default"), name), t)
@@ -321,7 +319,7 @@ class InMemoryTableWithV1Fallback(
override val name: String,
override val schema: StructType,
override val partitioning: Array[Transform],
- override val properties: util.Map[String, String])
+ override val properties: java.util.Map[String, String])
extends Table
with SupportsWrite with SupportsRead {
@@ -331,11 +329,11 @@ class InMemoryTableWithV1Fallback(
}
}
- override def capabilities: util.Set[TableCapability] = Set(
+ override def capabilities: java.util.Set[TableCapability] = java.util.EnumSet.of(
TableCapability.BATCH_READ,
TableCapability.V1_BATCH_WRITE,
TableCapability.OVERWRITE_BY_FILTER,
- TableCapability.TRUNCATE).asJava
+ TableCapability.TRUNCATE)
@volatile private var dataMap: mutable.Map[Seq[Any], Seq[Row]] = mutable.Map.empty
private val partFieldNames = partitioning.flatMap(_.references).toSeq.flatMap(_.fieldNames)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
new file mode 100644
index 0000000000000..c787493fbdcc1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceAggregatePushDownSuite.scala
@@ -0,0 +1,617 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources
+
+import java.sql.{Date, Timestamp}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row}
+import org.apache.spark.sql.execution.datasources.orc.OrcTest
+import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.functions.min
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DateType, Decimal, DecimalType, DoubleType, FloatType, IntegerType, LongType, ShortType, StringType, StructField, StructType, TimestampType}
+
+/**
+ * A test suite that tests aggregate push down for Parquet and ORC.
+ */
+trait FileSourceAggregatePushDownSuite
+ extends QueryTest
+ with FileBasedDataSourceTest
+ with SharedSparkSession
+ with ExplainSuiteHelper {
+
+ import testImplicits._
+
+ protected def format: String
+ // The SQL config key for enabling aggregate push down.
+ protected val aggPushDownEnabledKey: String
+
+ test("nested column: Max(top level column) not push down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ withDataSourceTable(data, "t") {
+ val max = sql("SELECT Max(_1) FROM t")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ }
+ }
+ }
+
+ test("nested column: Count(top level column) push down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ withDataSourceTable(data, "t") {
+ val count = sql("SELECT Count(_1) FROM t")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [COUNT(_1)]"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+
+ test("nested column: Max(nested sub-field) not push down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(aggPushDownEnabledKey-> "true") {
+ withDataSourceTable(data, "t") {
+ val max = sql("SELECT Max(_1._2[0]) FROM t")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ }
+ }
+ }
+
+ test("nested column: Count(nested sub-field) not push down") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ withDataSourceTable(data, "t") {
+ val count = sql("SELECT Count(_1._2[0]) FROM t")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+
+ test("Max(partition column): not push down") {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").format(format).save(dir.getCanonicalPath)
+ withTempView("tmp") {
+ spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ val max = sql("SELECT Max(p) FROM tmp")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ checkAnswer(max, Seq(Row(2)))
+ }
+ }
+ }
+ }
+
+ test("Count(partition column): push down") {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("if(id % 2 = 0, null, id) AS n", "id % 3 as p")
+ .write.partitionBy("p").format(format).save(dir.getCanonicalPath)
+ withTempView("tmp") {
+ spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
+ val enableVectorizedReader = Seq("false", "true")
+ for (testVectorizedReader <- enableVectorizedReader) {
+ withSQLConf(aggPushDownEnabledKey -> "true",
+ vectorizedReaderEnabledKey -> testVectorizedReader) {
+ val count = sql("SELECT COUNT(p) FROM tmp")
+ count.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [COUNT(p)]"
+ checkKeywordsExistsInExplain(count, expected_plan_fragment)
+ }
+ checkAnswer(count, Seq(Row(10)))
+ }
+ }
+ }
+ }
+ }
+
+ test("filter alias over aggregate") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withDataSourceTable(data, "t") {
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ val selectAgg = sql("SELECT min(_1) + max(_1) as res FROM t having res > 1")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_1), MAX(_1)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(7)))
+ }
+ }
+ }
+
+ test("alias over aggregate") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withDataSourceTable(data, "t") {
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ val selectAgg = sql("SELECT min(_1) + 1 as minPlus1, min(_1) + 2 as minPlus2 FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_1)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-1, 0)))
+ }
+ }
+ }
+
+ test("aggregate over alias push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withDataSourceTable(data, "t") {
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ val df = spark.table("t")
+ val query = df.select($"_1".as("col1")).agg(min($"col1"))
+ query.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_1)]"
+ checkKeywordsExistsInExplain(query, expected_plan_fragment)
+ }
+ checkAnswer(query, Seq(Row(-2)))
+ }
+ }
+ }
+
+ test("query with group by not push down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withDataSourceTable(data, "t") {
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ // aggregate not pushed down if there is group by
+ val selectAgg = sql("SELECT min(_1) FROM t GROUP BY _3 ")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-2), Row(0), Row(2), Row(3)))
+ }
+ }
+ }
+
+ test("aggregate with data filter cannot be pushed down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withDataSourceTable(data, "t") {
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ // aggregate not pushed down if there is filter
+ val selectAgg = sql("SELECT min(_3) FROM t WHERE _1 > 0")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(2)))
+ }
+ }
+ }
+
+ test("aggregate with partition filter can be pushed down") {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").format(format).save(dir.getCanonicalPath)
+ withTempView("tmp") {
+ spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
+ Seq("false", "true").foreach { enableVectorizedReader =>
+ withSQLConf(aggPushDownEnabledKey -> "true",
+ vectorizedReaderEnabledKey -> enableVectorizedReader) {
+ val max = sql("SELECT max(id), min(id), count(id) FROM tmp WHERE p = 0")
+ max.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MAX(id), MIN(id), COUNT(id)]"
+ checkKeywordsExistsInExplain(max, expected_plan_fragment)
+ }
+ checkAnswer(max, Seq(Row(9, 0, 4)))
+ }
+ }
+ }
+ }
+ }
+
+ test("push down only if all the aggregates can be pushed down") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 7))
+ withDataSourceTable(data, "t") {
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ // not push down since sum can't be pushed down
+ val selectAgg = sql("SELECT min(_1), sum(_3) FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(-2, 41)))
+ }
+ }
+ }
+
+ test("aggregate push down - MIN/MAX/COUNT") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withDataSourceTable(data, "t") {
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ val selectAgg = sql("SELECT min(_3), min(_3), max(_3), min(_1), max(_1), max(_1)," +
+ " count(*), count(_1), count(_2), count(_3) FROM t")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(_3), " +
+ "MAX(_3), " +
+ "MIN(_1), " +
+ "MAX(_1), " +
+ "COUNT(*), " +
+ "COUNT(_1), " +
+ "COUNT(_2), " +
+ "COUNT(_3)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+
+ checkAnswer(selectAgg, Seq(Row(2, 2, 19, -2, 9, 9, 6, 6, 4, 6)))
+ }
+ }
+ }
+
+ test("aggregate not push down - MIN/MAX/COUNT with CASE WHEN") {
+ val data = Seq((-2, "abc", 2), (3, "def", 4), (6, "ghi", 2), (0, null, 19),
+ (9, "mno", 7), (2, null, 6))
+ withDataSourceTable(data, "t") {
+ withSQLConf(aggPushDownEnabledKey -> "true") {
+ val selectAgg = sql(
+ """
+ |SELECT
+ | min(CASE WHEN _1 < 0 THEN 0 ELSE _1 END),
+ | min(CASE WHEN _3 > 5 THEN 1 ELSE 0 END),
+ | max(CASE WHEN _1 < 0 THEN 0 ELSE _1 END),
+ | max(CASE WHEN NOT(_3 > 5) THEN 1 ELSE 0 END),
+ | count(CASE WHEN _1 < 0 AND _2 IS NOT NULL THEN 0 ELSE _1 END),
+ | count(CASE WHEN _3 != 5 OR _2 IS NULL THEN 1 ELSE 0 END)
+ |FROM t
+ """.stripMargin)
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+
+ checkAnswer(selectAgg, Seq(Row(0, 0, 9, 1, 6, 6)))
+ }
+ }
+ }
+
+ private def testPushDownForAllDataTypes(
+ inputRows: Seq[Row],
+ expectedMinWithAllTypes: Seq[Row],
+ expectedMinWithOutTSAndBinary: Seq[Row],
+ expectedMaxWithAllTypes: Seq[Row],
+ expectedMaxWithOutTSAndBinary: Seq[Row],
+ expectedCount: Seq[Row]): Unit = {
+ implicit class StringToDate(s: String) {
+ def date: Date = Date.valueOf(s)
+ }
+
+ implicit class StringToTs(s: String) {
+ def ts: Timestamp = Timestamp.valueOf(s)
+ }
+
+ val schema = StructType(List(StructField("StringCol", StringType, true),
+ StructField("BooleanCol", BooleanType, false),
+ StructField("ByteCol", ByteType, false),
+ StructField("BinaryCol", BinaryType, false),
+ StructField("ShortCol", ShortType, false),
+ StructField("IntegerCol", IntegerType, true),
+ StructField("LongCol", LongType, false),
+ StructField("FloatCol", FloatType, false),
+ StructField("DoubleCol", DoubleType, false),
+ StructField("DecimalCol", DecimalType(25, 5), true),
+ StructField("DateCol", DateType, false),
+ StructField("TimestampCol", TimestampType, false)).toArray)
+
+ val rdd = sparkContext.parallelize(inputRows)
+ withTempPath { file =>
+ spark.createDataFrame(rdd, schema).write.format(format).save(file.getCanonicalPath)
+ withTempView("test") {
+ spark.read.format(format).load(file.getCanonicalPath).createOrReplaceTempView("test")
+ Seq("false", "true").foreach { enableVectorizedReader =>
+ withSQLConf(aggPushDownEnabledKey -> "true",
+ vectorizedReaderEnabledKey -> enableVectorizedReader) {
+
+ val testMinWithAllTypes = sql("SELECT min(StringCol), min(BooleanCol), min(ByteCol), " +
+ "min(BinaryCol), min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " +
+ "min(DoubleCol), min(DecimalCol), min(DateCol), min(TimestampCol) FROM test")
+
+ // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type
+ // so aggregates are not pushed down
+ // In addition, Parquet Binary min/max could be truncated, so we disable aggregate
+ // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType).
+ // Also do not push down for ORC with same reason.
+ testMinWithAllTypes.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(testMinWithAllTypes, expected_plan_fragment)
+ }
+
+ checkAnswer(testMinWithAllTypes, expectedMinWithAllTypes)
+
+ val testMinWithOutTSAndBinary = sql("SELECT min(BooleanCol), min(ByteCol), " +
+ "min(ShortCol), min(IntegerCol), min(LongCol), min(FloatCol), " +
+ "min(DoubleCol), min(DateCol) FROM test")
+
+ testMinWithOutTSAndBinary.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MIN(BooleanCol), " +
+ "MIN(ByteCol), " +
+ "MIN(ShortCol), " +
+ "MIN(IntegerCol), " +
+ "MIN(LongCol), " +
+ "MIN(FloatCol), " +
+ "MIN(DoubleCol), " +
+ "MIN(DateCol)]"
+ checkKeywordsExistsInExplain(testMinWithOutTSAndBinary, expected_plan_fragment)
+ }
+
+ checkAnswer(testMinWithOutTSAndBinary, expectedMinWithOutTSAndBinary)
+
+ val testMaxWithAllTypes = sql("SELECT max(StringCol), max(BooleanCol), " +
+ "max(ByteCol), max(BinaryCol), max(ShortCol), max(IntegerCol), max(LongCol), " +
+ "max(FloatCol), max(DoubleCol), max(DecimalCol), max(DateCol), max(TimestampCol) " +
+ "FROM test")
+
+ // INT96 (Timestamp) sort order is undefined, parquet doesn't return stats for this type
+ // so aggregates are not pushed down
+ // In addition, Parquet Binary min/max could be truncated, so we disable aggregate
+ // push down for Parquet Binary (could be Spark StringType, BinaryType or DecimalType).
+ // Also do not push down for ORC with same reason.
+ testMaxWithAllTypes.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: []"
+ checkKeywordsExistsInExplain(testMaxWithAllTypes, expected_plan_fragment)
+ }
+
+ checkAnswer(testMaxWithAllTypes, expectedMaxWithAllTypes)
+
+ val testMaxWithoutTSAndBinary = sql("SELECT max(BooleanCol), max(ByteCol), " +
+ "max(ShortCol), max(IntegerCol), max(LongCol), max(FloatCol), " +
+ "max(DoubleCol), max(DateCol) FROM test")
+
+ testMaxWithoutTSAndBinary.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MAX(BooleanCol), " +
+ "MAX(ByteCol), " +
+ "MAX(ShortCol), " +
+ "MAX(IntegerCol), " +
+ "MAX(LongCol), " +
+ "MAX(FloatCol), " +
+ "MAX(DoubleCol), " +
+ "MAX(DateCol)]"
+ checkKeywordsExistsInExplain(testMaxWithoutTSAndBinary, expected_plan_fragment)
+ }
+
+ checkAnswer(testMaxWithoutTSAndBinary, expectedMaxWithOutTSAndBinary)
+
+ val testCount = sql("SELECT count(StringCol), count(BooleanCol)," +
+ " count(ByteCol), count(BinaryCol), count(ShortCol), count(IntegerCol)," +
+ " count(LongCol), count(FloatCol), count(DoubleCol)," +
+ " count(DecimalCol), count(DateCol), count(TimestampCol) FROM test")
+
+ testCount.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [" +
+ "COUNT(StringCol), " +
+ "COUNT(BooleanCol), " +
+ "COUNT(ByteCol), " +
+ "COUNT(BinaryCol), " +
+ "COUNT(ShortCol), " +
+ "COUNT(IntegerCol), " +
+ "COUNT(LongCol), " +
+ "COUNT(FloatCol), " +
+ "COUNT(DoubleCol), " +
+ "COUNT(DecimalCol), " +
+ "COUNT(DateCol), " +
+ "COUNT(TimestampCol)]"
+ checkKeywordsExistsInExplain(testCount, expected_plan_fragment)
+ }
+
+ checkAnswer(testCount, expectedCount)
+ }
+ }
+ }
+ }
+ }
+
+ test("aggregate push down - different data types") {
+ implicit class StringToDate(s: String) {
+ def date: Date = Date.valueOf(s)
+ }
+
+ implicit class StringToTs(s: String) {
+ def ts: Timestamp = Timestamp.valueOf(s)
+ }
+
+ val rows =
+ Seq(
+ Row(
+ "a string",
+ true,
+ 10.toByte,
+ "Spark SQL".getBytes,
+ 12.toShort,
+ 3,
+ Long.MaxValue,
+ 0.15.toFloat,
+ 0.75D,
+ Decimal("12.345678"),
+ ("2021-01-01").date,
+ ("2015-01-01 23:50:59.123").ts),
+ Row(
+ "test string",
+ false,
+ 1.toByte,
+ "Parquet".getBytes,
+ 2.toShort,
+ null,
+ Long.MinValue,
+ 0.25.toFloat,
+ 0.85D,
+ Decimal("1.2345678"),
+ ("2015-01-01").date,
+ ("2021-01-01 23:50:59.123").ts),
+ Row(
+ null,
+ true,
+ 10000.toByte,
+ "Spark ML".getBytes,
+ 222.toShort,
+ 113,
+ 11111111L,
+ 0.25.toFloat,
+ 0.75D,
+ Decimal("12345.678"),
+ ("2004-06-19").date,
+ ("1999-08-26 10:43:59.123").ts)
+ )
+
+ testPushDownForAllDataTypes(
+ rows,
+ Seq(Row("a string", false, 1.toByte,
+ "Parquet".getBytes, 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D,
+ 1.23457, ("2004-06-19").date, ("1999-08-26 10:43:59.123").ts)),
+ Seq(Row(false, 1.toByte,
+ 2.toShort, 3, -9223372036854775808L, 0.15.toFloat, 0.75D, ("2004-06-19").date)),
+ Seq(Row("test string", true, 16.toByte,
+ "Spark SQL".getBytes, 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D,
+ 12345.678, ("2021-01-01").date, ("2021-01-01 23:50:59.123").ts)),
+ Seq(Row(true, 16.toByte,
+ 222.toShort, 113, 9223372036854775807L, 0.25.toFloat, 0.85D, ("2021-01-01").date)),
+ Seq(Row(2, 3, 3, 3, 3, 2, 3, 3, 3, 3, 3, 3))
+ )
+
+ // Test for 0 row (empty file)
+ val nullRow = Row.fromSeq((1 to 12).map(_ => null))
+ val nullRowWithOutTSAndBinary = Row.fromSeq((1 to 8).map(_ => null))
+ val zeroCount = Row.fromSeq((1 to 12).map(_ => 0))
+ testPushDownForAllDataTypes(Seq.empty, Seq(nullRow), Seq(nullRowWithOutTSAndBinary),
+ Seq(nullRow), Seq(nullRowWithOutTSAndBinary), Seq(zeroCount))
+ }
+
+ test("column name case sensitivity") {
+ Seq("false", "true").foreach { enableVectorizedReader =>
+ withSQLConf(aggPushDownEnabledKey -> "true",
+ vectorizedReaderEnabledKey -> enableVectorizedReader) {
+ withTempPath { dir =>
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").format(format).save(dir.getCanonicalPath)
+ withTempView("tmp") {
+ spark.read.format(format).load(dir.getCanonicalPath).createOrReplaceTempView("tmp")
+ val selectAgg = sql("SELECT max(iD), min(Id) FROM tmp")
+ selectAgg.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregation: [MAX(id), MIN(id)]"
+ checkKeywordsExistsInExplain(selectAgg, expected_plan_fragment)
+ }
+ checkAnswer(selectAgg, Seq(Row(9, 0)))
+ }
+ }
+ }
+ }
+ }
+}
+
+abstract class ParquetAggregatePushDownSuite
+ extends FileSourceAggregatePushDownSuite with ParquetTest {
+
+ override def format: String = "parquet"
+ override protected val aggPushDownEnabledKey: String =
+ SQLConf.PARQUET_AGGREGATE_PUSHDOWN_ENABLED.key
+}
+
+class ParquetV1AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+ override protected def sparkConf: SparkConf =
+ super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet")
+}
+
+class ParquetV2AggregatePushDownSuite extends ParquetAggregatePushDownSuite {
+
+ override protected def sparkConf: SparkConf =
+ super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "")
+}
+
+abstract class OrcAggregatePushDownSuite extends OrcTest with FileSourceAggregatePushDownSuite {
+
+ override def format: String = "orc"
+ override protected val aggPushDownEnabledKey: String =
+ SQLConf.ORC_AGGREGATE_PUSHDOWN_ENABLED.key
+}
+
+class OrcV1AggregatePushDownSuite extends OrcAggregatePushDownSuite {
+
+ override protected def sparkConf: SparkConf =
+ super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "orc")
+}
+
+class OrcV2AggregatePushDownSuite extends OrcAggregatePushDownSuite {
+
+ override protected def sparkConf: SparkConf =
+ super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "")
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 58921485b207d..e71f3b8c35e25 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -2965,16 +2965,14 @@ class JsonV2Suite extends JsonSuite {
withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "true") {
withTempPath { file =>
val scanBuilder = getBuilder(file.getCanonicalPath)
- assert(scanBuilder.pushFilters(filters) === filters)
- assert(scanBuilder.pushedFilters() === filters)
+ assert(scanBuilder.pushDataFilters(filters) === filters)
}
}
withSQLConf(SQLConf.JSON_FILTER_PUSHDOWN_ENABLED.key -> "false") {
withTempPath { file =>
val scanBuilder = getBuilder(file.getCanonicalPath)
- assert(scanBuilder.pushFilters(filters) === filters)
- assert(scanBuilder.pushedFilters() === Array.empty[sources.Filter])
+ assert(scanBuilder.pushDataFilters(filters) === Array.empty[sources.Filter])
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
new file mode 100644
index 0000000000000..6296da47cca51
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue}
+import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.BooleanType
+
+class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
+ test("SPARK-36644: Push down boolean column filter") {
+ testTranslateFilter(Symbol("col").boolean,
+ Some(new Predicate("=", Array(FieldReference("col"), LiteralValue(true, BooleanType)))))
+ }
+
+ /**
+ * Translate the given Catalyst [[Expression]] into data source V2 [[Predicate]]
+ * then verify against the given [[Predicate]].
+ */
+ def testTranslateFilter(catalystFilter: Expression, result: Option[Predicate]): Unit = {
+ assertResult(result) {
+ DataSourceV2Strategy.translateFilterV2(catalystFilter, true)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
new file mode 100644
index 0000000000000..2d6e6fcf16174
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2PredicateSuite.scala
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, Literal, LiteralValue}
+import org.apache.spark.sql.connector.expressions.filter._
+import org.apache.spark.sql.execution.datasources.v2.V2PredicateSuite.ref
+import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+class V2PredicateSuite extends SparkFunSuite {
+
+ test("nested columns") {
+ val predicate1 =
+ new Predicate("=", Array[Expression](ref("a", "B"), LiteralValue(1, IntegerType)))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a.B"))
+ assert(predicate1.describe.equals("a.B = 1"))
+
+ val predicate2 =
+ new Predicate("=", Array[Expression](ref("a", "b.c"), LiteralValue(1, IntegerType)))
+ assert(predicate2.references.map(_.describe()).toSeq == Seq("a.`b.c`"))
+ assert(predicate2.describe.equals("a.`b.c` = 1"))
+
+ val predicate3 =
+ new Predicate("=", Array[Expression](ref("`a`.b", "c"), LiteralValue(1, IntegerType)))
+ assert(predicate3.references.map(_.describe()).toSeq == Seq("```a``.b`.c"))
+ assert(predicate3.describe.equals("```a``.b`.c = 1"))
+ }
+
+ test("AlwaysTrue") {
+ val predicate1 = new AlwaysTrue
+ val predicate2 = new AlwaysTrue
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).length == 0)
+ assert(predicate1.describe.equals("TRUE"))
+ }
+
+ test("AlwaysFalse") {
+ val predicate1 = new AlwaysFalse
+ val predicate2 = new AlwaysFalse
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).length == 0)
+ assert(predicate1.describe.equals("FALSE"))
+ }
+
+ test("EqualTo") {
+ val predicate1 = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))
+ val predicate2 = new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
+ assert(predicate1.describe.equals("a = 1"))
+ }
+
+ test("EqualNullSafe") {
+ val predicate1 = new Predicate("<=>", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))
+ val predicate2 = new Predicate("<=>", Array[Expression](ref("a"), LiteralValue(1, IntegerType)))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
+ assert(predicate1.describe.equals("(a = 1) OR (a IS NULL AND 1 IS NULL)"))
+ }
+
+ test("In") {
+ val predicate1 = new Predicate("IN",
+ Array(ref("a"), LiteralValue(1, IntegerType), LiteralValue(2, IntegerType),
+ LiteralValue(3, IntegerType), LiteralValue(4, IntegerType)))
+ val predicate2 = new Predicate("IN",
+ Array(ref("a"), LiteralValue(4, IntegerType), LiteralValue(2, IntegerType),
+ LiteralValue(3, IntegerType), LiteralValue(1, IntegerType)))
+ assert(!predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
+ assert(predicate1.describe.equals("a IN (1, 2, 3, 4)"))
+ val values: Array[Literal[_]] = new Array[Literal[_]](1000)
+ var expected = "a IN ("
+ for (i <- 0 until 1000) {
+ values(i) = LiteralValue(i, IntegerType)
+ expected += i + ", "
+ }
+ val predicate3 = new Predicate("IN", (ref("a") +: values).toArray[Expression])
+ expected = expected.dropRight(2) // remove the last ", "
+ expected += ")"
+ assert(predicate3.describe.equals(expected))
+ }
+
+ test("IsNull") {
+ val predicate1 = new Predicate("IS_NULL", Array[Expression](ref("a")))
+ val predicate2 = new Predicate("IS_NULL", Array[Expression](ref("a")))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
+ assert(predicate1.describe.equals("a IS NULL"))
+ }
+
+ test("IsNotNull") {
+ val predicate1 = new Predicate("IS_NOT_NULL", Array[Expression](ref("a")))
+ val predicate2 = new Predicate("IS_NOT_NULL", Array[Expression](ref("a")))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
+ assert(predicate1.describe.equals("a IS NOT NULL"))
+ }
+
+ test("Not") {
+ val predicate1 = new Not(
+ new Predicate("<", Array[Expression](ref("a"), LiteralValue(1, IntegerType))))
+ val predicate2 = new Not(
+ new Predicate("<", Array[Expression](ref("a"), LiteralValue(1, IntegerType))))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
+ assert(predicate1.describe.equals("NOT (a < 1)"))
+ }
+
+ test("And") {
+ val predicate1 = new And(
+ new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))),
+ new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType))))
+ val predicate2 = new And(
+ new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))),
+ new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType))))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b"))
+ assert(predicate1.describe.equals("(a = 1) AND (b = 1)"))
+ }
+
+ test("Or") {
+ val predicate1 = new Or(
+ new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))),
+ new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType))))
+ val predicate2 = new Or(
+ new Predicate("=", Array[Expression](ref("a"), LiteralValue(1, IntegerType))),
+ new Predicate("=", Array[Expression](ref("b"), LiteralValue(1, IntegerType))))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a", "b"))
+ assert(predicate1.describe.equals("(a = 1) OR (b = 1)"))
+ }
+
+ test("StringStartsWith") {
+ val literal = LiteralValue(UTF8String.fromString("str"), StringType)
+ val predicate1 = new Predicate("STARTS_WITH",
+ Array[Expression](ref("a"), literal))
+ val predicate2 = new Predicate("STARTS_WITH",
+ Array[Expression](ref("a"), literal))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
+ assert(predicate1.describe.equals("a LIKE 'str%'"))
+ }
+
+ test("StringEndsWith") {
+ val literal = LiteralValue(UTF8String.fromString("str"), StringType)
+ val predicate1 = new Predicate("ENDS_WITH",
+ Array[Expression](ref("a"), literal))
+ val predicate2 = new Predicate("ENDS_WITH",
+ Array[Expression](ref("a"), literal))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
+ assert(predicate1.describe.equals("a LIKE '%str'"))
+ }
+
+ test("StringContains") {
+ val literal = LiteralValue(UTF8String.fromString("str"), StringType)
+ val predicate1 = new Predicate("CONTAINS",
+ Array[Expression](ref("a"), literal))
+ val predicate2 = new Predicate("CONTAINS",
+ Array[Expression](ref("a"), literal))
+ assert(predicate1.equals(predicate2))
+ assert(predicate1.references.map(_.describe()).toSeq == Seq("a"))
+ assert(predicate1.describe.equals("a LIKE '%str%'"))
+ }
+}
+
+object V2PredicateSuite {
+ private[sql] def ref(parts: String*): FieldReference = {
+ new FieldReference(parts)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala
index 1a4f08418f8d3..1a52dc4da009f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2SessionCatalogSuite.scala
@@ -67,10 +67,10 @@ class V2SessionCatalogTableSuite extends V2SessionCatalogBaseSuite {
override protected def afterAll(): Unit = {
val catalog = newCatalog()
- catalog.dropNamespace(Array("db"))
- catalog.dropNamespace(Array("db2"))
- catalog.dropNamespace(Array("ns"))
- catalog.dropNamespace(Array("ns2"))
+ catalog.dropNamespace(Array("db"), cascade = true)
+ catalog.dropNamespace(Array("db2"), cascade = true)
+ catalog.dropNamespace(Array("ns"), cascade = true)
+ catalog.dropNamespace(Array("ns2"), cascade = true)
super.afterAll()
}
@@ -806,7 +806,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
assert(catalog.listNamespaces(Array()) === Array(testNs, defaultNs))
assert(catalog.listNamespaces(testNs) === Array())
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("listNamespaces: fail if missing namespace") {
@@ -844,7 +844,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
assert(catalog.namespaceExists(testNs) === true)
checkMetadata(metadata.asScala, Map("property" -> "value"))
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("loadNamespaceMetadata: empty metadata") {
@@ -859,7 +859,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
assert(catalog.namespaceExists(testNs) === true)
checkMetadata(metadata.asScala, emptyProps.asScala)
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("createNamespace: basic behavior") {
@@ -879,7 +879,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
checkMetadata(metadata, Map("property" -> "value"))
assert(expectedPath === metadata("location"))
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("createNamespace: initialize location") {
@@ -895,7 +895,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
checkMetadata(metadata, Map.empty)
assert(expectedPath === metadata("location"))
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("createNamespace: relative location") {
@@ -912,7 +912,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
checkMetadata(metadata, Map.empty)
assert(expectedPath === metadata("location"))
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("createNamespace: fail if namespace already exists") {
@@ -928,7 +928,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
assert(catalog.namespaceExists(testNs) === true)
checkMetadata(catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value"))
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("createNamespace: fail nested namespace") {
@@ -943,7 +943,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
assert(exc.getMessage.contains("Invalid namespace name: db.nested"))
- catalog.dropNamespace(Array("db"))
+ catalog.dropNamespace(Array("db"), cascade = false)
}
test("createTable: fail if namespace does not exist") {
@@ -964,7 +964,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
assert(catalog.namespaceExists(testNs) === false)
- val ret = catalog.dropNamespace(testNs)
+ val ret = catalog.dropNamespace(testNs, cascade = false)
assert(ret === false)
}
@@ -976,7 +976,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
assert(catalog.namespaceExists(testNs) === true)
- val ret = catalog.dropNamespace(testNs)
+ val ret = catalog.dropNamespace(testNs, cascade = false)
assert(ret === true)
assert(catalog.namespaceExists(testNs) === false)
@@ -988,8 +988,8 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
catalog.createNamespace(testNs, Map("property" -> "value").asJava)
catalog.createTable(testIdent, schema, Array.empty, emptyProps)
- val exc = intercept[IllegalStateException] {
- catalog.dropNamespace(testNs)
+ val exc = intercept[AnalysisException] {
+ catalog.dropNamespace(testNs, cascade = false)
}
assert(exc.getMessage.contains(testNs.quoted))
@@ -997,7 +997,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
checkMetadata(catalog.loadNamespaceMetadata(testNs).asScala, Map("property" -> "value"))
catalog.dropTable(testIdent)
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("alterNamespace: basic behavior") {
@@ -1022,7 +1022,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
catalog.loadNamespaceMetadata(testNs).asScala,
Map("property" -> "value"))
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("alterNamespace: update namespace location") {
@@ -1045,7 +1045,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
catalog.alterNamespace(testNs, NamespaceChange.setProperty("location", "relativeP"))
assert(newRelativePath === spark.catalog.getDatabase(testNs(0)).locationUri)
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("alterNamespace: update namespace comment") {
@@ -1060,7 +1060,7 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
assert(newComment === spark.catalog.getDatabase(testNs(0)).description)
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
test("alterNamespace: fail if namespace doesn't exist") {
@@ -1087,6 +1087,6 @@ class V2SessionCatalogNamespaceSuite extends V2SessionCatalogBaseSuite {
assert(exc.getMessage.contains(s"Cannot remove reserved property: $p"))
}
- catalog.dropNamespace(testNs)
+ catalog.dropNamespace(testNs, cascade = false)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index 8842db2a2aca4..8f690eeaff901 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -24,7 +24,6 @@ import java.util.{Calendar, GregorianCalendar, Properties, TimeZone}
import scala.collection.JavaConverters._
-import org.h2.jdbc.JdbcSQLException
import org.mockito.ArgumentMatchers._
import org.mockito.Mockito._
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
@@ -38,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeTestUtils
import org.apache.spark.sql.execution.{DataSourceScanExec, ExtendedMode}
import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCommand}
import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRDD, JDBCRelation, JdbcUtils}
+import org.apache.spark.sql.execution.datasources.jdbc.{JDBCOptions, JDBCPartition, JDBCRelation, JdbcUtils}
import org.apache.spark.sql.execution.metric.InputOutputMetricsHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
@@ -54,7 +53,8 @@ class JDBCSuite extends QueryTest
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
var conn: java.sql.Connection = null
- val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte)
+ val testBytes = Array[Byte](99.toByte, 134.toByte, 135.toByte, 200.toByte, 205.toByte) ++
+ Array.fill(15)(0.toByte)
val testH2Dialect = new JdbcDialect {
override def canHandle(url: String): Boolean = url.startsWith("jdbc:h2")
@@ -87,7 +87,6 @@ class JDBCSuite extends QueryTest
val properties = new Properties()
properties.setProperty("user", "testUser")
properties.setProperty("password", "testPass")
- properties.setProperty("rowId", "false")
conn = DriverManager.getConnection(url, properties)
conn.prepareStatement("create schema test").executeUpdate()
@@ -162,7 +161,7 @@ class JDBCSuite extends QueryTest
|OPTIONS (url '$url', dbtable 'TEST.STRTYPES', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
- conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP)"
+ conn.prepareStatement("create table test.timetypes (a TIME, b DATE, c TIMESTAMP(7))"
).executeUpdate()
conn.prepareStatement("insert into test.timetypes values ('12:34:56', "
+ "'1996-01-01', '2002-02-20 11:22:33.543543543')").executeUpdate()
@@ -177,12 +176,12 @@ class JDBCSuite extends QueryTest
""".stripMargin.replaceAll("\n", " "))
conn.prepareStatement("CREATE TABLE test.timezone (tz TIMESTAMP WITH TIME ZONE) " +
- "AS SELECT '1999-01-08 04:05:06.543543543 GMT-08:00'")
+ "AS SELECT '1999-01-08 04:05:06.543543543-08:00'")
.executeUpdate()
conn.commit()
- conn.prepareStatement("CREATE TABLE test.array (ar ARRAY) " +
- "AS SELECT '(1, 2, 3)'")
+ conn.prepareStatement("CREATE TABLE test.array_table (ar Integer ARRAY) " +
+ "AS SELECT ARRAY[1, 2, 3]")
.executeUpdate()
conn.commit()
@@ -638,7 +637,7 @@ class JDBCSuite extends QueryTest
assert(rows(0).getAs[Array[Byte]](0).sameElements(testBytes))
assert(rows(0).getString(1).equals("Sensitive"))
assert(rows(0).getString(2).equals("Insensitive"))
- assert(rows(0).getString(3).equals("Twenty-byte CHAR"))
+ assert(rows(0).getString(3).equals("Twenty-byte CHAR "))
assert(rows(0).getAs[Array[Byte]](4).sameElements(testBytes))
assert(rows(0).getString(5).equals("I am a clob!"))
}
@@ -729,20 +728,6 @@ class JDBCSuite extends QueryTest
assert(math.abs(rows(0).getDouble(1) - 1.00000023841859331) < 1e-12)
}
- test("Pass extra properties via OPTIONS") {
- // We set rowId to false during setup, which means that _ROWID_ column should be absent from
- // all tables. If rowId is true (default), the query below doesn't throw an exception.
- intercept[JdbcSQLException] {
- sql(
- s"""
- |CREATE OR REPLACE TEMPORARY VIEW abc
- |USING org.apache.spark.sql.jdbc
- |OPTIONS (url '$url', dbtable '(SELECT _ROWID_ FROM test.people)',
- | user 'testUser', password 'testPass')
- """.stripMargin.replaceAll("\n", " "))
- }
- }
-
test("Remap types via JdbcDialects") {
JdbcDialects.registerDialect(testH2Dialect)
val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties())
@@ -788,33 +773,36 @@ class JDBCSuite extends QueryTest
}
test("compile filters") {
- val compileFilter = PrivateMethod[Option[String]](Symbol("compileFilter"))
def doCompileFilter(f: Filter): String =
- JDBCRDD invokePrivate compileFilter(f, JdbcDialects.get("jdbc:")) getOrElse("")
- assert(doCompileFilter(EqualTo("col0", 3)) === """"col0" = 3""")
- assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === """(NOT ("col1" = 'abc'))""")
- assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def")))
- === """("col0" = 0) AND ("col1" = 'def')""")
- assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi")))
- === """("col0" = 2) OR ("col1" = 'ghi')""")
- assert(doCompileFilter(LessThan("col0", 5)) === """"col0" < 5""")
- assert(doCompileFilter(LessThan("col3",
- Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col3" < '1995-11-21 00:00:00.0'""")
- assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04")))
- === """"col4" < '1983-08-04'""")
- assert(doCompileFilter(LessThanOrEqual("col0", 5)) === """"col0" <= 5""")
- assert(doCompileFilter(GreaterThan("col0", 3)) === """"col0" > 3""")
- assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === """"col0" >= 3""")
- assert(doCompileFilter(In("col1", Array("jkl"))) === """"col1" IN ('jkl')""")
- assert(doCompileFilter(In("col1", Array.empty)) ===
- """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""")
- assert(doCompileFilter(Not(In("col1", Array("mno", "pqr"))))
- === """(NOT ("col1" IN ('mno', 'pqr')))""")
- assert(doCompileFilter(IsNull("col1")) === """"col1" IS NULL""")
- assert(doCompileFilter(IsNotNull("col1")) === """"col1" IS NOT NULL""")
- assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def")))
- === """((NOT ("col0" != 'abc' OR "col0" IS NULL OR 'abc' IS NULL) """
- + """OR ("col0" IS NULL AND 'abc' IS NULL))) AND ("col1" = 'def')""")
+ JdbcDialects.get("jdbc:").compileExpression(f.toV2).getOrElse("")
+
+ Seq(("col0", "col1"), ("`col0`", "`col1`")).foreach { case(col0, col1) =>
+ assert(doCompileFilter(EqualTo(col0, 3)) === """"col0" = 3""")
+ assert(doCompileFilter(Not(EqualTo(col1, "abc"))) === """NOT ("col1" = 'abc')""")
+ assert(doCompileFilter(And(EqualTo(col0, 0), EqualTo(col1, "def")))
+ === """("col0" = 0) AND ("col1" = 'def')""")
+ assert(doCompileFilter(Or(EqualTo(col0, 2), EqualTo(col1, "ghi")))
+ === """("col0" = 2) OR ("col1" = 'ghi')""")
+ assert(doCompileFilter(LessThan(col0, 5)) === """"col0" < 5""")
+ assert(doCompileFilter(LessThan(col0,
+ Timestamp.valueOf("1995-11-21 00:00:00.0"))) === """"col0" < '1995-11-21 00:00:00.0'""")
+ assert(doCompileFilter(LessThan(col0, Date.valueOf("1983-08-04")))
+ === """"col0" < '1983-08-04'""")
+ assert(doCompileFilter(LessThanOrEqual(col0, 5)) === """"col0" <= 5""")
+ assert(doCompileFilter(GreaterThan(col0, 3)) === """"col0" > 3""")
+ assert(doCompileFilter(GreaterThanOrEqual(col0, 3)) === """"col0" >= 3""")
+ assert(doCompileFilter(In(col1, Array("jkl"))) === """"col1" IN ('jkl')""")
+ assert(doCompileFilter(In(col1, Array.empty)) ===
+ """CASE WHEN "col1" IS NULL THEN NULL ELSE FALSE END""")
+ assert(doCompileFilter(Not(In(col1, Array("mno", "pqr"))))
+ === """NOT ("col1" IN ('mno', 'pqr'))""")
+ assert(doCompileFilter(IsNull(col1)) === """"col1" IS NULL""")
+ assert(doCompileFilter(IsNotNull(col1)) === """"col1" IS NOT NULL""")
+ assert(doCompileFilter(And(EqualNullSafe(col0, "abc"), EqualTo(col1, "def")))
+ === """(("col0" = 'abc') OR ("col0" IS NULL AND 'abc' IS NULL))"""
+ + """ AND ("col1" = 'def')""")
+ }
+ assert(doCompileFilter(EqualTo("col0.nested", 3)).isEmpty)
}
test("Dialect unregister") {
@@ -1375,7 +1363,7 @@ class JDBCSuite extends QueryTest
}.getMessage
assert(e.contains("Unsupported type TIMESTAMP_WITH_TIMEZONE"))
e = intercept[SQLException] {
- spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY", new Properties()).collect()
+ spark.read.jdbc(urlWithUserAndPass, "TEST.ARRAY_TABLE", new Properties()).collect()
}.getMessage
assert(e.contains("Unsupported type ARRAY"))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
index 526dad91e5e19..94f044a0a6755 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCV2Suite.scala
@@ -20,13 +20,14 @@ package org.apache.spark.sql.jdbc
import java.sql.{Connection, DriverManager}
import java.util.Properties
-import org.apache.spark.SparkConf
-import org.apache.spark.sql.{ExplainSuiteHelper, QueryTest, Row}
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.sql.{DataFrame, ExplainSuiteHelper, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.CannotReplaceMissingTableException
-import org.apache.spark.sql.catalyst.plans.logical.Filter
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Sort}
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2ScanRelation, V1ScanWrapper}
import org.apache.spark.sql.execution.datasources.v2.jdbc.JDBCTableCatalog
-import org.apache.spark.sql.functions.{lit, sum, udf}
+import org.apache.spark.sql.functions.{abs, avg, ceil, coalesce, count, count_distinct, exp, floor, lit, log => ln, not, pow, sqrt, sum, udf, when}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.util.Utils
@@ -42,6 +43,7 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.set("spark.sql.catalog.h2.url", url)
.set("spark.sql.catalog.h2.driver", "org.h2.Driver")
.set("spark.sql.catalog.h2.pushDownAggregate", "true")
+ .set("spark.sql.catalog.h2.pushDownLimit", "true")
private def withConnection[T](f: Connection => T): T = {
val conn = DriverManager.getConnection(url, new Properties())
@@ -67,17 +69,40 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
conn.prepareStatement("INSERT INTO \"test\".\"people\" VALUES ('mary', 2)").executeUpdate()
conn.prepareStatement(
"CREATE TABLE \"test\".\"employee\" (dept INTEGER, name TEXT(32), salary NUMERIC(20, 2)," +
- " bonus DOUBLE)").executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000)")
- .executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200)")
- .executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200)")
- .executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300)")
- .executeUpdate()
- conn.prepareStatement("INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200)")
+ " bonus DOUBLE, is_manager BOOLEAN)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (1, 'amy', 10000, 1000, true)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (2, 'alex', 12000, 1200, false)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (1, 'cathy', 9000, 1200, false)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (2, 'david', 10000, 1300, true)").executeUpdate()
+ conn.prepareStatement(
+ "INSERT INTO \"test\".\"employee\" VALUES (6, 'jen', 12000, 1200, true)").executeUpdate()
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"dept\" (\"dept id\" INTEGER NOT NULL)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (1)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"dept\" VALUES (2)").executeUpdate()
+
+ // scalastyle:off
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"person\" (\"名\" INTEGER NOT NULL)").executeUpdate()
+ // scalastyle:on
+ conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (1)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"person\" VALUES (2)").executeUpdate()
+ conn.prepareStatement(
+ """CREATE TABLE "test"."view1" ("|col1" INTEGER, "|col2" INTEGER)""").executeUpdate()
+ conn.prepareStatement(
+ """CREATE TABLE "test"."view2" ("|col1" INTEGER, "|col3" INTEGER)""").executeUpdate()
+
+ conn.prepareStatement(
+ "CREATE TABLE \"test\".\"item\" (id INTEGER, name TEXT(32), price NUMERIC(23, 3))")
.executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " +
+ "(1, 'bottle', 11111111111111111111.123)").executeUpdate()
+ conn.prepareStatement("INSERT INTO \"test\".\"item\" VALUES " +
+ "(1, 'bottle', 99999999999999999999.123)").executeUpdate()
}
}
@@ -92,42 +117,369 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
checkAnswer(sql("SELECT name, id FROM h2.test.people"), Seq(Row("fred", 1), Row("mary", 2)))
}
- test("scan with filter push-down") {
- val df = spark.table("h2.test.people").filter($"id" > 1)
- val filters = df.queryExecution.optimizedPlan.collect {
- case f: Filter => f
- }
- assert(filters.isEmpty)
-
+ private def checkPushedInfo(df: DataFrame, expectedPlanFragment: String): Unit = {
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedFilters: [IsNotNull(ID), GreaterThan(ID,1)]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ checkKeywordsExistsInExplain(df, expectedPlanFragment)
}
+ }
- checkAnswer(df, Row("mary", 2))
+ // TABLESAMPLE ({integer_expression | decimal_expression} PERCENT) and
+ // TABLESAMPLE (BUCKET integer_expression OUT OF integer_expression)
+ // are tested in JDBC dialect tests because TABLESAMPLE is not supported by all the DBMS
+ test("TABLESAMPLE (integer_expression ROWS) is the same as LIMIT") {
+ val df = sql("SELECT NAME FROM h2.test.employee TABLESAMPLE (3 ROWS)")
+ checkSchemaNames(df, Seq("NAME"))
+ checkPushedInfo(df, "PushedFilters: [], PushedLimit: LIMIT 3, ")
+ checkAnswer(df, Seq(Row("amy"), Row("alex"), Row("cathy")))
}
- test("scan with column pruning") {
- val df = spark.table("h2.test.people").select("id")
+ private def checkSchemaNames(df: DataFrame, names: Seq[String]): Unit = {
val scan = df.queryExecution.optimizedPlan.collectFirst {
case s: DataSourceV2ScanRelation => s
}.get
- assert(scan.schema.names.sameElements(Seq("ID")))
+ assert(scan.schema.names.sameElements(names))
+ }
+
+ test("simple scan with LIMIT") {
+ val df1 = spark.read.table("h2.test.employee")
+ .where($"dept" === 1).limit(1)
+ checkPushedInfo(df1,
+ "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], PushedLimit: LIMIT 1, ")
+ checkAnswer(df1, Seq(Row(1, "amy", 10000.00, 1000.0, true)))
+
+ val df2 = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .filter($"dept" > 1)
+ .limit(1)
+ checkPushedInfo(df2,
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ")
+ checkAnswer(df2, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
+
+ val df3 = sql("SELECT name FROM h2.test.employee WHERE dept > 1 LIMIT 1")
+ checkSchemaNames(df3, Seq("NAME"))
+ checkPushedInfo(df3,
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], PushedLimit: LIMIT 1, ")
+ checkAnswer(df3, Seq(Row("alex")))
+
+ val df4 = spark.read
+ .table("h2.test.employee")
+ .groupBy("DEPT").sum("SALARY")
+ .limit(1)
+ checkPushedInfo(df4,
+ "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT], ")
+ checkAnswer(df4, Seq(Row(1, 19000.00)))
+
+ val name = udf { (x: String) => x.matches("cat|dav|amy") }
+ val sub = udf { (x: String) => x.substring(0, 3) }
+ val df5 = spark.read
+ .table("h2.test.employee")
+ .select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
+ .filter(name($"shortName"))
+ .limit(1)
+ // LIMIT is pushed down only if all the filters are pushed down
+ checkPushedInfo(df5, "PushedFilters: [], ")
+ checkAnswer(df5, Seq(Row(10000.00, 1000.0, "amy")))
+ }
+
+ private def checkSortRemoved(df: DataFrame, removed: Boolean = true): Unit = {
+ val sorts = df.queryExecution.optimizedPlan.collect {
+ case s: Sort => s
+ }
+ if (removed) {
+ assert(sorts.isEmpty)
+ } else {
+ assert(sorts.nonEmpty)
+ }
+ }
+
+ test("simple scan with top N") {
+ val df1 = spark.read
+ .table("h2.test.employee")
+ .sort("salary")
+ .limit(1)
+ checkSortRemoved(df1)
+ checkPushedInfo(df1,
+ "PushedFilters: [], PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ")
+ checkAnswer(df1, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))
+
+ val df2 = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "1")
+ .table("h2.test.employee")
+ .where($"dept" === 1)
+ .orderBy($"salary")
+ .limit(1)
+ checkSortRemoved(df2)
+ checkPushedInfo(df2, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], " +
+ "PushedTopN: ORDER BY [salary ASC NULLS FIRST] LIMIT 1, ")
+ checkAnswer(df2, Seq(Row(1, "cathy", 9000.00, 1200.0, false)))
+
+ val df3 = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .filter($"dept" > 1)
+ .orderBy($"salary".desc)
+ .limit(1)
+ checkSortRemoved(df3, false)
+ checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
+ "PushedTopN: ORDER BY [salary DESC NULLS LAST] LIMIT 1, ")
+ checkAnswer(df3, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
+
+ val df4 =
+ sql("SELECT name FROM h2.test.employee WHERE dept > 1 ORDER BY salary NULLS LAST LIMIT 1")
+ checkSchemaNames(df4, Seq("NAME"))
+ checkSortRemoved(df4)
+ checkPushedInfo(df4, "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
+ "PushedTopN: ORDER BY [salary ASC NULLS LAST] LIMIT 1, ")
+ checkAnswer(df4, Seq(Row("david")))
+
+ val df5 = spark.read.table("h2.test.employee")
+ .where($"dept" === 1).orderBy($"salary")
+ checkSortRemoved(df5, false)
+ checkPushedInfo(df5, "PushedFilters: [DEPT IS NOT NULL, DEPT = 1], ")
+ checkAnswer(df5,
+ Seq(Row(1, "cathy", 9000.00, 1200.0, false), Row(1, "amy", 10000.00, 1000.0, true)))
+
+ val df6 = spark.read
+ .table("h2.test.employee")
+ .groupBy("DEPT").sum("SALARY")
+ .orderBy("DEPT")
+ .limit(1)
+ checkSortRemoved(df6, false)
+ checkPushedInfo(df6, "PushedAggregates: [SUM(SALARY)]," +
+ " PushedFilters: [], PushedGroupByColumns: [DEPT], ")
+ checkAnswer(df6, Seq(Row(1, 19000.00)))
+
+ val name = udf { (x: String) => x.matches("cat|dav|amy") }
+ val sub = udf { (x: String) => x.substring(0, 3) }
+ val df7 = spark.read
+ .table("h2.test.employee")
+ .select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
+ .filter(name($"shortName"))
+ .sort($"SALARY".desc)
+ .limit(1)
+ // LIMIT is pushed down only if all the filters are pushed down
+ checkSortRemoved(df7, false)
+ checkPushedInfo(df7, "PushedFilters: [], ")
+ checkAnswer(df7, Seq(Row(10000.00, 1000.0, "amy")))
+
+ val df8 = spark.read
+ .table("h2.test.employee")
+ .sort(sub($"NAME"))
+ .limit(1)
+ checkSortRemoved(df8, false)
+ checkPushedInfo(df8, "PushedFilters: [], ")
+ checkAnswer(df8, Seq(Row(2, "alex", 12000.00, 1200.0, false)))
+ }
+
+ test("simple scan with top N: order by with alias") {
+ val df1 = spark.read
+ .table("h2.test.employee")
+ .select($"NAME", $"SALARY".as("mySalary"))
+ .sort("mySalary")
+ .limit(1)
+ checkSortRemoved(df1)
+ checkPushedInfo(df1,
+ "PushedFilters: [], PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ")
+ checkAnswer(df1, Seq(Row("cathy", 9000.00)))
+
+ val df2 = spark.read
+ .table("h2.test.employee")
+ .select($"DEPT", $"NAME", $"SALARY".as("mySalary"))
+ .filter($"DEPT" > 1)
+ .sort("mySalary")
+ .limit(1)
+ checkSortRemoved(df2)
+ checkPushedInfo(df2,
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 1], " +
+ "PushedTopN: ORDER BY [SALARY ASC NULLS FIRST] LIMIT 1, ")
+ checkAnswer(df2, Seq(Row(2, "david", 10000.00)))
+ }
+
+ test("scan with filter push-down") {
+ val df = spark.table("h2.test.people").filter($"id" > 1)
+ checkFiltersRemoved(df)
+ checkPushedInfo(df, "PushedFilters: [ID IS NOT NULL, ID > 1], ")
+ checkAnswer(df, Row("mary", 2))
+
+ val df2 = spark.table("h2.test.employee").filter($"name".isin("amy", "cathy"))
+ checkFiltersRemoved(df2)
+ checkPushedInfo(df2, "PushedFilters: [NAME IN ('amy', 'cathy')]")
+ checkAnswer(df2, Seq(Row(1, "amy", 10000, 1000, true), Row(1, "cathy", 9000, 1200, false)))
+
+ val df3 = spark.table("h2.test.employee").filter($"name".startsWith("a"))
+ checkFiltersRemoved(df3)
+ checkPushedInfo(df3, "PushedFilters: [NAME IS NOT NULL, NAME LIKE 'a%']")
+ checkAnswer(df3, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false)))
+
+ val df4 = spark.table("h2.test.employee").filter($"is_manager")
+ checkFiltersRemoved(df4)
+ checkPushedInfo(df4, "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true]")
+ checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "david", 10000, 1300, true),
+ Row(6, "jen", 12000, 1200, true)))
+
+ val df5 = spark.table("h2.test.employee").filter($"is_manager".and($"salary" > 10000))
+ checkFiltersRemoved(df5)
+ checkPushedInfo(df5, "PushedFilters: [IS_MANAGER IS NOT NULL, SALARY IS NOT NULL, " +
+ "IS_MANAGER = true, SALARY > 10000.00]")
+ checkAnswer(df5, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ val df6 = spark.table("h2.test.employee").filter($"is_manager".or($"salary" > 10000))
+ checkFiltersRemoved(df6)
+ checkPushedInfo(df6, "PushedFilters: [(IS_MANAGER = true) OR (SALARY > 10000.00)], ")
+ checkAnswer(df6, Seq(Row(1, "amy", 10000, 1000, true), Row(2, "alex", 12000, 1200, false),
+ Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true)))
+
+ val df7 = spark.table("h2.test.employee").filter(not($"is_manager") === true)
+ checkFiltersRemoved(df7)
+ checkPushedInfo(df7, "PushedFilters: [IS_MANAGER IS NOT NULL, NOT (IS_MANAGER = true)], ")
+ checkAnswer(df7, Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "alex", 12000, 1200, false)))
+
+ val df8 = spark.table("h2.test.employee").filter($"is_manager" === true)
+ checkFiltersRemoved(df8)
+ checkPushedInfo(df8, "PushedFilters: [IS_MANAGER IS NOT NULL, IS_MANAGER = true], ")
+ checkAnswer(df8, Seq(Row(1, "amy", 10000, 1000, true),
+ Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true)))
+
+ val df9 = spark.table("h2.test.employee")
+ .filter(when($"dept" > 1, true).when($"is_manager", false).otherwise($"dept" > 3))
+ checkFiltersRemoved(df9)
+ checkPushedInfo(df9, "PushedFilters: [CASE WHEN DEPT > 1 THEN TRUE " +
+ "WHEN IS_MANAGER = true THEN FALSE ELSE DEPT > 3 END], ")
+ checkAnswer(df9, Seq(Row(2, "alex", 12000, 1200, false),
+ Row(2, "david", 10000, 1300, true), Row(6, "jen", 12000, 1200, true)))
+
+ val df10 = spark.table("h2.test.people")
+ .select($"NAME".as("myName"), $"ID".as("myID"))
+ .filter($"myID" > 1)
+ checkFiltersRemoved(df10)
+ checkPushedInfo(df10, "PushedFilters: [ID IS NOT NULL, ID > 1], ")
+ checkAnswer(df10, Row("mary", 2))
+ }
+
+ test("scan with filter push-down with ansi mode") {
+ Seq(false, true).foreach { ansiMode =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+ val df = spark.table("h2.test.people").filter($"id" + 1 > 1)
+ checkFiltersRemoved(df, ansiMode)
+ val expectedPlanFragment = if (ansiMode) {
+ "PushedFilters: [ID IS NOT NULL, (ID + 1) > 1]"
+ } else {
+ "PushedFilters: [ID IS NOT NULL]"
+ }
+ checkPushedInfo(df, expectedPlanFragment)
+ checkAnswer(df, Seq(Row("fred", 1), Row("mary", 2)))
+
+ val df2 = spark.table("h2.test.people").filter($"id" + Int.MaxValue > 1)
+ checkFiltersRemoved(df2, ansiMode)
+ val expectedPlanFragment2 = if (ansiMode) {
+ "PushedFilters: [ID IS NOT NULL, (ID + 2147483647) > 1], "
+ } else {
+ "PushedFilters: [ID IS NOT NULL], "
+ }
+ checkPushedInfo(df2, expectedPlanFragment2)
+ if (ansiMode) {
+ val e = intercept[SparkException] {
+ checkAnswer(df2, Seq.empty)
+ }
+ assert(e.getMessage.contains(
+ "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\""))
+ } else {
+ checkAnswer(df2, Seq.empty)
+ }
+
+ val df3 = sql("""
+ |SELECT * FROM h2.test.employee
+ |WHERE (CASE WHEN SALARY > 10000 THEN BONUS ELSE BONUS + 200 END) > 1200
+ |""".stripMargin)
+
+ checkFiltersRemoved(df3, ansiMode)
+ val expectedPlanFragment3 = if (ansiMode) {
+ "PushedFilters: [(CASE WHEN SALARY > 10000.00 THEN BONUS" +
+ " ELSE BONUS + 200.0 END) > 1200.0]"
+ } else {
+ "PushedFilters: []"
+ }
+ checkPushedInfo(df3, expectedPlanFragment3)
+ checkAnswer(df3,
+ Seq(Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true)))
+
+ val df4 = spark.table("h2.test.employee")
+ .filter(($"salary" > 1000d).and($"salary" < 12000d))
+ checkFiltersRemoved(df4, ansiMode)
+ val expectedPlanFragment4 = if (ansiMode) {
+ "PushedFilters: [SALARY IS NOT NULL, " +
+ "CAST(SALARY AS double) > 1000.0, CAST(SALARY AS double) < 12000.0], "
+ } else {
+ "PushedFilters: [SALARY IS NOT NULL], "
+ }
+ checkPushedInfo(df4, expectedPlanFragment4)
+ checkAnswer(df4, Seq(Row(1, "amy", 10000, 1000, true),
+ Row(1, "cathy", 9000, 1200, false), Row(2, "david", 10000, 1300, true)))
+
+ val df5 = spark.table("h2.test.employee")
+ .filter(abs($"dept" - 3) > 1)
+ .filter(coalesce($"salary", $"bonus") > 2000)
+ checkFiltersRemoved(df5, ansiMode)
+ val expectedPlanFragment5 = if (ansiMode) {
+ "PushedFilters: [DEPT IS NOT NULL, ABS(DEPT - 3) > 1, " +
+ "(COALESCE(CAST(SALARY AS double), BONUS)) > 2000.0]"
+ } else {
+ "PushedFilters: [DEPT IS NOT NULL]"
+ }
+ checkPushedInfo(df5, expectedPlanFragment5)
+ checkAnswer(df5, Seq(Row(1, "amy", 10000, 1000, true),
+ Row(1, "cathy", 9000, 1200, false), Row(6, "jen", 12000, 1200, true)))
+
+ val df6 = spark.table("h2.test.employee")
+ .filter(ln($"dept") > 1)
+ .filter(exp($"salary") > 2000)
+ .filter(pow($"dept", 2) > 4)
+ .filter(sqrt($"salary") > 100)
+ .filter(floor($"dept") > 1)
+ .filter(ceil($"dept") > 1)
+ checkFiltersRemoved(df6, ansiMode)
+ val expectedPlanFragment6 = if (ansiMode) {
+ "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL, " +
+ "LN(CAST(DEPT AS double)) > 1.0, EXP(CAST(SALARY AS double)...,"
+ } else {
+ "PushedFilters: [DEPT IS NOT NULL, SALARY IS NOT NULL]"
+ }
+ checkPushedInfo(df6, expectedPlanFragment6)
+ checkAnswer(df6, Seq(Row(6, "jen", 12000, 1200, true)))
+
+ // H2 does not support width_bucket
+ val df7 = sql("""
+ |SELECT * FROM h2.test.employee
+ |WHERE width_bucket(dept, 1, 6, 3) > 1
+ |""".stripMargin)
+ checkFiltersRemoved(df7, false)
+ checkPushedInfo(df7, "PushedFilters: [DEPT IS NOT NULL]")
+ checkAnswer(df7, Seq(Row(6, "jen", 12000, 1200, true)))
+ }
+ }
+ }
+
+ test("scan with column pruning") {
+ val df = spark.table("h2.test.people").select("id")
+ checkSchemaNames(df, Seq("ID"))
checkAnswer(df, Seq(Row(1), Row(2)))
}
test("scan with filter push-down and column pruning") {
val df = spark.table("h2.test.people").filter($"id" > 1).select("name")
- val filters = df.queryExecution.optimizedPlan.collect {
- case f: Filter => f
- }
- assert(filters.isEmpty)
- val scan = df.queryExecution.optimizedPlan.collectFirst {
- case s: DataSourceV2ScanRelation => s
- }.get
- assert(scan.schema.names.sameElements(Seq("NAME")))
+ checkFiltersRemoved(df)
+ checkSchemaNames(df, Seq("NAME"))
checkAnswer(df, Row("mary"))
}
@@ -168,7 +520,8 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
test("show tables") {
checkAnswer(sql("SHOW TABLES IN h2.test"),
Seq(Row("test", "people", false), Row("test", "empty_table", false),
- Row("test", "employee", false)))
+ Row("test", "employee", false), Row("test", "item", false), Row("test", "dept", false),
+ Row("test", "person", false), Row("test", "view1", false), Row("test", "view2", false)))
}
test("SQL API: create table as select") {
@@ -238,167 +591,195 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
}
}
- test("scan with aggregate push-down: MAX MIN with filter and group by") {
- val df = sql("select MAX(SaLaRY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
- " group by DePt")
- val filters = df.queryExecution.optimizedPlan.collect {
- case f: Filter => f
+ private def checkAggregateRemoved(df: DataFrame, removed: Boolean = true): Unit = {
+ val aggregates = df.queryExecution.optimizedPlan.collect {
+ case agg: Aggregate => agg
}
- assert(filters.isEmpty)
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
- "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
- "PushedGroupby: [DEPT]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ if (removed) {
+ assert(aggregates.isEmpty)
+ } else {
+ assert(aggregates.nonEmpty)
}
- checkAnswer(df, Seq(Row(10000, 1000), Row(12000, 1200), Row(12000, 1200)))
}
- test("scan with aggregate push-down: MAX MIN with filter without group by") {
- val df = sql("select MAX(ID), MIN(ID) FROM h2.test.people where id > 0")
+ test("scan with aggregate push-down: MAX AVG with filter and group by") {
+ val df = sql("select MAX(SaLaRY), AVG(BONUS) FROM h2.test.employee where dept > 0" +
+ " group by DePt")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), AVG(BONUS)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], " +
+ "PushedGroupByColumns: [DEPT], ")
+ checkAnswer(df, Seq(Row(10000, 1100.0), Row(12000, 1250.0), Row(12000, 1200.0)))
+ }
+
+ private def checkFiltersRemoved(df: DataFrame, removed: Boolean = true): Unit = {
val filters = df.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
- assert(filters.isEmpty)
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [MAX(ID), MIN(ID)], " +
- "PushedFilters: [IsNotNull(ID), GreaterThan(ID,0)], " +
- "PushedGroupby: []"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ if (removed) {
+ assert(filters.isEmpty)
+ } else {
+ assert(filters.nonEmpty)
+ }
+ }
+
+ test("scan with aggregate push-down: MAX AVG with filter without group by") {
+ val df = sql("select MAX(ID), AVG(ID) FROM h2.test.people where id > 0")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [MAX(ID), AVG(ID)], " +
+ "PushedFilters: [ID IS NOT NULL, ID > 0], " +
+ "PushedGroupByColumns: [], ")
+ checkAnswer(df, Seq(Row(2, 1.5)))
+ }
+
+ test("partitioned scan with aggregate push-down: complete push-down only") {
+ withTempView("v") {
+ spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .createTempView("v")
+ val df = sql("select AVG(SALARY) FROM v GROUP BY name")
+ // Partitioned JDBC Scan doesn't support complete aggregate push-down, and AVG requires
+ // complete push-down so aggregate is not pushed at the end.
+ checkAggregateRemoved(df, removed = false)
+ checkAnswer(df, Seq(Row(9000.0), Row(10000.0), Row(10000.0), Row(12000.0), Row(12000.0)))
}
- checkAnswer(df, Seq(Row(2, 1)))
}
test("scan with aggregate push-down: aggregate + number") {
val df = sql("select MAX(SALARY) + 1 FROM h2.test.employee")
+ checkAggregateRemoved(df)
df.queryExecution.optimizedPlan.collect {
case _: DataSourceV2ScanRelation =>
val expected_plan_fragment =
"PushedAggregates: [MAX(SALARY)]"
checkKeywordsExistsInExplain(df, expected_plan_fragment)
}
+ checkPushedInfo(df, "PushedAggregates: [MAX(SALARY)]")
checkAnswer(df, Seq(Row(12001)))
}
test("scan with aggregate push-down: COUNT(*)") {
val df = sql("select COUNT(*) FROM h2.test.employee")
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [COUNT(*)]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
- }
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [COUNT(*)]")
checkAnswer(df, Seq(Row(5)))
}
test("scan with aggregate push-down: COUNT(col)") {
val df = sql("select COUNT(DEPT) FROM h2.test.employee")
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [COUNT(DEPT)]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
- }
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [COUNT(DEPT)]")
checkAnswer(df, Seq(Row(5)))
}
test("scan with aggregate push-down: COUNT(DISTINCT col)") {
val df = sql("select COUNT(DISTINCT DEPT) FROM h2.test.employee")
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [COUNT(DISTINCT DEPT)]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
- }
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [COUNT(DISTINCT DEPT)]")
+ checkAnswer(df, Seq(Row(3)))
+ }
+
+ test("scan with aggregate push-down: cannot partial push down COUNT(DISTINCT col)") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .agg(count_distinct($"DEPT"))
+ checkAggregateRemoved(df, false)
checkAnswer(df, Seq(Row(3)))
}
test("scan with aggregate push-down: SUM without filer and group by") {
val df = sql("SELECT SUM(SALARY) FROM h2.test.employee")
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [SUM(SALARY)]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
- }
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [SUM(SALARY)]")
checkAnswer(df, Seq(Row(53000)))
}
test("scan with aggregate push-down: DISTINCT SUM without filer and group by") {
val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee")
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [SUM(DISTINCT SALARY)]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
- }
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)]")
checkAnswer(df, Seq(Row(31000)))
}
test("scan with aggregate push-down: SUM with group by") {
val df = sql("SELECT SUM(SALARY) FROM h2.test.employee GROUP BY DEPT")
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [SUM(SALARY)], " +
- "PushedFilters: [], " +
- "PushedGroupby: [DEPT]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
- }
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [SUM(SALARY)], " +
+ "PushedFilters: [], PushedGroupByColumns: [DEPT], ")
checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
}
test("scan with aggregate push-down: DISTINCT SUM with group by") {
val df = sql("SELECT SUM(DISTINCT SALARY) FROM h2.test.employee GROUP BY DEPT")
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [SUM(DISTINCT SALARY)], " +
- "PushedFilters: [], " +
- "PushedGroupby: [DEPT]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
- }
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [SUM(DISTINCT SALARY)], " +
+ "PushedFilters: [], PushedGroupByColumns: [DEPT]")
checkAnswer(df, Seq(Row(19000), Row(22000), Row(12000)))
}
test("scan with aggregate push-down: with multiple group by columns") {
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
" group by DEPT, NAME")
- val filters11 = df.queryExecution.optimizedPlan.collect {
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]")
+ checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300),
+ Row(10000, 1000), Row(12000, 1200)))
+ }
+
+ test("scan with aggregate push-down: with concat multiple group key in project") {
+ val df1 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) FROM h2.test.employee" +
+ " where dept > 0 group by DEPT, NAME")
+ val filters1 = df1.queryExecution.optimizedPlan.collect {
case f: Filter => f
}
- assert(filters11.isEmpty)
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
- "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
- "PushedGroupby: [DEPT, NAME]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ assert(filters1.isEmpty)
+ checkAggregateRemoved(df1)
+ checkPushedInfo(df1, "PushedAggregates: [MAX(SALARY)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]")
+ checkAnswer(df1, Seq(Row("1#amy", 10000), Row("1#cathy", 9000), Row("2#alex", 12000),
+ Row("2#david", 10000), Row("6#jen", 12000)))
+
+ val df2 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + MIN(BONUS)" +
+ " FROM h2.test.employee where dept > 0 group by DEPT, NAME")
+ val filters2 = df2.queryExecution.optimizedPlan.collect {
+ case f: Filter => f
}
- checkAnswer(df, Seq(Row(9000, 1200), Row(12000, 1200), Row(10000, 1300),
- Row(10000, 1000), Row(12000, 1200)))
+ assert(filters2.isEmpty)
+ checkAggregateRemoved(df2)
+ checkPushedInfo(df2, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT, NAME]")
+ checkAnswer(df2, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200),
+ Row("2#david", 11300), Row("6#jen", 13200)))
+
+ val df3 = sql("select concat_ws('#', DEPT, NAME), MAX(SALARY) + MIN(BONUS)" +
+ " FROM h2.test.employee where dept > 0 group by concat_ws('#', DEPT, NAME)")
+ checkFiltersRemoved(df3)
+ checkAggregateRemoved(df3, false)
+ checkPushedInfo(df3, "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], ")
+ checkAnswer(df3, Seq(Row("1#amy", 11000), Row("1#cathy", 10200), Row("2#alex", 13200),
+ Row("2#david", 11300), Row("6#jen", 13200)))
}
test("scan with aggregate push-down: with having clause") {
val df = sql("select MAX(SALARY), MIN(BONUS) FROM h2.test.employee where dept > 0" +
" group by DEPT having MIN(BONUS) > 1000")
- val filters = df.queryExecution.optimizedPlan.collect {
- case f: Filter => f // filter over aggregate not push down
- }
- assert(filters.nonEmpty)
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
- "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
- "PushedGroupby: [DEPT]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
- }
+ // filter over aggregate not push down
+ checkFiltersRemoved(df, false)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [MAX(SALARY), MIN(BONUS)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]")
checkAnswer(df, Seq(Row(12000, 1200), Row(12000, 1200)))
}
@@ -406,14 +787,9 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
val df = sql("select * from h2.test.employee")
.groupBy($"DEPT")
.min("SALARY").as("total")
- df.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [MIN(SALARY)], " +
- "PushedFilters: [], " +
- "PushedGroupby: [DEPT]"
- checkKeywordsExistsInExplain(df, expected_plan_fragment)
- }
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [MIN(SALARY)], " +
+ "PushedFilters: [], PushedGroupByColumns: [DEPT]")
checkAnswer(df, Seq(Row(1, 9000), Row(2, 10000), Row(6, 12000)))
}
@@ -425,18 +801,10 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
.agg(sum($"SALARY").as("total"))
.filter($"total" > 1000)
.orderBy($"total")
- val filters = query.queryExecution.optimizedPlan.collect {
- case f: Filter => f
- }
- assert(filters.nonEmpty) // filter over aggregate not pushed down
- query.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [SUM(SALARY)], " +
- "PushedFilters: [IsNotNull(DEPT), GreaterThan(DEPT,0)], " +
- "PushedGroupby: [DEPT]"
- checkKeywordsExistsInExplain(query, expected_plan_fragment)
- }
+ checkFiltersRemoved(query, false)// filter over aggregate not pushed down
+ checkAggregateRemoved(query)
+ checkPushedInfo(query, "PushedAggregates: [SUM(SALARY)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]")
checkAnswer(query, Seq(Row(6, 12000), Row(1, 19000), Row(2, 22000)))
}
@@ -444,25 +812,366 @@ class JDBCV2Suite extends QueryTest with SharedSparkSession with ExplainSuiteHel
val df = spark.table("h2.test.employee")
val decrease = udf { (x: Double, y: Double) => x - y }
val query = df.select(decrease(sum($"SALARY"), sum($"BONUS")).as("value"))
- query.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: [SUM(SALARY), SUM(BONUS)"
- checkKeywordsExistsInExplain(query, expected_plan_fragment)
- }
+ checkAggregateRemoved(query)
+ checkPushedInfo(query, "PushedAggregates: [SUM(SALARY), SUM(BONUS)], ")
checkAnswer(query, Seq(Row(47100.0)))
}
- test("scan with aggregate push-down: aggregate over alias NOT push down") {
- val cols = Seq("a", "b", "c", "d")
+ test("scan with aggregate push-down: partition columns are same as group by columns") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .groupBy($"dept")
+ .count()
+ checkAggregateRemoved(df)
+ checkAnswer(df, Seq(Row(1, 2), Row(2, 2), Row(6, 1)))
+ }
+
+ test("scan with aggregate push-down: VAR_POP VAR_SAMP with filter and group by") {
+ val df = sql("select VAR_POP(bonus), VAR_SAMP(bonus) FROM h2.test.employee where dept > 0" +
+ " group by DePt")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [VAR_POP(BONUS), VAR_SAMP(BONUS)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]")
+ checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
+ }
+
+ test("scan with aggregate push-down: STDDEV_POP STDDEV_SAMP with filter and group by") {
+ val df = sql("select STDDEV_POP(bonus), STDDEV_SAMP(bonus) FROM h2.test.employee" +
+ " where dept > 0 group by DePt")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [STDDEV_POP(BONUS), STDDEV_SAMP(BONUS)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]")
+ checkAnswer(df, Seq(Row(100d, 141.4213562373095d), Row(50d, 70.71067811865476d), Row(0d, null)))
+ }
+
+ test("scan with aggregate push-down: COVAR_POP COVAR_SAMP with filter and group by") {
+ val df = sql("select COVAR_POP(bonus, bonus), COVAR_SAMP(bonus, bonus)" +
+ " FROM h2.test.employee where dept > 0 group by DePt")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [COVAR_POP(BONUS, BONUS), COVAR_SAMP(BONUS, BONUS)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]")
+ checkAnswer(df, Seq(Row(10000d, 20000d), Row(2500d, 5000d), Row(0d, null)))
+ }
+
+ test("scan with aggregate push-down: CORR with filter and group by") {
+ val df = sql("select CORR(bonus, bonus) FROM h2.test.employee where dept > 0" +
+ " group by DePt")
+ checkFiltersRemoved(df)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [CORR(BONUS, BONUS)], " +
+ "PushedFilters: [DEPT IS NOT NULL, DEPT > 0], PushedGroupByColumns: [DEPT]")
+ checkAnswer(df, Seq(Row(1d), Row(1d), Row(null)))
+ }
+
+ test("scan with aggregate push-down: aggregate over alias push down") {
+ val cols = Seq("a", "b", "c", "d", "e")
val df1 = sql("select * from h2.test.employee").toDF(cols: _*)
val df2 = df1.groupBy().sum("c")
+ checkAggregateRemoved(df2)
df2.queryExecution.optimizedPlan.collect {
- case _: DataSourceV2ScanRelation =>
- val expected_plan_fragment =
- "PushedAggregates: []" // aggregate over alias not push down
- checkKeywordsExistsInExplain(df2, expected_plan_fragment)
+ case relation: DataSourceV2ScanRelation =>
+ val expectedPlanFragment =
+ "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: []"
+ checkKeywordsExistsInExplain(df2, expectedPlanFragment)
+ relation.scan match {
+ case v1: V1ScanWrapper =>
+ assert(v1.pushedDownOperators.aggregation.nonEmpty)
+ }
}
checkAnswer(df2, Seq(Row(53000.00)))
}
+
+ test("scan with aggregate push-down: aggregate with partially pushed down filters" +
+ "will NOT push down") {
+ val df = spark.table("h2.test.employee")
+ val name = udf { (x: String) => x.matches("cat|dav|amy") }
+ val sub = udf { (x: String) => x.substring(0, 3) }
+ val query = df.select($"SALARY", $"BONUS", sub($"NAME").as("shortName"))
+ .filter("SALARY > 100")
+ .filter(name($"shortName"))
+ .agg(sum($"SALARY").as("sum_salary"))
+ checkAggregateRemoved(query, false)
+ query.queryExecution.optimizedPlan.collect {
+ case relation: DataSourceV2ScanRelation => relation.scan match {
+ case v1: V1ScanWrapper =>
+ assert(v1.pushedDownOperators.aggregation.isEmpty)
+ }
+ }
+ checkAnswer(query, Seq(Row(29000.0)))
+ }
+
+ test("scan with aggregate push-down: aggregate function with CASE WHEN") {
+ val df = sql(
+ """
+ |SELECT
+ | COUNT(CASE WHEN SALARY > 8000 AND SALARY < 10000 THEN SALARY ELSE 0 END),
+ | COUNT(CASE WHEN SALARY > 8000 AND SALARY <= 13000 THEN SALARY ELSE 0 END),
+ | COUNT(CASE WHEN SALARY > 11000 OR SALARY < 10000 THEN SALARY ELSE 0 END),
+ | COUNT(CASE WHEN SALARY >= 12000 OR SALARY < 9000 THEN SALARY ELSE 0 END),
+ | COUNT(CASE WHEN SALARY >= 12000 OR NOT(SALARY >= 9000) THEN SALARY ELSE 0 END),
+ | MAX(CASE WHEN NOT(SALARY > 10000) AND SALARY >= 8000 THEN SALARY ELSE 0 END),
+ | MAX(CASE WHEN NOT(SALARY > 10000) OR SALARY > 8000 THEN SALARY ELSE 0 END),
+ | MAX(CASE WHEN NOT(SALARY > 10000) AND NOT(SALARY < 8000) THEN SALARY ELSE 0 END),
+ | MAX(CASE WHEN NOT(SALARY != 0) OR NOT(SALARY < 8000) THEN SALARY ELSE 0 END),
+ | MAX(CASE WHEN NOT(SALARY > 8000 AND SALARY > 8000) THEN 0 ELSE SALARY END),
+ | MIN(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NULL) THEN SALARY ELSE 0 END),
+ | SUM(CASE WHEN SALARY > 10000 THEN 2 WHEN SALARY > 8000 THEN 1 END),
+ | AVG(CASE WHEN NOT(SALARY > 8000 OR SALARY IS NOT NULL) THEN SALARY ELSE 0 END)
+ |FROM h2.test.employee GROUP BY DEPT
+ """.stripMargin)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df,
+ "PushedAggregates: [COUNT(CASE WHEN (SALARY > 8000.00) AND (SALARY < 10000.00)" +
+ " THEN SALARY ELSE 0.00 END), COUNT(CAS..., " +
+ "PushedFilters: [], " +
+ "PushedGroupByColumns: [DEPT], ")
+ checkAnswer(df, Seq(Row(1, 1, 1, 1, 1, 0d, 12000d, 0d, 12000d, 12000d, 0d, 2, 0d),
+ Row(2, 2, 2, 2, 2, 10000d, 10000d, 10000d, 10000d, 10000d, 0d, 2, 0d),
+ Row(2, 2, 2, 2, 2, 10000d, 12000d, 10000d, 12000d, 12000d, 0d, 3, 0d)))
+ }
+
+ test("scan with aggregate push-down: aggregate function with binary arithmetic") {
+ Seq(false, true).foreach { ansiMode =>
+ withSQLConf(SQLConf.ANSI_ENABLED.key -> ansiMode.toString) {
+ val df = sql("SELECT SUM(2147483647 + DEPT) FROM h2.test.employee")
+ checkAggregateRemoved(df, ansiMode)
+ val expectedPlanFragment = if (ansiMode) {
+ "PushedAggregates: [SUM(2147483647 + DEPT)], " +
+ "PushedFilters: [], " +
+ "PushedGroupByColumns: []"
+ } else {
+ "PushedFilters: []"
+ }
+ checkPushedInfo(df, expectedPlanFragment)
+ if (ansiMode) {
+ val e = intercept[SparkException] {
+ checkAnswer(df, Seq(Row(-10737418233L)))
+ }
+ assert(e.getMessage.contains(
+ "org.h2.jdbc.JdbcSQLDataException: Numeric value out of range: \"2147483648\""))
+ } else {
+ checkAnswer(df, Seq(Row(-10737418233L)))
+ }
+ }
+ }
+ }
+
+ test("scan with aggregate push-down: aggregate function with UDF") {
+ val df = spark.table("h2.test.employee")
+ val decrease = udf { (x: Double, y: Double) => x - y }
+ val query = df.select(sum(decrease($"SALARY", $"BONUS")).as("value"))
+ checkAggregateRemoved(query, false)
+ checkPushedInfo(query, "PushedFilters: []")
+ checkAnswer(query, Seq(Row(47100.0)))
+ }
+
+ test("scan with aggregate push-down: partition columns with multi group by columns") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .groupBy($"dept", $"name")
+ .count()
+ checkAggregateRemoved(df, false)
+ checkAnswer(df, Seq(Row(1, "amy", 1), Row(1, "cathy", 1),
+ Row(2, "alex", 1), Row(2, "david", 1), Row(6, "jen", 1)))
+ }
+
+ test("scan with aggregate push-down: partition columns is different from group by columns") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .groupBy($"name")
+ .count()
+ checkAggregateRemoved(df, false)
+ checkAnswer(df,
+ Seq(Row("alex", 1), Row("amy", 1), Row("cathy", 1), Row("david", 1), Row("jen", 1)))
+ }
+
+ test("column name with composite field") {
+ checkAnswer(sql("SELECT `dept id` FROM h2.test.dept"), Seq(Row(1), Row(2)))
+ val df = sql("SELECT COUNT(`dept id`) FROM h2.test.dept")
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [COUNT(`dept id`)]")
+ checkAnswer(df, Seq(Row(2)))
+ }
+
+ test("column name with non-ascii") {
+ // scalastyle:off
+ checkAnswer(sql("SELECT `名` FROM h2.test.person"), Seq(Row(1), Row(2)))
+ val df = sql("SELECT COUNT(`名`) FROM h2.test.person")
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [COUNT(`名`)]")
+ checkAnswer(df, Seq(Row(2)))
+ // scalastyle:on
+ }
+
+ test("scan with aggregate push-down: complete push-down SUM, AVG, COUNT") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "1")
+ .table("h2.test.employee")
+ .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]")
+ checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5)))
+
+ val df2 = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "1")
+ .table("h2.test.employee")
+ .groupBy($"name")
+ .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+ checkAggregateRemoved(df)
+ checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), AVG(SALARY), COUNT(SALARY)]")
+ checkAnswer(df2, Seq(
+ Row("alex", 12000.00, 12000.000000, 1),
+ Row("amy", 10000.00, 10000.000000, 1),
+ Row("cathy", 9000.00, 9000.000000, 1),
+ Row("david", 10000.00, 10000.000000, 1),
+ Row("jen", 12000.00, 12000.000000, 1)))
+ }
+
+ test("scan with aggregate push-down: partial push-down SUM, AVG, COUNT") {
+ val df = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+ checkAggregateRemoved(df, false)
+ checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]")
+ checkAnswer(df, Seq(Row(53000.00, 10600.000000, 5)))
+
+ val df2 = spark.read
+ .option("partitionColumn", "dept")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .groupBy($"name")
+ .agg(sum($"SALARY").as("sum"), avg($"SALARY").as("avg"), count($"SALARY").as("count"))
+ checkAggregateRemoved(df, false)
+ checkPushedInfo(df, "PushedAggregates: [SUM(SALARY), COUNT(SALARY)]")
+ checkAnswer(df2, Seq(
+ Row("alex", 12000.00, 12000.000000, 1),
+ Row("amy", 10000.00, 10000.000000, 1),
+ Row("cathy", 9000.00, 9000.000000, 1),
+ Row("david", 10000.00, 10000.000000, 1),
+ Row("jen", 12000.00, 12000.000000, 1)))
+ }
+
+ test("SPARK-37895: JDBC push down with delimited special identifiers") {
+ val df = sql(
+ """SELECT h2.test.view1.`|col1`, h2.test.view1.`|col2`, h2.test.view2.`|col3`
+ |FROM h2.test.view1 LEFT JOIN h2.test.view2
+ |ON h2.test.view1.`|col1` = h2.test.view2.`|col1`""".stripMargin)
+ checkAnswer(df, Seq.empty[Row])
+ }
+
+ test("scan with aggregate push-down: complete push-down aggregate with alias") {
+ val df = spark.table("h2.test.employee")
+ .select($"DEPT", $"SALARY".as("mySalary"))
+ .groupBy($"DEPT")
+ .agg(sum($"mySalary").as("total"))
+ .filter($"total" > 1000)
+ checkAggregateRemoved(df)
+ checkPushedInfo(df,
+ "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]")
+ checkAnswer(df, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00)))
+
+ val df2 = spark.table("h2.test.employee")
+ .select($"DEPT".as("myDept"), $"SALARY".as("mySalary"))
+ .groupBy($"myDept")
+ .agg(sum($"mySalary").as("total"))
+ .filter($"total" > 1000)
+ checkAggregateRemoved(df2)
+ checkPushedInfo(df2,
+ "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [DEPT]")
+ checkAnswer(df2, Seq(Row(1, 19000.00), Row(2, 22000.00), Row(6, 12000.00)))
+ }
+
+ test("scan with aggregate push-down: partial push-down aggregate with alias") {
+ val df = spark.read
+ .option("partitionColumn", "DEPT")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .select($"NAME", $"SALARY".as("mySalary"))
+ .groupBy($"NAME")
+ .agg(sum($"mySalary").as("total"))
+ .filter($"total" > 1000)
+ checkAggregateRemoved(df, false)
+ checkPushedInfo(df,
+ "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]")
+ checkAnswer(df, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
+ Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
+
+ val df2 = spark.read
+ .option("partitionColumn", "DEPT")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.employee")
+ .select($"NAME".as("myName"), $"SALARY".as("mySalary"))
+ .groupBy($"myName")
+ .agg(sum($"mySalary").as("total"))
+ .filter($"total" > 1000)
+ checkAggregateRemoved(df2, false)
+ checkPushedInfo(df2,
+ "PushedAggregates: [SUM(SALARY)], PushedFilters: [], PushedGroupByColumns: [NAME]")
+ checkAnswer(df2, Seq(Row("alex", 12000.00), Row("amy", 10000.00),
+ Row("cathy", 9000.00), Row("david", 10000.00), Row("jen", 12000.00)))
+ }
+
+ test("scan with aggregate push-down: partial push-down AVG with overflow") {
+ def createDataFrame: DataFrame = spark.read
+ .option("partitionColumn", "id")
+ .option("lowerBound", "0")
+ .option("upperBound", "2")
+ .option("numPartitions", "2")
+ .table("h2.test.item")
+ .agg(avg($"PRICE").as("avg"))
+
+ Seq(true, false).foreach { ansiEnabled =>
+ withSQLConf((SQLConf.ANSI_ENABLED.key, ansiEnabled.toString)) {
+ val df = createDataFrame
+ checkAggregateRemoved(df, false)
+ df.queryExecution.optimizedPlan.collect {
+ case _: DataSourceV2ScanRelation =>
+ val expected_plan_fragment =
+ "PushedAggregates: [SUM(PRICE), COUNT(PRICE)]"
+ checkKeywordsExistsInExplain(df, expected_plan_fragment)
+ }
+ if (ansiEnabled) {
+ val e = intercept[SparkException] {
+ df.collect()
+ }
+ assert(e.getCause.isInstanceOf[ArithmeticException])
+ assert(e.getCause.getMessage.contains("cannot be represented as Decimal") ||
+ e.getCause.getMessage.contains("Overflow in sum of decimals"))
+ } else {
+ checkAnswer(df, Seq(Row(null)))
+ }
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index efa2773bfd692..79952e5a6c288 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -227,7 +227,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
JdbcDialects.registerDialect(testH2Dialect)
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
- val m = intercept[org.h2.jdbc.JdbcSQLException] {
+ val m = intercept[org.h2.jdbc.JdbcSQLSyntaxErrorException] {
df.write.option("createTableOptions", "ENGINE tableEngineName")
.jdbc(url1, "TEST.CREATETBLOPTS", properties)
}.getMessage
@@ -326,7 +326,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
test("save errors if wrong user/password combination") {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
- val e = intercept[org.h2.jdbc.JdbcSQLException] {
+ val e = intercept[org.h2.jdbc.JdbcSQLInvalidAuthorizationSpecException] {
df.write.format("jdbc")
.option("dbtable", "TEST.SAVETEST")
.option("url", url1)
@@ -427,7 +427,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
// verify the data types of the created table by reading the database catalog of H2
val query =
"""
- |(SELECT column_name, type_name, character_maximum_length
+ |(SELECT column_name, data_type, character_maximum_length
| FROM information_schema.columns WHERE table_name = 'DBCOLTYPETEST')
""".stripMargin
val rows = spark.read.jdbc(url1, query, properties).collect()
@@ -436,7 +436,7 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
val typeName = row.getString(1)
// For CHAR and VARCHAR, we also compare the max length
if (typeName.contains("CHAR")) {
- val charMaxLength = row.getInt(2)
+ val charMaxLength = row.getLong(2)
assert(expectedTypes(row.getString(0)) == s"$typeName($charMaxLength)")
} else {
assert(expectedTypes(row.getString(0)) == typeName)
@@ -452,15 +452,18 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
val df = spark.createDataFrame(sparkContext.parallelize(data), schema)
// out-of-order
- val expected1 = Map("id" -> "BIGINT", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)")
+ val expected1 =
+ Map("id" -> "BIGINT", "first#name" -> "CHARACTER VARYING(123)", "city" -> "CHARACTER(20)")
testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), id BIGINT, city CHAR(20)", expected1)
// partial schema
- val expected2 = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CHAR(20)")
+ val expected2 =
+ Map("id" -> "INTEGER", "first#name" -> "CHARACTER VARYING(123)", "city" -> "CHARACTER(20)")
testUserSpecifiedColTypes(df, "`first#name` VARCHAR(123), city CHAR(20)", expected2)
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") {
// should still respect the original column names
- val expected = Map("id" -> "INTEGER", "first#name" -> "VARCHAR(123)", "city" -> "CLOB")
+ val expected = Map("id" -> "INTEGER", "first#name" -> "CHARACTER VARYING(123)",
+ "city" -> "CHARACTER LARGE OBJECT(9223372036854775807)")
testUserSpecifiedColTypes(df, "`FiRsT#NaMe` VARCHAR(123)", expected)
}
@@ -470,7 +473,9 @@ class JDBCWriteSuite extends SharedSparkSession with BeforeAndAfter {
StructField("First#Name", StringType) ::
StructField("city", StringType) :: Nil)
val df = spark.createDataFrame(sparkContext.parallelize(data), schema)
- val expected = Map("id" -> "INTEGER", "First#Name" -> "VARCHAR(123)", "city" -> "CLOB")
+ val expected =
+ Map("id" -> "INTEGER", "First#Name" -> "CHARACTER VARYING(123)",
+ "city" -> "CHARACTER LARGE OBJECT(9223372036854775807)")
testUserSpecifiedColTypes(df, "`First#Name` VARCHAR(123)", expected)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 0e62be40607a1..ba0b599f2245d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -22,6 +22,7 @@ import java.net.URI
import java.nio.file.Files
import java.util.{Locale, UUID}
+import scala.collection.JavaConverters._
import scala.concurrent.duration._
import scala.language.implicitConversions
import scala.util.control.NonFatal
@@ -459,7 +460,9 @@ private[sql] trait SQLTestUtilsBase
*/
def getLocalDirSize(file: File): Long = {
assert(file.isDirectory)
- file.listFiles.filter(f => DataSourceUtils.isDataFile(f.getName)).map(_.length).sum
+ Files.walk(file.toPath).iterator().asScala
+ .filter(p => Files.isRegularFile(p) && DataSourceUtils.isDataFile(p.getFileName.toString))
+ .map(_.toFile.length).sum
}
}
diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml
index f1dcddd806525..dd3dabb82cc67 100644
--- a/sql/hive-thriftserver/pom.xml
+++ b/sql/hive-thriftserver/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../../pom.xml
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 82bdeaf4e6608..e6bb5d5f49dd2 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -22,7 +22,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../../pom.xml
diff --git a/streaming/pom.xml b/streaming/pom.xml
index 3a0f9a2f00c71..91db9435a87d4 100644
--- a/streaming/pom.xml
+++ b/streaming/pom.xml
@@ -21,7 +21,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../pom.xml
diff --git a/tools/pom.xml b/tools/pom.xml
index c2b09a8508e2a..2d5830ad83d1c 100644
--- a/tools/pom.xml
+++ b/tools/pom.xml
@@ -20,7 +20,7 @@
org.apache.spark
spark-parent_2.12
- 3.2.0-kylin-4.x-r60
+ 3.2.0-kylin-4.x-r61
../pom.xml