From 25a4c8e0c5c63ca4722b1da6182e0e0f0f48b73a Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 6 Apr 2016 15:48:28 +0200 Subject: [PATCH 1/7] [SPARK-14396][BUILD][HOT] Fix compilation against Scala 2.10 #### What changes were proposed in this pull request? This PR is to fix the compilation errors in Scala 2.10 build, as shown in the link: https://amplab.cs.berkeley.edu/jenkins/job/spark-master-compile-maven-scala-2.10/735/console ``` [error] /home/jenkins/workspace/spark-master-compile-maven-scala-2.10/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala:266: value contains is not a member of Option[String] [error] assert(desc.viewText.contains("SELECT * FROM tab1")) [error] ^ [error] /home/jenkins/workspace/spark-master-compile-maven-scala-2.10/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala:267: value contains is not a member of Option[String] [error] assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) [error] ^ [error] /home/jenkins/workspace/spark-master-compile-maven-scala-2.10/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala:293: value contains is not a member of Option[String] [error] assert(desc.viewText.contains("SELECT * FROM tab1")) [error] ^ [error] /home/jenkins/workspace/spark-master-compile-maven-scala-2.10/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala:294: value contains is not a member of Option[String] [error] assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) [error] ^ [error] four errors found [error] Compile failed at Apr 5, 2016 10:59:09 PM [10.502s] ``` #### How was this patch tested? Not sure how to trigger Scala 2.10 compilation in the test environment. Author: gatorsmile Closes #12201 from gatorsmile/buildBreak2.10. --- .../org/apache/spark/sql/hive/HiveDDLCommandSuite.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index b4e5d4adf1728..c5f01da4fabbb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -263,8 +263,8 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW) assert(desc.storage.locationUri.isEmpty) assert(desc.schema == Seq.empty[CatalogColumn]) - assert(desc.viewText.contains("SELECT * FROM tab1")) - assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) + assert(desc.viewText == Option("SELECT * FROM tab1")) + assert(desc.viewOriginalText == Option("SELECT * FROM tab1")) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat.isEmpty) assert(desc.storage.outputFormat.isEmpty) @@ -290,8 +290,8 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.schema == CatalogColumn("col1", null, nullable = true, None) :: CatalogColumn("col3", null, nullable = true, None) :: Nil) - assert(desc.viewText.contains("SELECT * FROM tab1")) - assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) + assert(desc.viewText == Option("SELECT * FROM tab1")) + assert(desc.viewOriginalText == Option("SELECT * FROM tab1")) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat.isEmpty) assert(desc.storage.outputFormat.isEmpty) From 24015199f46b5934d3000960538539495e025acf Mon Sep 17 00:00:00 2001 From: Victor Chima Date: Wed, 6 Apr 2016 15:27:46 +0100 Subject: [PATCH 2/7] Added omitted word in error message ## What changes were proposed in this pull request? Added an omitted word in the error message displayed by the Graphx Pregel API when `maxIterations <= 0` ## How was this patch tested? Manual test Author: Victor Chima Closes #12205 from blazy2k9/hotfix/pregel-error-message. --- graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index d2e51d2ec4438..646462b4a8350 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -119,7 +119,7 @@ object Pregel extends Logging { mergeMsg: (A, A) => A) : Graph[VD, ED] = { - require(maxIterations > 0, s"Maximum of iterations must be greater than 0," + + require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + s" but got ${maxIterations}") var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() From 5e64dab868be1a0d415fb6d6dd3463e7171fdd1a Mon Sep 17 00:00:00 2001 From: Prajwal Tuladhar Date: Wed, 6 Apr 2016 15:28:52 +0100 Subject: [PATCH 3/7] [SPARK-14430][BUILD] use https while downloading binaries from build/mvn ## What changes were proposed in this pull request? `./build/mvn` file was downloading binaries in non HTTPS mode. This PR tends to fix it. ## How was this patch tested? By running `./build/mvn clean package` locally Author: Prajwal Tuladhar Closes #12182 from infynyxx/mvn_use_https. --- build/mvn | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/build/mvn b/build/mvn index 58058c04b8911..41c0850ccb8fb 100755 --- a/build/mvn +++ b/build/mvn @@ -72,7 +72,7 @@ install_mvn() { local MVN_VERSION="3.3.9" install_app \ - "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ + "https://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ "apache-maven-${MVN_VERSION}-bin.tar.gz" \ "apache-maven-${MVN_VERSION}/bin/mvn" @@ -84,7 +84,7 @@ install_zinc() { local zinc_path="zinc-0.3.9/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 install_app \ - "http://downloads.typesafe.com/zinc/0.3.9" \ + "https://downloads.typesafe.com/zinc/0.3.9" \ "zinc-0.3.9.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" @@ -100,7 +100,7 @@ install_scala() { local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" install_app \ - "http://downloads.typesafe.com/scala/${scala_version}" \ + "https://downloads.typesafe.com/scala/${scala_version}" \ "scala-${scala_version}.tgz" \ "scala-${scala_version}/bin/scala" From 59236e5c5b9d24f90fcf8d09b23ae8b06355657e Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 6 Apr 2016 10:05:02 -0700 Subject: [PATCH 4/7] [SPARK-14288][SQL] Memory Sink for streaming This PR exposes the internal testing `MemorySink` though the data source API. This will allow users to easily test streaming applications in the Spark shell or other local tests. Usage: ```scala inputStream.write .format("memory") .queryName("memStream") .startStream() // Now you can query the result of the stream here. sqlContext.table("memStream") ``` The most complicated part of the logic is choosing the checkpoint directory. There are a few requirements we are attempting to satisfy here: - when working in the shell locally, it should just work with no extra configuration. - when working on a cluster you should be able to make it easily create the checkpoint on a distributed file system so you can test aggregation (state checkpoints are also stored in this directory and must be accessible from workers). - it should be clear that you can't resume since the data is just in memory. The chosen algorithm proceeds as follows: - the user gives a checkpoint directory, use it - if the conf has a checkpoint location, use `$location/$queryName` - if neither, create a local directory - always check to make sure there are no offsets written to the directory Author: Michael Armbrust Closes #12119 from marmbrus/memorySink. --- .../apache/spark/sql/DataFrameWriter.scala | 79 ++++++++++++++---- .../spark/sql/execution/SparkStrategies.scala | 6 ++ .../sql/execution/streaming/memory.scala | 8 ++ .../org/apache/spark/sql/QueryTest.scala | 2 + .../spark/sql/streaming/MemorySinkSuite.scala | 82 +++++++++++++++++++ 5 files changed, 159 insertions(+), 18 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3332a997cda90..54d250867fbb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -29,8 +29,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelation +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -275,23 +277,64 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ def startStream(): ContinuousQuery = { - val dataSource = - DataSource( - df.sqlContext, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) - - val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) - val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { - new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString - }) - df.sqlContext.sessionState.continuousQueryManager.startQuery( - queryName, - checkpointLocation, - df, - dataSource.createSink(), - trigger) + if (source == "memory") { + val queryName = + extraOptions.getOrElse( + "queryName", throw new AnalysisException("queryName must be specified for memory sink")) + val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified => + new Path(userSpecified).toUri.toString + }.orElse { + val checkpointConfig: Option[String] = + df.sqlContext.conf.getConf( + SQLConf.CHECKPOINT_LOCATION, + None) + + checkpointConfig.map { location => + new Path(location, queryName).toUri.toString + } + }.getOrElse { + Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath + } + + // If offsets have already been created, we trying to resume a query. + val checkpointPath = new Path(checkpointLocation, "offsets") + val fs = checkpointPath.getFileSystem(df.sqlContext.sparkContext.hadoopConfiguration) + if (fs.exists(checkpointPath)) { + throw new AnalysisException( + s"Unable to resume query written to memory sink. Delete $checkpointPath to start over.") + } else { + checkpointPath.toUri.toString + } + + val sink = new MemorySink(df.schema) + val resultDf = Dataset.ofRows(df.sqlContext, new MemoryPlan(sink)) + resultDf.registerTempTable(queryName) + val continuousQuery = df.sqlContext.sessionState.continuousQueryManager.startQuery( + queryName, + checkpointLocation, + df, + sink, + trigger) + continuousQuery + } else { + val dataSource = + DataSource( + df.sqlContext, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + + val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) + val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { + new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString + }) + df.sqlContext.sessionState.continuousQueryManager.startQuery( + queryName, + checkpointLocation, + df, + dataSource.createSink(), + trigger) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5f3128d8e42d3..d77aba726098e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -30,6 +31,7 @@ import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescri import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.execution.streaming.MemoryPlan import org.apache.spark.sql.internal.SQLConf private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { @@ -332,6 +334,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil + case MemoryPlan(sink, output) => + val encoder = RowEncoder(sink.schema) + LocalTableScan(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil + case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b652530d7c78c..351ef404a8e39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.types.StructType object MemoryStream { @@ -136,3 +138,9 @@ class MemorySink(val schema: StructType) extends Sink with Logging { } } +/** + * Used to query the data that has been written into a [[MemorySink]]. + */ +case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { + def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 4e62fac919f5e..48a077d0e551a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.streaming.MemoryPlan abstract class QueryTest extends PlanTest { @@ -200,6 +201,7 @@ abstract class QueryTest extends PlanTest { logicalPlan.transform { case _: ObjectOperator => return case _: LogicalRelation => return + case _: MemoryPlan => return }.transformAllExpressions { case a: ImperativeAggregate => return } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala new file mode 100644 index 0000000000000..5249aa28dd6ca --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.{AnalysisException, Row, StreamTest} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class MemorySinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("registering as a table") { + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDataset( + sqlContext.table("memStream").as[Int], + 1, 2, 3) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDataset( + sqlContext.table("memStream").as[Int], + 1, 2, 3, 4, 5, 6) + + query.stop() + } + + test("error when no name is specified") { + val error = intercept[AnalysisException] { + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .startStream() + } + + assert(error.message contains "queryName must be specified") + } + + test("error if attempting to resume specific checkpoint") { + val location = Utils.createTempDir("steaming.checkpoint").getCanonicalPath + + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .option("checkpointLocation", location) + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + intercept[AnalysisException] { + input.toDF().write + .format("memory") + .queryName("memStream") + .option("checkpointLocation", location) + .startStream() + } + } +} From 90ca1844865baf96656a9e5efdf56f415f2646be Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Apr 2016 10:46:34 -0700 Subject: [PATCH 5/7] [SPARK-14418][PYSPARK] fix unpersist of Broadcast in Python ## What changes were proposed in this pull request? Currently, Broaccast.unpersist() will remove the file of broadcast, which should be the behavior of destroy(). This PR added destroy() for Broadcast in Python, to match the sematics in Scala. ## How was this patch tested? Added regression tests. Author: Davies Liu Closes #12189 from davies/py_unpersist. --- python/pyspark/broadcast.py | 17 ++++++++++++++++- python/pyspark/tests.py | 15 +++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 663c9abe0881e..a0b819220e6d3 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -99,11 +99,26 @@ def value(self): def unpersist(self, blocking=False): """ - Delete cached copies of this broadcast on the executors. + Delete cached copies of this broadcast on the executors. If the + broadcast is used after this is called, it will need to be + re-sent to each executor. + + :param blocking: Whether to block until unpersisting has completed """ if self._jbroadcast is None: raise Exception("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) + + def destroy(self): + """ + Destroy all data and metadata related to this broadcast variable. + Use this with caution; once a broadcast variable has been destroyed, + it cannot be used again. This method blocks until destroy has + completed. + """ + if self._jbroadcast is None: + raise Exception("Broadcast can only be destroyed in driver") + self._jbroadcast.destroy() os.unlink(self._path) def __reduce__(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 40fccb8c0090c..15c87e22f98b0 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -694,6 +694,21 @@ def test_large_broadcast(self): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEqual(N, m) + def test_unpersist(self): + N = 1000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 3MB + bdata.unpersist() + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + bdata.destroy() + try: + self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + except Exception as e: + pass + else: + raise Exception("job should fail after destroy the broadcast") + def test_multiple_broadcasts(self): N = 1 << 21 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM From 10494feae0c2c1aca545c73ba61af6d8f743c5bb Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 6 Apr 2016 10:57:46 -0700 Subject: [PATCH 6/7] [SPARK-14426][SQL] Merge PerserUtils and ParseUtils ## What changes were proposed in this pull request? We have ParserUtils and ParseUtils which are both utility collections for use during the parsing process. Those names and what they are used for is very similar so I think we can merge them. Also, the original unescapeSQLString method may have a fault. When "\u0061" style character literals are passed to the method, it's not unescaped successfully. This patch fix the bug. ## How was this patch tested? Added a new test case. Author: Kousuke Saruta Closes #12199 from sarutak/merge-ParseUtils-and-ParserUtils. --- scalastyle-config.xml | 2 +- .../spark/sql/catalyst/parser/ParseUtils.java | 135 ------------------ .../sql/catalyst/parser/ParserUtils.scala | 78 +++++++++- .../parser/ExpressionParserSuite.scala | 2 +- .../catalyst/parser/ParserUtilsSuite.scala | 65 +++++++++ 5 files changed, 144 insertions(+), 138 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 37d2ecf48ec02..33c2cbd293533 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -116,7 +116,7 @@ This file is divided into 3 sections: - + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java deleted file mode 100644 index 01f89112a759b..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser; - -import java.nio.charset.StandardCharsets; - -/** - * A couple of utility methods that help with parsing ASTs. - * - * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive: - * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java - */ -public final class ParseUtils { - private ParseUtils() { - super(); - } - - private static final int[] multiplier = new int[] {1000, 100, 10, 1}; - - @SuppressWarnings("nls") - public static String unescapeSQLString(String b) { - Character enclosure = null; - - // Some of the strings can be passed in as unicode. For example, the - // delimiter can be passed in as \002 - So, we first check if the - // string is a unicode number, else go back to the old behavior - StringBuilder sb = new StringBuilder(b.length()); - for (int i = 0; i < b.length(); i++) { - - char currentChar = b.charAt(i); - if (enclosure == null) { - if (currentChar == '\'' || b.charAt(i) == '\"') { - enclosure = currentChar; - } - // ignore all other chars outside the enclosure - continue; - } - - if (enclosure.equals(currentChar)) { - enclosure = null; - continue; - } - - if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { - int code = 0; - int base = i + 2; - for (int j = 0; j < 4; j++) { - int digit = Character.digit(b.charAt(j + base), 16); - code += digit * multiplier[j]; - } - sb.append((char)code); - i += 5; - continue; - } - - if (currentChar == '\\' && (i + 4 < b.length())) { - char i1 = b.charAt(i + 1); - char i2 = b.charAt(i + 2); - char i3 = b.charAt(i + 3); - if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') - && (i3 >= '0' && i3 <= '7')) { - byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); - byte[] bValArr = new byte[1]; - bValArr[0] = bVal; - String tmp = new String(bValArr, StandardCharsets.UTF_8); - sb.append(tmp); - i += 3; - continue; - } - } - - if (currentChar == '\\' && (i + 2 < b.length())) { - char n = b.charAt(i + 1); - switch (n) { - case '0': - sb.append("\0"); - break; - case '\'': - sb.append("'"); - break; - case '"': - sb.append("\""); - break; - case 'b': - sb.append("\b"); - break; - case 'n': - sb.append("\n"); - break; - case 'r': - sb.append("\r"); - break; - case 't': - sb.append("\t"); - break; - case 'Z': - sb.append("\u001A"); - break; - case '\\': - sb.append("\\"); - break; - // The following 2 lines are exactly what MySQL does TODO: why do we do this? - case '%': - sb.append("\\%"); - break; - case '_': - sb.append("\\_"); - break; - default: - sb.append(n); - } - i++; - } else { - sb.append(currentChar); - } - } - return sb.toString(); - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 90b76dc314a54..cb9fefec8f482 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,11 +16,12 @@ */ package org.apache.spark.sql.catalyst.parser +import scala.collection.mutable.StringBuilder + import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} @@ -87,6 +88,81 @@ object ParserUtils { } } + /** Unescape baskslash-escaped string enclosed by quotes. */ + def unescapeSQLString(b: String): String = { + var enclosure: Character = null + val sb = new StringBuilder(b.length()) + + def appendEscapedChar(n: Char) { + n match { + case '0' => sb.append('\u0000') + case '\'' => sb.append('\'') + case '"' => sb.append('\"') + case 'b' => sb.append('\b') + case 'n' => sb.append('\n') + case 'r' => sb.append('\r') + case 't' => sb.append('\t') + case 'Z' => sb.append('\u001A') + case '\\' => sb.append('\\') + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%' => sb.append("\\%") + case '_' => sb.append("\\_") + case _ => sb.append(n) + } + } + + var i = 0 + val strLength = b.length + while (i < strLength) { + val currentChar = b.charAt(i) + if (enclosure == null) { + if (currentChar == '\'' || currentChar == '\"') { + enclosure = currentChar + } + } else if (enclosure == currentChar) { + enclosure = null + } else if (currentChar == '\\') { + + if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') { + // \u0000 style character literals. + + val base = i + 2 + val code = (0 until 4).foldLeft(0) { (mid, j) => + val digit = Character.digit(b.charAt(j + base), 16) + (mid << 4) + digit + } + sb.append(code.asInstanceOf[Char]) + i += 5 + } else if (i + 4 < strLength) { + // \000 style character literals. + + val i1 = b.charAt(i + 1) + val i2 = b.charAt(i + 2) + val i3 = b.charAt(i + 3) + + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) { + val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char] + sb.append(tmp) + i += 3 + } else { + appendEscapedChar(i1) + i += 1 + } + } else if (i + 2 < strLength) { + // escaped character literals. + val n = b.charAt(i + 1) + appendEscapedChar(n) + i += 1 + } + } else { + // non-escaped character literals. + sb.append(currentChar) + } + i += 1 + } + sb.toString() + } + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index a80d29ce5dcb3..6f40ec67ec6e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -415,7 +415,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") // Unicode - assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)") } test("intervals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala new file mode 100644 index 0000000000000..d090daf7b41eb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +class ParserUtilsSuite extends SparkFunSuite { + + import ParserUtils._ + + test("unescapeSQLString") { + // scalastyle:off nonascii + + // String not including escaped characters and enclosed by double quotes. + assert(unescapeSQLString(""""abcdefg"""") == "abcdefg") + + // String enclosed by single quotes. + assert(unescapeSQLString("""'C0FFEE'""") == "C0FFEE") + + // Strings including single escaped characters. + assert(unescapeSQLString("""'\0'""") == "\u0000") + assert(unescapeSQLString(""""\'"""") == "\'") + assert(unescapeSQLString("""'\"'""") == "\"") + assert(unescapeSQLString(""""\b"""") == "\b") + assert(unescapeSQLString("""'\n'""") == "\n") + assert(unescapeSQLString(""""\r"""") == "\r") + assert(unescapeSQLString("""'\t'""") == "\t") + assert(unescapeSQLString(""""\Z"""") == "\u001A") + assert(unescapeSQLString("""'\\'""") == "\\") + assert(unescapeSQLString(""""\%"""") == "\\%") + assert(unescapeSQLString("""'\_'""") == "\\_") + + // String including '\000' style literal characters. + assert(unescapeSQLString("""'3 + 5 = \070'""") == "3 + 5 = \u0038") + assert(unescapeSQLString(""""\000"""") == "\u0000") + + // String including invalid '\000' style literal characters. + assert(unescapeSQLString(""""\256"""") == "256") + + // String including a '\u0000' style literal characters (\u732B is a cat in Kanji). + assert(unescapeSQLString(""""How cute \u732B are"""") == "How cute \u732B are") + + // String including a surrogate pair character + // (\uD867\uDE3D is Okhotsk atka mackerel in Kanji). + assert(unescapeSQLString(""""\uD867\uDE3D is a fish"""") == "\uD867\uDE3D is a fish") + + // scalastyle:on nonascii + } + + // TODO: Add test cases for other methods in ParserUtils +} From 5abd02c02b3fa3505defdc8ab0c5c5e23a16aa80 Mon Sep 17 00:00:00 2001 From: bomeng Date: Wed, 6 Apr 2016 11:05:52 -0700 Subject: [PATCH 7/7] [SPARK-14429][SQL] Improve LIKE pattern in "SHOW TABLES / FUNCTIONS LIKE " DDL LIKE is commonly used in SHOW TABLES / FUNCTIONS etc DDL. In the pattern, user can use `|` or `*` as wildcards. 1. Currently, we used `replaceAll()` to replace `*` with `.*`, but the replacement was scattered in several places; I have created an utility method and use it in all the places; 2. Consistency with Hive: the pattern is case insensitive in Hive and white spaces will be trimmed, but current pattern matching does not do that. For example, suppose we have tables (t1, t2, t3), `SHOW TABLES LIKE ' T* ' ` will list all the t-tables. Please use Hive to verify it. 3. Combined with `|`, the result will be sorted. For pattern like `' B*|a* '`, it will list the result in a-b order. I've made some changes to the utility method to make sure we will get the same result as Hive does. A new method was created in StringUtil and test cases were added. andrewor14 Author: bomeng Closes #12206 from bomeng/SPARK-14429. --- .../catalyst/catalog/InMemoryCatalog.scala | 13 ++++------- .../sql/catalyst/catalog/SessionCatalog.scala | 10 +++----- .../spark/sql/catalyst/util/StringUtils.scala | 23 ++++++++++++++++++- .../sql/catalyst/util/StringUtilsSuite.scala | 12 ++++++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 12 ++++------ 5 files changed, 46 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 2af0107fa37a0..5d136b663f30c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} - +import org.apache.spark.sql.catalyst.util.StringUtils /** * An in-memory (ephemeral) implementation of the system catalog. @@ -47,11 +47,6 @@ class InMemoryCatalog extends ExternalCatalog { // Database name -> description private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] - private def filterPattern(names: Seq[String], pattern: String): Seq[String] = { - val regex = pattern.replaceAll("\\*", ".*").r - names.filter { funcName => regex.pattern.matcher(funcName).matches() } - } - private def functionExists(db: String, funcName: String): Boolean = { requireDbExists(db) catalog(db).functions.contains(funcName) @@ -141,7 +136,7 @@ class InMemoryCatalog extends ExternalCatalog { } override def listDatabases(pattern: String): Seq[String] = synchronized { - filterPattern(listDatabases(), pattern) + StringUtils.filterPattern(listDatabases(), pattern) } override def setCurrentDatabase(db: String): Unit = { /* no-op */ } @@ -208,7 +203,7 @@ class InMemoryCatalog extends ExternalCatalog { } override def listTables(db: String, pattern: String): Seq[String] = synchronized { - filterPattern(listTables(db), pattern) + StringUtils.filterPattern(listTables(db), pattern) } // -------------------------------------------------------------------------- @@ -322,7 +317,7 @@ class InMemoryCatalog extends ExternalCatalog { override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { requireDbExists(db) - filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) + StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 62a3b1c10590f..2acf584e8ff01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionE import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} - +import org.apache.spark.sql.catalyst.util.StringUtils /** * An internal catalog that is used by a Spark Session. This internal catalog serves as a @@ -297,9 +297,7 @@ class SessionCatalog( def listTables(db: String, pattern: String): Seq[TableIdentifier] = { val dbTables = externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) } - val regex = pattern.replaceAll("\\*", ".*").r - val _tempTables = tempTables.keys.toSeq - .filter { t => regex.pattern.matcher(t).matches() } + val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) .map { t => TableIdentifier(t) } dbTables ++ _tempTables } @@ -613,9 +611,7 @@ class SessionCatalog( def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = { val dbFunctions = externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } - val regex = pattern.replaceAll("\\*", ".*").r - val loadedFunctions = functionRegistry.listFunction() - .filter { f => regex.pattern.matcher(f).matches() } + val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) .map { f => FunctionIdentifier(f) } // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry. // So, the returned list may have two entries for the same function. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index c2eeb3c5650ab..0f65028261b85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.util.regex.Pattern +import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.unsafe.types.UTF8String @@ -52,4 +52,25 @@ object StringUtils { def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) + + /** + * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL + * @param names the names list to be filtered + * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will + * follows regular expression convention, case insensitive match and white spaces + * on both ends will be ignored + * @return the filtered names list in order + */ + def filterPattern(names: Seq[String], pattern: String): Seq[String] = { + val funcNames = scala.collection.mutable.SortedSet.empty[String] + pattern.trim().split("\\|").foreach { subPattern => + try { + val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r + funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } + } catch { + case _: PatternSyntaxException => + } + } + funcNames.toSeq + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index d6f273f9e568a..2ffc18a8d14fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite { assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") } + + test("filter pattern") { + val names = Seq("a1", "a2", "b2", "c3") + assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3")) + assert(filterPattern(names, "*a*") === Seq("a1", "a2")) + assert(filterPattern(names, " *a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a.* ") === Seq("a1", "a2")) + assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2")) + assert(filterPattern(names, " a. ") === Seq("a1", "a2")) + assert(filterPattern(names, " d* ") === Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5a851b47caf87..2ab7c1581cfa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -56,17 +57,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { - val regex = java.util.regex.Pattern.compile(pattern) - sqlContext.sessionState.functionRegistry.listFunction() - .filter(regex.matcher(_).matches()).map(Row(_)) + StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern) + .map(Row(_)) } - checkAnswer(sql("SHOW functions"), getFunctions(".*")) + checkAnswer(sql("SHOW functions"), getFunctions("*")) Seq("^c*", "*e$", "log*", "*date*").foreach { pattern => // For the pattern part, only '*' and '|' are allowed as wildcards. // For '*', we need to replace it to '.*'. - checkAnswer( - sql(s"SHOW FUNCTIONS '$pattern'"), - getFunctions(pattern.replaceAll("\\*", ".*"))) + checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) } }