diff --git a/build/mvn b/build/mvn index 58058c04b8911..41c0850ccb8fb 100755 --- a/build/mvn +++ b/build/mvn @@ -72,7 +72,7 @@ install_mvn() { local MVN_VERSION="3.3.9" install_app \ - "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ + "https://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \ "apache-maven-${MVN_VERSION}-bin.tar.gz" \ "apache-maven-${MVN_VERSION}/bin/mvn" @@ -84,7 +84,7 @@ install_zinc() { local zinc_path="zinc-0.3.9/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 install_app \ - "http://downloads.typesafe.com/zinc/0.3.9" \ + "https://downloads.typesafe.com/zinc/0.3.9" \ "zinc-0.3.9.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" @@ -100,7 +100,7 @@ install_scala() { local scala_bin="${_DIR}/scala-${scala_version}/bin/scala" install_app \ - "http://downloads.typesafe.com/scala/${scala_version}" \ + "https://downloads.typesafe.com/scala/${scala_version}" \ "scala-${scala_version}.tgz" \ "scala-${scala_version}/bin/scala" diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index d2e51d2ec4438..646462b4a8350 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -119,7 +119,7 @@ object Pregel extends Logging { mergeMsg: (A, A) => A) : Graph[VD, ED] = { - require(maxIterations > 0, s"Maximum of iterations must be greater than 0," + + require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + s" but got ${maxIterations}") var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 663c9abe0881e..a0b819220e6d3 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -99,11 +99,26 @@ def value(self): def unpersist(self, blocking=False): """ - Delete cached copies of this broadcast on the executors. + Delete cached copies of this broadcast on the executors. If the + broadcast is used after this is called, it will need to be + re-sent to each executor. + + :param blocking: Whether to block until unpersisting has completed """ if self._jbroadcast is None: raise Exception("Broadcast can only be unpersisted in driver") self._jbroadcast.unpersist(blocking) + + def destroy(self): + """ + Destroy all data and metadata related to this broadcast variable. + Use this with caution; once a broadcast variable has been destroyed, + it cannot be used again. This method blocks until destroy has + completed. + """ + if self._jbroadcast is None: + raise Exception("Broadcast can only be destroyed in driver") + self._jbroadcast.destroy() os.unlink(self._path) def __reduce__(self): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 40fccb8c0090c..15c87e22f98b0 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -694,6 +694,21 @@ def test_large_broadcast(self): m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() self.assertEqual(N, m) + def test_unpersist(self): + N = 1000 + data = [[float(i) for i in range(300)] for i in range(N)] + bdata = self.sc.broadcast(data) # 3MB + bdata.unpersist() + m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + self.assertEqual(N, m) + bdata.destroy() + try: + self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum() + except Exception as e: + pass + else: + raise Exception("job should fail after destroy the broadcast") + def test_multiple_broadcasts(self): N = 1 << 21 b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 37d2ecf48ec02..33c2cbd293533 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -116,7 +116,7 @@ This file is divided into 3 sections: - + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java deleted file mode 100644 index 01f89112a759b..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser; - -import java.nio.charset.StandardCharsets; - -/** - * A couple of utility methods that help with parsing ASTs. - * - * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive: - * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java - */ -public final class ParseUtils { - private ParseUtils() { - super(); - } - - private static final int[] multiplier = new int[] {1000, 100, 10, 1}; - - @SuppressWarnings("nls") - public static String unescapeSQLString(String b) { - Character enclosure = null; - - // Some of the strings can be passed in as unicode. For example, the - // delimiter can be passed in as \002 - So, we first check if the - // string is a unicode number, else go back to the old behavior - StringBuilder sb = new StringBuilder(b.length()); - for (int i = 0; i < b.length(); i++) { - - char currentChar = b.charAt(i); - if (enclosure == null) { - if (currentChar == '\'' || b.charAt(i) == '\"') { - enclosure = currentChar; - } - // ignore all other chars outside the enclosure - continue; - } - - if (enclosure.equals(currentChar)) { - enclosure = null; - continue; - } - - if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { - int code = 0; - int base = i + 2; - for (int j = 0; j < 4; j++) { - int digit = Character.digit(b.charAt(j + base), 16); - code += digit * multiplier[j]; - } - sb.append((char)code); - i += 5; - continue; - } - - if (currentChar == '\\' && (i + 4 < b.length())) { - char i1 = b.charAt(i + 1); - char i2 = b.charAt(i + 2); - char i3 = b.charAt(i + 3); - if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') - && (i3 >= '0' && i3 <= '7')) { - byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); - byte[] bValArr = new byte[1]; - bValArr[0] = bVal; - String tmp = new String(bValArr, StandardCharsets.UTF_8); - sb.append(tmp); - i += 3; - continue; - } - } - - if (currentChar == '\\' && (i + 2 < b.length())) { - char n = b.charAt(i + 1); - switch (n) { - case '0': - sb.append("\0"); - break; - case '\'': - sb.append("'"); - break; - case '"': - sb.append("\""); - break; - case 'b': - sb.append("\b"); - break; - case 'n': - sb.append("\n"); - break; - case 'r': - sb.append("\r"); - break; - case 't': - sb.append("\t"); - break; - case 'Z': - sb.append("\u001A"); - break; - case '\\': - sb.append("\\"); - break; - // The following 2 lines are exactly what MySQL does TODO: why do we do this? - case '%': - sb.append("\\%"); - break; - case '_': - sb.append("\\_"); - break; - default: - sb.append(n); - } - i++; - } else { - sb.append(currentChar); - } - } - return sb.toString(); - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 0c7cd408dfc40..186bbccef1204 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} - +import org.apache.spark.sql.catalyst.util.StringUtils /** * An in-memory (ephemeral) implementation of the system catalog. @@ -47,11 +47,6 @@ class InMemoryCatalog extends ExternalCatalog { // Database name -> description private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] - private def filterPattern(names: Seq[String], pattern: String): Seq[String] = { - val regex = pattern.replaceAll("\\*", ".*").r - names.filter { funcName => regex.pattern.matcher(funcName).matches() } - } - private def partitionExists(db: String, table: String, spec: TablePartitionSpec): Boolean = { requireTableExists(db, table) catalog(db).tables(table).partitions.contains(spec) @@ -136,7 +131,7 @@ class InMemoryCatalog extends ExternalCatalog { } override def listDatabases(pattern: String): Seq[String] = synchronized { - filterPattern(listDatabases(), pattern) + StringUtils.filterPattern(listDatabases(), pattern) } override def setCurrentDatabase(db: String): Unit = { /* no-op */ } @@ -203,7 +198,7 @@ class InMemoryCatalog extends ExternalCatalog { } override def listTables(db: String, pattern: String): Seq[String] = synchronized { - filterPattern(listTables(db), pattern) + StringUtils.filterPattern(listTables(db), pattern) } // -------------------------------------------------------------------------- @@ -322,7 +317,7 @@ class InMemoryCatalog extends ExternalCatalog { override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { requireDbExists(db) - filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) + StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 4825863ea92ec..7db9fd0527ec4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionE import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} - +import org.apache.spark.sql.catalyst.util.StringUtils /** * An internal catalog that is used by a Spark Session. This internal catalog serves as a @@ -297,9 +297,7 @@ class SessionCatalog( def listTables(db: String, pattern: String): Seq[TableIdentifier] = { val dbTables = externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) } - val regex = pattern.replaceAll("\\*", ".*").r - val _tempTables = tempTables.keys.toSeq - .filter { t => regex.pattern.matcher(t).matches() } + val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) .map { t => TableIdentifier(t) } dbTables ++ _tempTables } @@ -610,9 +608,7 @@ class SessionCatalog( def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = { val dbFunctions = externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } - val regex = pattern.replaceAll("\\*", ".*").r - val loadedFunctions = functionRegistry.listFunction() - .filter { f => regex.pattern.matcher(f).matches() } + val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) .map { f => FunctionIdentifier(f) } // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry. // So, the returned list may have two entries for the same function. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 90b76dc314a54..cb9fefec8f482 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,11 +16,12 @@ */ package org.apache.spark.sql.catalyst.parser +import scala.collection.mutable.StringBuilder + import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} @@ -87,6 +88,81 @@ object ParserUtils { } } + /** Unescape baskslash-escaped string enclosed by quotes. */ + def unescapeSQLString(b: String): String = { + var enclosure: Character = null + val sb = new StringBuilder(b.length()) + + def appendEscapedChar(n: Char) { + n match { + case '0' => sb.append('\u0000') + case '\'' => sb.append('\'') + case '"' => sb.append('\"') + case 'b' => sb.append('\b') + case 'n' => sb.append('\n') + case 'r' => sb.append('\r') + case 't' => sb.append('\t') + case 'Z' => sb.append('\u001A') + case '\\' => sb.append('\\') + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%' => sb.append("\\%") + case '_' => sb.append("\\_") + case _ => sb.append(n) + } + } + + var i = 0 + val strLength = b.length + while (i < strLength) { + val currentChar = b.charAt(i) + if (enclosure == null) { + if (currentChar == '\'' || currentChar == '\"') { + enclosure = currentChar + } + } else if (enclosure == currentChar) { + enclosure = null + } else if (currentChar == '\\') { + + if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') { + // \u0000 style character literals. + + val base = i + 2 + val code = (0 until 4).foldLeft(0) { (mid, j) => + val digit = Character.digit(b.charAt(j + base), 16) + (mid << 4) + digit + } + sb.append(code.asInstanceOf[Char]) + i += 5 + } else if (i + 4 < strLength) { + // \000 style character literals. + + val i1 = b.charAt(i + 1) + val i2 = b.charAt(i + 2) + val i3 = b.charAt(i + 3) + + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) { + val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char] + sb.append(tmp) + i += 3 + } else { + appendEscapedChar(i1) + i += 1 + } + } else if (i + 2 < strLength) { + // escaped character literals. + val n = b.charAt(i + 1) + appendEscapedChar(n) + i += 1 + } + } else { + // non-escaped character literals. + sb.append(currentChar) + } + i += 1 + } + sb.toString() + } + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index c2eeb3c5650ab..0f65028261b85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.util.regex.Pattern +import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.unsafe.types.UTF8String @@ -52,4 +52,25 @@ object StringUtils { def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) + + /** + * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL + * @param names the names list to be filtered + * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will + * follows regular expression convention, case insensitive match and white spaces + * on both ends will be ignored + * @return the filtered names list in order + */ + def filterPattern(names: Seq[String], pattern: String): Seq[String] = { + val funcNames = scala.collection.mutable.SortedSet.empty[String] + pattern.trim().split("\\|").foreach { subPattern => + try { + val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r + funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } + } catch { + case _: PatternSyntaxException => + } + } + funcNames.toSeq + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index a80d29ce5dcb3..6f40ec67ec6e0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -415,7 +415,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") // Unicode - assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)") } test("intervals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala new file mode 100644 index 0000000000000..d090daf7b41eb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +class ParserUtilsSuite extends SparkFunSuite { + + import ParserUtils._ + + test("unescapeSQLString") { + // scalastyle:off nonascii + + // String not including escaped characters and enclosed by double quotes. + assert(unescapeSQLString(""""abcdefg"""") == "abcdefg") + + // String enclosed by single quotes. + assert(unescapeSQLString("""'C0FFEE'""") == "C0FFEE") + + // Strings including single escaped characters. + assert(unescapeSQLString("""'\0'""") == "\u0000") + assert(unescapeSQLString(""""\'"""") == "\'") + assert(unescapeSQLString("""'\"'""") == "\"") + assert(unescapeSQLString(""""\b"""") == "\b") + assert(unescapeSQLString("""'\n'""") == "\n") + assert(unescapeSQLString(""""\r"""") == "\r") + assert(unescapeSQLString("""'\t'""") == "\t") + assert(unescapeSQLString(""""\Z"""") == "\u001A") + assert(unescapeSQLString("""'\\'""") == "\\") + assert(unescapeSQLString(""""\%"""") == "\\%") + assert(unescapeSQLString("""'\_'""") == "\\_") + + // String including '\000' style literal characters. + assert(unescapeSQLString("""'3 + 5 = \070'""") == "3 + 5 = \u0038") + assert(unescapeSQLString(""""\000"""") == "\u0000") + + // String including invalid '\000' style literal characters. + assert(unescapeSQLString(""""\256"""") == "256") + + // String including a '\u0000' style literal characters (\u732B is a cat in Kanji). + assert(unescapeSQLString(""""How cute \u732B are"""") == "How cute \u732B are") + + // String including a surrogate pair character + // (\uD867\uDE3D is Okhotsk atka mackerel in Kanji). + assert(unescapeSQLString(""""\uD867\uDE3D is a fish"""") == "\uD867\uDE3D is a fish") + + // scalastyle:on nonascii + } + + // TODO: Add test cases for other methods in ParserUtils +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index d6f273f9e568a..2ffc18a8d14fb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite { assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") } + + test("filter pattern") { + val names = Seq("a1", "a2", "b2", "c3") + assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3")) + assert(filterPattern(names, "*a*") === Seq("a1", "a2")) + assert(filterPattern(names, " *a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a.* ") === Seq("a1", "a2")) + assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2")) + assert(filterPattern(names, " a. ") === Seq("a1", "a2")) + assert(filterPattern(names, " d* ") === Nil) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 3332a997cda90..54d250867fbb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -29,8 +29,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.streaming.StreamExecution +import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.HadoopFsRelation +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -275,23 +277,64 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 2.0.0 */ def startStream(): ContinuousQuery = { - val dataSource = - DataSource( - df.sqlContext, - className = source, - options = extraOptions.toMap, - partitionColumns = normalizedParCols.getOrElse(Nil)) - - val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) - val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { - new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString - }) - df.sqlContext.sessionState.continuousQueryManager.startQuery( - queryName, - checkpointLocation, - df, - dataSource.createSink(), - trigger) + if (source == "memory") { + val queryName = + extraOptions.getOrElse( + "queryName", throw new AnalysisException("queryName must be specified for memory sink")) + val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified => + new Path(userSpecified).toUri.toString + }.orElse { + val checkpointConfig: Option[String] = + df.sqlContext.conf.getConf( + SQLConf.CHECKPOINT_LOCATION, + None) + + checkpointConfig.map { location => + new Path(location, queryName).toUri.toString + } + }.getOrElse { + Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath + } + + // If offsets have already been created, we trying to resume a query. + val checkpointPath = new Path(checkpointLocation, "offsets") + val fs = checkpointPath.getFileSystem(df.sqlContext.sparkContext.hadoopConfiguration) + if (fs.exists(checkpointPath)) { + throw new AnalysisException( + s"Unable to resume query written to memory sink. Delete $checkpointPath to start over.") + } else { + checkpointPath.toUri.toString + } + + val sink = new MemorySink(df.schema) + val resultDf = Dataset.ofRows(df.sqlContext, new MemoryPlan(sink)) + resultDf.registerTempTable(queryName) + val continuousQuery = df.sqlContext.sessionState.continuousQueryManager.startQuery( + queryName, + checkpointLocation, + df, + sink, + trigger) + continuousQuery + } else { + val dataSource = + DataSource( + df.sqlContext, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + + val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName) + val checkpointLocation = extraOptions.getOrElse("checkpointLocation", { + new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString + }) + df.sqlContext.sessionState.continuousQueryManager.startQuery( + queryName, + checkpointLocation, + df, + dataSource.createSink(), + trigger) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5f3128d8e42d3..d77aba726098e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -30,6 +31,7 @@ import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescri import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} +import org.apache.spark.sql.execution.streaming.MemoryPlan import org.apache.spark.sql.internal.SQLConf private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { @@ -332,6 +334,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case r: RunnableCommand => ExecutedCommand(r) :: Nil + case MemoryPlan(sink, output) => + val encoder = RowEncoder(sink.schema) + LocalTableScan(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil + case logical.Distinct(child) => throw new IllegalStateException( "logical distinct operator should have been replaced by aggregate in the optimizer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index b652530d7c78c..351ef404a8e39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -25,6 +25,8 @@ import scala.util.control.NonFatal import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext} import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LeafNode import org.apache.spark.sql.types.StructType object MemoryStream { @@ -136,3 +138,9 @@ class MemorySink(val schema: StructType) extends Sink with Logging { } } +/** + * Used to query the data that has been written into a [[MemorySink]]. + */ +case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { + def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 4e62fac919f5e..48a077d0e551a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.streaming.MemoryPlan abstract class QueryTest extends PlanTest { @@ -200,6 +201,7 @@ abstract class QueryTest extends PlanTest { logicalPlan.transform { case _: ObjectOperator => return case _: LogicalRelation => return + case _: MemoryPlan => return }.transformAllExpressions { case a: ImperativeAggregate => return } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5a851b47caf87..2ab7c1581cfa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -56,17 +57,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { - val regex = java.util.regex.Pattern.compile(pattern) - sqlContext.sessionState.functionRegistry.listFunction() - .filter(regex.matcher(_).matches()).map(Row(_)) + StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern) + .map(Row(_)) } - checkAnswer(sql("SHOW functions"), getFunctions(".*")) + checkAnswer(sql("SHOW functions"), getFunctions("*")) Seq("^c*", "*e$", "log*", "*date*").foreach { pattern => // For the pattern part, only '*' and '|' are allowed as wildcards. // For '*', we need to replace it to '.*'. - checkAnswer( - sql(s"SHOW FUNCTIONS '$pattern'"), - getFunctions(pattern.replaceAll("\\*", ".*"))) + checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala new file mode 100644 index 0000000000000..5249aa28dd6ca --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.{AnalysisException, Row, StreamTest} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +class MemorySinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + test("registering as a table") { + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + + checkDataset( + sqlContext.table("memStream").as[Int], + 1, 2, 3) + + input.addData(4, 5, 6) + query.processAllAvailable() + checkDataset( + sqlContext.table("memStream").as[Int], + 1, 2, 3, 4, 5, 6) + + query.stop() + } + + test("error when no name is specified") { + val error = intercept[AnalysisException] { + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .startStream() + } + + assert(error.message contains "queryName must be specified") + } + + test("error if attempting to resume specific checkpoint") { + val location = Utils.createTempDir("steaming.checkpoint").getCanonicalPath + + val input = MemoryStream[Int] + val query = input.toDF().write + .format("memory") + .queryName("memStream") + .option("checkpointLocation", location) + .startStream() + input.addData(1, 2, 3) + query.processAllAvailable() + query.stop() + + intercept[AnalysisException] { + input.toDF().write + .format("memory") + .queryName("memStream") + .option("checkpointLocation", location) + .startStream() + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index b4e5d4adf1728..c5f01da4fabbb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -263,8 +263,8 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW) assert(desc.storage.locationUri.isEmpty) assert(desc.schema == Seq.empty[CatalogColumn]) - assert(desc.viewText.contains("SELECT * FROM tab1")) - assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) + assert(desc.viewText == Option("SELECT * FROM tab1")) + assert(desc.viewOriginalText == Option("SELECT * FROM tab1")) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat.isEmpty) assert(desc.storage.outputFormat.isEmpty) @@ -290,8 +290,8 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.schema == CatalogColumn("col1", null, nullable = true, None) :: CatalogColumn("col3", null, nullable = true, None) :: Nil) - assert(desc.viewText.contains("SELECT * FROM tab1")) - assert(desc.viewOriginalText.contains("SELECT * FROM tab1")) + assert(desc.viewText == Option("SELECT * FROM tab1")) + assert(desc.viewOriginalText == Option("SELECT * FROM tab1")) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat.isEmpty) assert(desc.storage.outputFormat.isEmpty)