diff --git a/build/mvn b/build/mvn
index 58058c04b8911..41c0850ccb8fb 100755
--- a/build/mvn
+++ b/build/mvn
@@ -72,7 +72,7 @@ install_mvn() {
local MVN_VERSION="3.3.9"
install_app \
- "http://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \
+ "https://archive.apache.org/dist/maven/maven-3/${MVN_VERSION}/binaries" \
"apache-maven-${MVN_VERSION}-bin.tar.gz" \
"apache-maven-${MVN_VERSION}/bin/mvn"
@@ -84,7 +84,7 @@ install_zinc() {
local zinc_path="zinc-0.3.9/bin/zinc"
[ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1
install_app \
- "http://downloads.typesafe.com/zinc/0.3.9" \
+ "https://downloads.typesafe.com/zinc/0.3.9" \
"zinc-0.3.9.tgz" \
"${zinc_path}"
ZINC_BIN="${_DIR}/${zinc_path}"
@@ -100,7 +100,7 @@ install_scala() {
local scala_bin="${_DIR}/scala-${scala_version}/bin/scala"
install_app \
- "http://downloads.typesafe.com/scala/${scala_version}" \
+ "https://downloads.typesafe.com/scala/${scala_version}" \
"scala-${scala_version}.tgz" \
"scala-${scala_version}/bin/scala"
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
index d2e51d2ec4438..646462b4a8350 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
@@ -119,7 +119,7 @@ object Pregel extends Logging {
mergeMsg: (A, A) => A)
: Graph[VD, ED] =
{
- require(maxIterations > 0, s"Maximum of iterations must be greater than 0," +
+ require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," +
s" but got ${maxIterations}")
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index 663c9abe0881e..a0b819220e6d3 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -99,11 +99,26 @@ def value(self):
def unpersist(self, blocking=False):
"""
- Delete cached copies of this broadcast on the executors.
+ Delete cached copies of this broadcast on the executors. If the
+ broadcast is used after this is called, it will need to be
+ re-sent to each executor.
+
+ :param blocking: Whether to block until unpersisting has completed
"""
if self._jbroadcast is None:
raise Exception("Broadcast can only be unpersisted in driver")
self._jbroadcast.unpersist(blocking)
+
+ def destroy(self):
+ """
+ Destroy all data and metadata related to this broadcast variable.
+ Use this with caution; once a broadcast variable has been destroyed,
+ it cannot be used again. This method blocks until destroy has
+ completed.
+ """
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be destroyed in driver")
+ self._jbroadcast.destroy()
os.unlink(self._path)
def __reduce__(self):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 40fccb8c0090c..15c87e22f98b0 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -694,6 +694,21 @@ def test_large_broadcast(self):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEqual(N, m)
+ def test_unpersist(self):
+ N = 1000
+ data = [[float(i) for i in range(300)] for i in range(N)]
+ bdata = self.sc.broadcast(data) # 3MB
+ bdata.unpersist()
+ m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ self.assertEqual(N, m)
+ bdata.destroy()
+ try:
+ self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
+ except Exception as e:
+ pass
+ else:
+ raise Exception("job should fail after destroy the broadcast")
+
def test_multiple_broadcasts(self):
N = 1 << 21
b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 37d2ecf48ec02..33c2cbd293533 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -116,7 +116,7 @@ This file is divided into 3 sections:
-
+
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
deleted file mode 100644
index 01f89112a759b..0000000000000
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java
+++ /dev/null
@@ -1,135 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.parser;
-
-import java.nio.charset.StandardCharsets;
-
-/**
- * A couple of utility methods that help with parsing ASTs.
- *
- * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive:
- * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java
- */
-public final class ParseUtils {
- private ParseUtils() {
- super();
- }
-
- private static final int[] multiplier = new int[] {1000, 100, 10, 1};
-
- @SuppressWarnings("nls")
- public static String unescapeSQLString(String b) {
- Character enclosure = null;
-
- // Some of the strings can be passed in as unicode. For example, the
- // delimiter can be passed in as \002 - So, we first check if the
- // string is a unicode number, else go back to the old behavior
- StringBuilder sb = new StringBuilder(b.length());
- for (int i = 0; i < b.length(); i++) {
-
- char currentChar = b.charAt(i);
- if (enclosure == null) {
- if (currentChar == '\'' || b.charAt(i) == '\"') {
- enclosure = currentChar;
- }
- // ignore all other chars outside the enclosure
- continue;
- }
-
- if (enclosure.equals(currentChar)) {
- enclosure = null;
- continue;
- }
-
- if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') {
- int code = 0;
- int base = i + 2;
- for (int j = 0; j < 4; j++) {
- int digit = Character.digit(b.charAt(j + base), 16);
- code += digit * multiplier[j];
- }
- sb.append((char)code);
- i += 5;
- continue;
- }
-
- if (currentChar == '\\' && (i + 4 < b.length())) {
- char i1 = b.charAt(i + 1);
- char i2 = b.charAt(i + 2);
- char i3 = b.charAt(i + 3);
- if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7')
- && (i3 >= '0' && i3 <= '7')) {
- byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8));
- byte[] bValArr = new byte[1];
- bValArr[0] = bVal;
- String tmp = new String(bValArr, StandardCharsets.UTF_8);
- sb.append(tmp);
- i += 3;
- continue;
- }
- }
-
- if (currentChar == '\\' && (i + 2 < b.length())) {
- char n = b.charAt(i + 1);
- switch (n) {
- case '0':
- sb.append("\0");
- break;
- case '\'':
- sb.append("'");
- break;
- case '"':
- sb.append("\"");
- break;
- case 'b':
- sb.append("\b");
- break;
- case 'n':
- sb.append("\n");
- break;
- case 'r':
- sb.append("\r");
- break;
- case 't':
- sb.append("\t");
- break;
- case 'Z':
- sb.append("\u001A");
- break;
- case '\\':
- sb.append("\\");
- break;
- // The following 2 lines are exactly what MySQL does TODO: why do we do this?
- case '%':
- sb.append("\\%");
- break;
- case '_':
- sb.append("\\_");
- break;
- default:
- sb.append(n);
- }
- i++;
- } else {
- sb.append(currentChar);
- }
- }
- return sb.toString();
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
index 0c7cd408dfc40..186bbccef1204 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-
+import org.apache.spark.sql.catalyst.util.StringUtils
/**
* An in-memory (ephemeral) implementation of the system catalog.
@@ -47,11 +47,6 @@ class InMemoryCatalog extends ExternalCatalog {
// Database name -> description
private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc]
- private def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
- val regex = pattern.replaceAll("\\*", ".*").r
- names.filter { funcName => regex.pattern.matcher(funcName).matches() }
- }
-
private def partitionExists(db: String, table: String, spec: TablePartitionSpec): Boolean = {
requireTableExists(db, table)
catalog(db).tables(table).partitions.contains(spec)
@@ -136,7 +131,7 @@ class InMemoryCatalog extends ExternalCatalog {
}
override def listDatabases(pattern: String): Seq[String] = synchronized {
- filterPattern(listDatabases(), pattern)
+ StringUtils.filterPattern(listDatabases(), pattern)
}
override def setCurrentDatabase(db: String): Unit = { /* no-op */ }
@@ -203,7 +198,7 @@ class InMemoryCatalog extends ExternalCatalog {
}
override def listTables(db: String, pattern: String): Seq[String] = synchronized {
- filterPattern(listTables(db), pattern)
+ StringUtils.filterPattern(listTables(db), pattern)
}
// --------------------------------------------------------------------------
@@ -322,7 +317,7 @@ class InMemoryCatalog extends ExternalCatalog {
override def listFunctions(db: String, pattern: String): Seq[String] = synchronized {
requireDbExists(db)
- filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
+ StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 4825863ea92ec..7db9fd0527ec4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionE
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
-
+import org.apache.spark.sql.catalyst.util.StringUtils
/**
* An internal catalog that is used by a Spark Session. This internal catalog serves as a
@@ -297,9 +297,7 @@ class SessionCatalog(
def listTables(db: String, pattern: String): Seq[TableIdentifier] = {
val dbTables =
externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) }
- val regex = pattern.replaceAll("\\*", ".*").r
- val _tempTables = tempTables.keys.toSeq
- .filter { t => regex.pattern.matcher(t).matches() }
+ val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern)
.map { t => TableIdentifier(t) }
dbTables ++ _tempTables
}
@@ -610,9 +608,7 @@ class SessionCatalog(
def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = {
val dbFunctions =
externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) }
- val regex = pattern.replaceAll("\\*", ".*").r
- val loadedFunctions = functionRegistry.listFunction()
- .filter { f => regex.pattern.matcher(f).matches() }
+ val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern)
.map { f => FunctionIdentifier(f) }
// TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry.
// So, the returned list may have two entries for the same function.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
index 90b76dc314a54..cb9fefec8f482 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala
@@ -16,11 +16,12 @@
*/
package org.apache.spark.sql.catalyst.parser
+import scala.collection.mutable.StringBuilder
+
import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token}
import org.antlr.v4.runtime.misc.Interval
import org.antlr.v4.runtime.tree.TerminalNode
-import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
@@ -87,6 +88,81 @@ object ParserUtils {
}
}
+ /** Unescape baskslash-escaped string enclosed by quotes. */
+ def unescapeSQLString(b: String): String = {
+ var enclosure: Character = null
+ val sb = new StringBuilder(b.length())
+
+ def appendEscapedChar(n: Char) {
+ n match {
+ case '0' => sb.append('\u0000')
+ case '\'' => sb.append('\'')
+ case '"' => sb.append('\"')
+ case 'b' => sb.append('\b')
+ case 'n' => sb.append('\n')
+ case 'r' => sb.append('\r')
+ case 't' => sb.append('\t')
+ case 'Z' => sb.append('\u001A')
+ case '\\' => sb.append('\\')
+ // The following 2 lines are exactly what MySQL does TODO: why do we do this?
+ case '%' => sb.append("\\%")
+ case '_' => sb.append("\\_")
+ case _ => sb.append(n)
+ }
+ }
+
+ var i = 0
+ val strLength = b.length
+ while (i < strLength) {
+ val currentChar = b.charAt(i)
+ if (enclosure == null) {
+ if (currentChar == '\'' || currentChar == '\"') {
+ enclosure = currentChar
+ }
+ } else if (enclosure == currentChar) {
+ enclosure = null
+ } else if (currentChar == '\\') {
+
+ if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') {
+ // \u0000 style character literals.
+
+ val base = i + 2
+ val code = (0 until 4).foldLeft(0) { (mid, j) =>
+ val digit = Character.digit(b.charAt(j + base), 16)
+ (mid << 4) + digit
+ }
+ sb.append(code.asInstanceOf[Char])
+ i += 5
+ } else if (i + 4 < strLength) {
+ // \000 style character literals.
+
+ val i1 = b.charAt(i + 1)
+ val i2 = b.charAt(i + 2)
+ val i3 = b.charAt(i + 3)
+
+ if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) {
+ val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char]
+ sb.append(tmp)
+ i += 3
+ } else {
+ appendEscapedChar(i1)
+ i += 1
+ }
+ } else if (i + 2 < strLength) {
+ // escaped character literals.
+ val n = b.charAt(i + 1)
+ appendEscapedChar(n)
+ i += 1
+ }
+ } else {
+ // non-escaped character literals.
+ sb.append(currentChar)
+ }
+ i += 1
+ }
+ sb.toString()
+ }
+
/** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */
implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal {
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
index c2eeb3c5650ab..0f65028261b85 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.util
-import java.util.regex.Pattern
+import java.util.regex.{Pattern, PatternSyntaxException}
import org.apache.spark.unsafe.types.UTF8String
@@ -52,4 +52,25 @@ object StringUtils {
def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase)
def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase)
+
+ /**
+ * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL
+ * @param names the names list to be filtered
+ * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will
+ * follows regular expression convention, case insensitive match and white spaces
+ * on both ends will be ignored
+ * @return the filtered names list in order
+ */
+ def filterPattern(names: Seq[String], pattern: String): Seq[String] = {
+ val funcNames = scala.collection.mutable.SortedSet.empty[String]
+ pattern.trim().split("\\|").foreach { subPattern =>
+ try {
+ val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r
+ funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() }
+ } catch {
+ case _: PatternSyntaxException =>
+ }
+ }
+ funcNames.toSeq
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index a80d29ce5dcb3..6f40ec67ec6e0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -415,7 +415,7 @@ class ExpressionParserSuite extends PlanTest {
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!")
// Unicode
- assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)")
+ assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)")
}
test("intervals") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
new file mode 100644
index 0000000000000..d090daf7b41eb
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.SparkFunSuite
+
+class ParserUtilsSuite extends SparkFunSuite {
+
+ import ParserUtils._
+
+ test("unescapeSQLString") {
+ // scalastyle:off nonascii
+
+ // String not including escaped characters and enclosed by double quotes.
+ assert(unescapeSQLString(""""abcdefg"""") == "abcdefg")
+
+ // String enclosed by single quotes.
+ assert(unescapeSQLString("""'C0FFEE'""") == "C0FFEE")
+
+ // Strings including single escaped characters.
+ assert(unescapeSQLString("""'\0'""") == "\u0000")
+ assert(unescapeSQLString(""""\'"""") == "\'")
+ assert(unescapeSQLString("""'\"'""") == "\"")
+ assert(unescapeSQLString(""""\b"""") == "\b")
+ assert(unescapeSQLString("""'\n'""") == "\n")
+ assert(unescapeSQLString(""""\r"""") == "\r")
+ assert(unescapeSQLString("""'\t'""") == "\t")
+ assert(unescapeSQLString(""""\Z"""") == "\u001A")
+ assert(unescapeSQLString("""'\\'""") == "\\")
+ assert(unescapeSQLString(""""\%"""") == "\\%")
+ assert(unescapeSQLString("""'\_'""") == "\\_")
+
+ // String including '\000' style literal characters.
+ assert(unescapeSQLString("""'3 + 5 = \070'""") == "3 + 5 = \u0038")
+ assert(unescapeSQLString(""""\000"""") == "\u0000")
+
+ // String including invalid '\000' style literal characters.
+ assert(unescapeSQLString(""""\256"""") == "256")
+
+ // String including a '\u0000' style literal characters (\u732B is a cat in Kanji).
+ assert(unescapeSQLString(""""How cute \u732B are"""") == "How cute \u732B are")
+
+ // String including a surrogate pair character
+ // (\uD867\uDE3D is Okhotsk atka mackerel in Kanji).
+ assert(unescapeSQLString(""""\uD867\uDE3D is a fish"""") == "\uD867\uDE3D is a fish")
+
+ // scalastyle:on nonascii
+ }
+
+ // TODO: Add test cases for other methods in ParserUtils
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
index d6f273f9e568a..2ffc18a8d14fb 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala
@@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite {
assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E")
assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E")
}
+
+ test("filter pattern") {
+ val names = Seq("a1", "a2", "b2", "c3")
+ assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3"))
+ assert(filterPattern(names, "*a*") === Seq("a1", "a2"))
+ assert(filterPattern(names, " *a* ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " a* ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " a.* ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2"))
+ assert(filterPattern(names, " a. ") === Seq("a1", "a2"))
+ assert(filterPattern(names, " d* ") === Nil)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index 3332a997cda90..54d250867fbb3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -29,8 +29,10 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project}
import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource}
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
-import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.HadoopFsRelation
+import org.apache.spark.util.Utils
/**
* :: Experimental ::
@@ -275,23 +277,64 @@ final class DataFrameWriter private[sql](df: DataFrame) {
* @since 2.0.0
*/
def startStream(): ContinuousQuery = {
- val dataSource =
- DataSource(
- df.sqlContext,
- className = source,
- options = extraOptions.toMap,
- partitionColumns = normalizedParCols.getOrElse(Nil))
-
- val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
- val checkpointLocation = extraOptions.getOrElse("checkpointLocation", {
- new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString
- })
- df.sqlContext.sessionState.continuousQueryManager.startQuery(
- queryName,
- checkpointLocation,
- df,
- dataSource.createSink(),
- trigger)
+ if (source == "memory") {
+ val queryName =
+ extraOptions.getOrElse(
+ "queryName", throw new AnalysisException("queryName must be specified for memory sink"))
+ val checkpointLocation = extraOptions.get("checkpointLocation").map { userSpecified =>
+ new Path(userSpecified).toUri.toString
+ }.orElse {
+ val checkpointConfig: Option[String] =
+ df.sqlContext.conf.getConf(
+ SQLConf.CHECKPOINT_LOCATION,
+ None)
+
+ checkpointConfig.map { location =>
+ new Path(location, queryName).toUri.toString
+ }
+ }.getOrElse {
+ Utils.createTempDir(namePrefix = "memory.stream").getCanonicalPath
+ }
+
+ // If offsets have already been created, we trying to resume a query.
+ val checkpointPath = new Path(checkpointLocation, "offsets")
+ val fs = checkpointPath.getFileSystem(df.sqlContext.sparkContext.hadoopConfiguration)
+ if (fs.exists(checkpointPath)) {
+ throw new AnalysisException(
+ s"Unable to resume query written to memory sink. Delete $checkpointPath to start over.")
+ } else {
+ checkpointPath.toUri.toString
+ }
+
+ val sink = new MemorySink(df.schema)
+ val resultDf = Dataset.ofRows(df.sqlContext, new MemoryPlan(sink))
+ resultDf.registerTempTable(queryName)
+ val continuousQuery = df.sqlContext.sessionState.continuousQueryManager.startQuery(
+ queryName,
+ checkpointLocation,
+ df,
+ sink,
+ trigger)
+ continuousQuery
+ } else {
+ val dataSource =
+ DataSource(
+ df.sqlContext,
+ className = source,
+ options = extraOptions.toMap,
+ partitionColumns = normalizedParCols.getOrElse(Nil))
+
+ val queryName = extraOptions.getOrElse("queryName", StreamExecution.nextName)
+ val checkpointLocation = extraOptions.getOrElse("checkpointLocation", {
+ new Path(df.sqlContext.conf.checkpointLocation, queryName).toUri.toString
+ })
+ df.sqlContext.sessionState.continuousQueryManager.startQuery(
+ queryName,
+ checkpointLocation,
+ df,
+ dataSource.createSink(),
+ trigger)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 5f3128d8e42d3..d77aba726098e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
@@ -30,6 +31,7 @@ import org.apache.spark.sql.execution.command.{DescribeCommand => RunnableDescri
import org.apache.spark.sql.execution.datasources.{DescribeCommand => LogicalDescribeCommand, _}
import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight}
+import org.apache.spark.sql.execution.streaming.MemoryPlan
import org.apache.spark.sql.internal.SQLConf
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
@@ -332,6 +334,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case r: RunnableCommand => ExecutedCommand(r) :: Nil
+ case MemoryPlan(sink, output) =>
+ val encoder = RowEncoder(sink.schema)
+ LocalTableScan(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil
+
case logical.Distinct(child) =>
throw new IllegalStateException(
"logical distinct operator should have been replaced by aggregate in the optimizer")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
index b652530d7c78c..351ef404a8e39 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala
@@ -25,6 +25,8 @@ import scala.util.control.NonFatal
import org.apache.spark.internal.Logging
import org.apache.spark.sql.{DataFrame, Dataset, Encoder, Row, SQLContext}
import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LeafNode
import org.apache.spark.sql.types.StructType
object MemoryStream {
@@ -136,3 +138,9 @@ class MemorySink(val schema: StructType) extends Sink with Logging {
}
}
+/**
+ * Used to query the data that has been written into a [[MemorySink]].
+ */
+case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode {
+ def this(sink: MemorySink) = this(sink, sink.schema.toAttributes)
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 4e62fac919f5e..48a077d0e551a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.streaming.MemoryPlan
abstract class QueryTest extends PlanTest {
@@ -200,6 +201,7 @@ abstract class QueryTest extends PlanTest {
logicalPlan.transform {
case _: ObjectOperator => return
case _: LogicalRelation => return
+ case _: MemoryPlan => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5a851b47caf87..2ab7c1581cfa0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin}
import org.apache.spark.sql.functions._
@@ -56,17 +57,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
test("show functions") {
def getFunctions(pattern: String): Seq[Row] = {
- val regex = java.util.regex.Pattern.compile(pattern)
- sqlContext.sessionState.functionRegistry.listFunction()
- .filter(regex.matcher(_).matches()).map(Row(_))
+ StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern)
+ .map(Row(_))
}
- checkAnswer(sql("SHOW functions"), getFunctions(".*"))
+ checkAnswer(sql("SHOW functions"), getFunctions("*"))
Seq("^c*", "*e$", "log*", "*date*").foreach { pattern =>
// For the pattern part, only '*' and '|' are allowed as wildcards.
// For '*', we need to replace it to '.*'.
- checkAnswer(
- sql(s"SHOW FUNCTIONS '$pattern'"),
- getFunctions(pattern.replaceAll("\\*", ".*")))
+ checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
new file mode 100644
index 0000000000000..5249aa28dd6ca
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.spark.sql.{AnalysisException, Row, StreamTest}
+import org.apache.spark.sql.execution.streaming._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.util.Utils
+
+class MemorySinkSuite extends StreamTest with SharedSQLContext {
+ import testImplicits._
+
+ test("registering as a table") {
+ val input = MemoryStream[Int]
+ val query = input.toDF().write
+ .format("memory")
+ .queryName("memStream")
+ .startStream()
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+
+ checkDataset(
+ sqlContext.table("memStream").as[Int],
+ 1, 2, 3)
+
+ input.addData(4, 5, 6)
+ query.processAllAvailable()
+ checkDataset(
+ sqlContext.table("memStream").as[Int],
+ 1, 2, 3, 4, 5, 6)
+
+ query.stop()
+ }
+
+ test("error when no name is specified") {
+ val error = intercept[AnalysisException] {
+ val input = MemoryStream[Int]
+ val query = input.toDF().write
+ .format("memory")
+ .startStream()
+ }
+
+ assert(error.message contains "queryName must be specified")
+ }
+
+ test("error if attempting to resume specific checkpoint") {
+ val location = Utils.createTempDir("steaming.checkpoint").getCanonicalPath
+
+ val input = MemoryStream[Int]
+ val query = input.toDF().write
+ .format("memory")
+ .queryName("memStream")
+ .option("checkpointLocation", location)
+ .startStream()
+ input.addData(1, 2, 3)
+ query.processAllAvailable()
+ query.stop()
+
+ intercept[AnalysisException] {
+ input.toDF().write
+ .format("memory")
+ .queryName("memStream")
+ .option("checkpointLocation", location)
+ .startStream()
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
index b4e5d4adf1728..c5f01da4fabbb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala
@@ -263,8 +263,8 @@ class HiveDDLCommandSuite extends PlanTest {
assert(desc.tableType == CatalogTableType.VIRTUAL_VIEW)
assert(desc.storage.locationUri.isEmpty)
assert(desc.schema == Seq.empty[CatalogColumn])
- assert(desc.viewText.contains("SELECT * FROM tab1"))
- assert(desc.viewOriginalText.contains("SELECT * FROM tab1"))
+ assert(desc.viewText == Option("SELECT * FROM tab1"))
+ assert(desc.viewOriginalText == Option("SELECT * FROM tab1"))
assert(desc.storage.serdeProperties == Map())
assert(desc.storage.inputFormat.isEmpty)
assert(desc.storage.outputFormat.isEmpty)
@@ -290,8 +290,8 @@ class HiveDDLCommandSuite extends PlanTest {
assert(desc.schema ==
CatalogColumn("col1", null, nullable = true, None) ::
CatalogColumn("col3", null, nullable = true, None) :: Nil)
- assert(desc.viewText.contains("SELECT * FROM tab1"))
- assert(desc.viewOriginalText.contains("SELECT * FROM tab1"))
+ assert(desc.viewText == Option("SELECT * FROM tab1"))
+ assert(desc.viewOriginalText == Option("SELECT * FROM tab1"))
assert(desc.storage.serdeProperties == Map())
assert(desc.storage.inputFormat.isEmpty)
assert(desc.storage.outputFormat.isEmpty)