cmd = new ArrayList<>();
- String envJavaHome;
- if (javaHome != null) {
- cmd.add(join(File.separator, javaHome, "bin", "java"));
- } else if ((envJavaHome = System.getenv("JAVA_HOME")) != null) {
- cmd.add(join(File.separator, envJavaHome, "bin", "java"));
- } else {
- cmd.add(join(File.separator, System.getProperty("java.home"), "bin", "java"));
+ String[] candidateJavaHomes = new String[] {
+ javaHome,
+ childEnv.get("JAVA_HOME"),
+ System.getenv("JAVA_HOME"),
+ System.getProperty("java.home")
+ };
+ for (String javaHome : candidateJavaHomes) {
+ if (javaHome != null) {
+ cmd.add(join(File.separator, javaHome, "bin", "java"));
+ break;
+ }
}
// Load extra JAVA_OPTS from conf/java-opts, if it exists.
diff --git a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java
index 5609f8492f4f4..7dfcf0e66734a 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/ChildProcAppHandle.java
@@ -18,6 +18,7 @@
package org.apache.spark.launcher;
import java.io.InputStream;
+import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -29,7 +30,7 @@ class ChildProcAppHandle extends AbstractAppHandle {
private static final Logger LOG = Logger.getLogger(ChildProcAppHandle.class.getName());
private volatile Process childProc;
- private OutputRedirector redirector;
+ private volatile OutputRedirector redirector;
ChildProcAppHandle(LauncherServer server) {
super(server);
@@ -46,6 +47,23 @@ public synchronized void disconnect() {
}
}
+ /**
+ * Parses the logs of {@code spark-submit} and returns the last exception thrown.
+ *
+ * Since {@link SparkLauncher} runs {@code spark-submit} in a sub-process, it's difficult to
+ * accurately retrieve the full {@link Throwable} from the {@code spark-submit} process.
+ * This method parses the logs of the sub-process and provides a best-effort attempt at
+ * returning the last exception thrown by the {@code spark-submit} process. Only the exception
+ * message is parsed, the associated stacktrace is meaningless.
+ *
+ * @return an {@link Optional} containing a {@link RuntimeException} with the parsed
+ * exception, otherwise returns a {@link Optional#EMPTY}
+ */
+ @Override
+ public Optional getError() {
+ return redirector != null ? Optional.ofNullable(redirector.getError()) : Optional.empty();
+ }
+
@Override
public synchronized void kill() {
if (!isDisposed()) {
diff --git a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java
index 15fbca0facef2..ba09050c756d2 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/InProcessAppHandle.java
@@ -17,7 +17,9 @@
package org.apache.spark.launcher;
+import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
+import java.util.Optional;
import java.util.concurrent.atomic.AtomicLong;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -31,6 +33,8 @@ class InProcessAppHandle extends AbstractAppHandle {
// Avoid really long thread names.
private static final int MAX_APP_NAME_LEN = 16;
+ private volatile Throwable error;
+
private Thread app;
InProcessAppHandle(LauncherServer server) {
@@ -51,6 +55,11 @@ public synchronized void kill() {
}
}
+ @Override
+ public Optional getError() {
+ return Optional.ofNullable(error);
+ }
+
synchronized void start(String appName, Method main, String[] args) {
CommandBuilderUtils.checkState(app == null, "Handle already started.");
@@ -62,7 +71,11 @@ synchronized void start(String appName, Method main, String[] args) {
try {
main.invoke(null, (Object) args);
} catch (Throwable t) {
+ if (t instanceof InvocationTargetException) {
+ t = t.getCause();
+ }
LOG.log(Level.WARNING, "Application failed with exception.", t);
+ error = t;
setState(State.FAILED);
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java
index 6f4b0bb38e031..0f097f8313925 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/OutputRedirector.java
@@ -37,6 +37,7 @@ class OutputRedirector {
private final ChildProcAppHandle callback;
private volatile boolean active;
+ private volatile Throwable error;
OutputRedirector(InputStream in, String loggerName, ThreadFactory tf) {
this(in, loggerName, tf, null);
@@ -61,6 +62,10 @@ private void redirect() {
while ((line = reader.readLine()) != null) {
if (active) {
sink.info(line.replaceFirst("\\s*$", ""));
+ if ((containsIgnoreCase(line, "Error") || containsIgnoreCase(line, "Exception")) &&
+ !line.contains("at ")) {
+ error = new RuntimeException(line);
+ }
}
}
} catch (IOException e) {
@@ -85,4 +90,24 @@ boolean isAlive() {
return thread.isAlive();
}
+ Throwable getError() {
+ return error;
+ }
+
+ /**
+ * Copied from Apache Commons Lang {@code StringUtils#containsIgnoreCase(String, String)}
+ */
+ private static boolean containsIgnoreCase(String str, String searchStr) {
+ if (str == null || searchStr == null) {
+ return false;
+ }
+ int len = searchStr.length();
+ int max = str.length() - len;
+ for (int i = 0; i <= max; i++) {
+ if (str.regionMatches(true, i, searchStr, 0, len)) {
+ return true;
+ }
+ }
+ return false;
+ }
}
diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java
index cefb4d1a95fb6..afec270e2b11c 100644
--- a/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java
+++ b/launcher/src/main/java/org/apache/spark/launcher/SparkAppHandle.java
@@ -17,6 +17,8 @@
package org.apache.spark.launcher;
+import java.util.Optional;
+
/**
* A handle to a running Spark application.
*
@@ -100,6 +102,12 @@ public boolean isFinal() {
*/
void disconnect();
+ /**
+ * If the application failed due to an error, return the underlying error. If the app
+ * succeeded, this method returns an empty {@link Optional}.
+ */
+ Optional getError();
+
/**
* Listener for updates to a handle's state. The callbacks do not receive information about
* what exactly has changed, just that an update has occurred.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 27a7db0b2f5d4..f2a5c11a34867 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -934,8 +934,8 @@ class LogisticRegressionModel private[spark] (
@Since("2.1.0") val interceptVector: Vector,
@Since("1.3.0") override val numClasses: Int,
private val isMultinomial: Boolean)
- extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
- with LogisticRegressionParams with MLWritable {
+ extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel] with MLWritable
+ with LogisticRegressionParams with HasTrainingSummary[LogisticRegressionTrainingSummary] {
require(coefficientMatrix.numRows == interceptVector.size, s"Dimension mismatch! Expected " +
s"coefficientMatrix.numRows == interceptVector.size, but ${coefficientMatrix.numRows} != " +
@@ -1018,20 +1018,16 @@ class LogisticRegressionModel private[spark] (
@Since("1.6.0")
override val numFeatures: Int = coefficientMatrix.numCols
- private var trainingSummary: Option[LogisticRegressionTrainingSummary] = None
-
/**
* Gets summary of model on training set. An exception is thrown
- * if `trainingSummary == None`.
+ * if `hasSummary` is false.
*/
@Since("1.5.0")
- def summary: LogisticRegressionTrainingSummary = trainingSummary.getOrElse {
- throw new SparkException("No training summary available for this LogisticRegressionModel")
- }
+ override def summary: LogisticRegressionTrainingSummary = super.summary
/**
* Gets summary of model on training set. An exception is thrown
- * if `trainingSummary == None` or it is a multiclass model.
+ * if `hasSummary` is false or it is a multiclass model.
*/
@Since("2.3.0")
def binarySummary: BinaryLogisticRegressionTrainingSummary = summary match {
@@ -1062,16 +1058,6 @@ class LogisticRegressionModel private[spark] (
(model, model.getProbabilityCol, model.getPredictionCol)
}
- private[classification]
- def setSummary(summary: Option[LogisticRegressionTrainingSummary]): this.type = {
- this.trainingSummary = summary
- this
- }
-
- /** Indicates whether a training summary exists for this model instance. */
- @Since("1.5.0")
- def hasSummary: Boolean = trainingSummary.isDefined
-
/**
* Evaluates the model on a test dataset.
*
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 1a94aefa3f563..49e9f51368131 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -87,8 +87,9 @@ private[clustering] trait BisectingKMeansParams extends Params with HasMaxIter
@Since("2.0.0")
class BisectingKMeansModel private[ml] (
@Since("2.0.0") override val uid: String,
- private val parentModel: MLlibBisectingKMeansModel
- ) extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable {
+ private val parentModel: MLlibBisectingKMeansModel)
+ extends Model[BisectingKMeansModel] with BisectingKMeansParams with MLWritable
+ with HasTrainingSummary[BisectingKMeansSummary] {
@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
@@ -143,28 +144,12 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
- private var trainingSummary: Option[BisectingKMeansSummary] = None
-
- private[clustering] def setSummary(summary: Option[BisectingKMeansSummary]): this.type = {
- this.trainingSummary = summary
- this
- }
-
- /**
- * Return true if there exists summary of model.
- */
- @Since("2.1.0")
- def hasSummary: Boolean = trainingSummary.nonEmpty
-
/**
* Gets summary of model on training set. An exception is
- * thrown if `trainingSummary == None`.
+ * thrown if `hasSummary` is false.
*/
@Since("2.1.0")
- def summary: BisectingKMeansSummary = trainingSummary.getOrElse {
- throw new SparkException(
- s"No training summary available for the ${this.getClass.getSimpleName}")
- }
+ override def summary: BisectingKMeansSummary = super.summary
}
object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 88abc1605d69f..bb10b3228b93f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -85,7 +85,8 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0") override val uid: String,
@Since("2.0.0") val weights: Array[Double],
@Since("2.0.0") val gaussians: Array[MultivariateGaussian])
- extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
+ extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable
+ with HasTrainingSummary[GaussianMixtureSummary] {
/** @group setParam */
@Since("2.1.0")
@@ -160,28 +161,13 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
- private var trainingSummary: Option[GaussianMixtureSummary] = None
-
- private[clustering] def setSummary(summary: Option[GaussianMixtureSummary]): this.type = {
- this.trainingSummary = summary
- this
- }
-
- /**
- * Return true if there exists summary of model.
- */
- @Since("2.0.0")
- def hasSummary: Boolean = trainingSummary.nonEmpty
-
/**
* Gets summary of model on training set. An exception is
- * thrown if `trainingSummary == None`.
+ * thrown if `hasSummary` is false.
*/
@Since("2.0.0")
- def summary: GaussianMixtureSummary = trainingSummary.getOrElse {
- throw new RuntimeException(
- s"No training summary available for the ${this.getClass.getSimpleName}")
- }
+ override def summary: GaussianMixtureSummary = super.summary
+
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 2eed84d51782a..319747d4a1930 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -107,7 +107,8 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
class KMeansModel private[ml] (
@Since("1.5.0") override val uid: String,
private[clustering] val parentModel: MLlibKMeansModel)
- extends Model[KMeansModel] with KMeansParams with GeneralMLWritable {
+ extends Model[KMeansModel] with KMeansParams with GeneralMLWritable
+ with HasTrainingSummary[KMeansSummary] {
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
@@ -153,28 +154,12 @@ class KMeansModel private[ml] (
@Since("1.6.0")
override def write: GeneralMLWriter = new GeneralMLWriter(this)
- private var trainingSummary: Option[KMeansSummary] = None
-
- private[clustering] def setSummary(summary: Option[KMeansSummary]): this.type = {
- this.trainingSummary = summary
- this
- }
-
- /**
- * Return true if there exists summary of model.
- */
- @Since("2.0.0")
- def hasSummary: Boolean = trainingSummary.nonEmpty
-
/**
* Gets summary of model on training set. An exception is
- * thrown if `trainingSummary == None`.
+ * thrown if `hasSummary` is false.
*/
@Since("2.0.0")
- def summary: KMeansSummary = trainingSummary.getOrElse {
- throw new SparkException(
- s"No training summary available for the ${this.getClass.getSimpleName}")
- }
+ override def summary: KMeansSummary = super.summary
}
/** Helper class for storing model data */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
index 1b9a3499947d9..d9a330f67e8dc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala
@@ -97,8 +97,8 @@ private[clustering] trait PowerIterationClusteringParams extends Params with Has
/**
* :: Experimental ::
* Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by
- * Lin and Cohen. From the abstract:
- * PIC finds a very low-dimensional embedding of a dataset using truncated power
+ * Lin and Cohen. From
+ * the abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power
* iteration on a normalized pair-wise similarity matrix of the data.
*
* This class is not yet an Estimator/Transformer, use `assignClusters` method to run the
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 031cd0d635bf4..616569bb55e4c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
+import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{Dataset, Row}
@@ -33,7 +33,8 @@ import org.apache.spark.sql.types.{DoubleType, FloatType}
@Since("1.4.0")
@Experimental
final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val uid: String)
- extends Evaluator with HasPredictionCol with HasLabelCol with DefaultParamsWritable {
+ extends Evaluator with HasPredictionCol with HasLabelCol
+ with HasWeightCol with DefaultParamsWritable {
@Since("1.4.0")
def this() = this(Identifiable.randomUID("regEval"))
@@ -69,6 +70,10 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
def setLabelCol(value: String): this.type = set(labelCol, value)
+ /** @group setParam */
+ @Since("3.0.0")
+ def setWeightCol(value: String): this.type = set(weightCol, value)
+
setDefault(metricName -> "rmse")
@Since("2.0.0")
@@ -77,11 +82,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
SchemaUtils.checkColumnTypes(schema, $(predictionCol), Seq(DoubleType, FloatType))
SchemaUtils.checkNumericType(schema, $(labelCol))
- val predictionAndLabels = dataset
- .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
+ val predictionAndLabelsWithWeights = dataset
+ .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType),
+ if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)))
.rdd
- .map { case Row(prediction: Double, label: Double) => (prediction, label) }
- val metrics = new RegressionMetrics(predictionAndLabels)
+ .map { case Row(prediction: Double, label: Double, weight: Double) =>
+ (prediction, label, weight) }
+ val metrics = new RegressionMetrics(predictionAndLabelsWithWeights)
val metric = $(metricName) match {
case "rmse" => metrics.rootMeanSquaredError
case "mse" => metrics.meanSquaredError
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
index 2a3413553a6af..b0006a8d4a58e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/PrefixSpan.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.fpm
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
+import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.fpm.{PrefixSpan => mllibPrefixSpan}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.col
@@ -135,7 +136,10 @@ final class PrefixSpan(@Since("2.4.0") override val uid: String) extends Params
* - `freq: Long`
*/
@Since("2.4.0")
- def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = {
+ def findFrequentSequentialPatterns(dataset: Dataset[_]): DataFrame = instrumented { instr =>
+ instr.logDataset(dataset)
+ instr.logParams(this, params: _*)
+
val sequenceColParam = $(sequenceCol)
val inputType = dataset.schema(sequenceColParam).dataType
require(inputType.isInstanceOf[ArrayType] &&
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala b/mllib/src/main/scala/org/apache/spark/ml/r/PowerIterationClusteringWrapper.scala
similarity index 56%
rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala
rename to mllib/src/main/scala/org/apache/spark/ml/r/PowerIterationClusteringWrapper.scala
index 7f7ef216cf485..b5dfad0224ed8 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/KerberosConfigSpec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/PowerIterationClusteringWrapper.scala
@@ -14,20 +14,26 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.deploy.k8s.features.hadooputils
-import io.fabric8.kubernetes.api.model.Secret
+package org.apache.spark.ml.r
-/**
- * Represents a given configuration of the Kerberos Configuration logic
- *
- * - The secret containing a DT, either previously specified or built on the fly
- * - The name of the secret where the DT will be stored
- * - The data item-key on the secret which correlates with where the current DT data is stored
- * - The Job User's username
- */
-private[spark] case class KerberosConfigSpec(
- dtSecret: Option[Secret],
- dtSecretName: String,
- dtSecretItemKey: String,
- jobUserName: String)
+import org.apache.spark.ml.clustering.PowerIterationClustering
+
+private[r] object PowerIterationClusteringWrapper {
+ def getPowerIterationClustering(
+ k: Int,
+ initMode: String,
+ maxIter: Int,
+ srcCol: String,
+ dstCol: String,
+ weightCol: String): PowerIterationClustering = {
+ val pic = new PowerIterationClustering()
+ .setK(k)
+ .setInitMode(initMode)
+ .setMaxIter(maxIter)
+ .setSrcCol(srcCol)
+ .setDstCol(dstCol)
+ if (weightCol != null) pic.setWeightCol(weightCol)
+ pic
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index abb60ea205751..885b13bf8dac3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -1001,7 +1001,8 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0") val coefficients: Vector,
@Since("2.0.0") val intercept: Double)
extends RegressionModel[Vector, GeneralizedLinearRegressionModel]
- with GeneralizedLinearRegressionBase with MLWritable {
+ with GeneralizedLinearRegressionBase with MLWritable
+ with HasTrainingSummary[GeneralizedLinearRegressionTrainingSummary] {
/**
* Sets the link prediction (linear predictor) column name.
@@ -1054,29 +1055,12 @@ class GeneralizedLinearRegressionModel private[ml] (
output.toDF()
}
- private var trainingSummary: Option[GeneralizedLinearRegressionTrainingSummary] = None
-
/**
* Gets R-like summary of model on training set. An exception is
* thrown if there is no summary available.
*/
@Since("2.0.0")
- def summary: GeneralizedLinearRegressionTrainingSummary = trainingSummary.getOrElse {
- throw new SparkException(
- "No training summary available for this GeneralizedLinearRegressionModel")
- }
-
- /**
- * Indicates if [[summary]] is available.
- */
- @Since("2.0.0")
- def hasSummary: Boolean = trainingSummary.nonEmpty
-
- private[regression]
- def setSummary(summary: Option[GeneralizedLinearRegressionTrainingSummary]): this.type = {
- this.trainingSummary = summary
- this
- }
+ override def summary: GeneralizedLinearRegressionTrainingSummary = super.summary
/**
* Evaluate the model on the given dataset, returning a summary of the results.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index ce6c12cc368dd..197828762d160 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -647,33 +647,20 @@ class LinearRegressionModel private[ml] (
@Since("1.3.0") val intercept: Double,
@Since("2.3.0") val scale: Double)
extends RegressionModel[Vector, LinearRegressionModel]
- with LinearRegressionParams with GeneralMLWritable {
+ with LinearRegressionParams with GeneralMLWritable
+ with HasTrainingSummary[LinearRegressionTrainingSummary] {
private[ml] def this(uid: String, coefficients: Vector, intercept: Double) =
this(uid, coefficients, intercept, 1.0)
- private var trainingSummary: Option[LinearRegressionTrainingSummary] = None
-
override val numFeatures: Int = coefficients.size
/**
* Gets summary (e.g. residuals, mse, r-squared ) of model on training set. An exception is
- * thrown if `trainingSummary == None`.
+ * thrown if `hasSummary` is false.
*/
@Since("1.5.0")
- def summary: LinearRegressionTrainingSummary = trainingSummary.getOrElse {
- throw new SparkException("No training summary available for this LinearRegressionModel")
- }
-
- private[regression]
- def setSummary(summary: Option[LinearRegressionTrainingSummary]): this.type = {
- this.trainingSummary = summary
- this
- }
-
- /** Indicates whether a training summary exists for this model instance. */
- @Since("1.5.0")
- def hasSummary: Boolean = trainingSummary.isDefined
+ override def summary: LinearRegressionTrainingSummary = super.summary
/**
* Evaluates the model on a test dataset.
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala
new file mode 100644
index 0000000000000..edb0208144e10
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.SparkException
+import org.apache.spark.annotation.Since
+
+
+/**
+ * Trait for models that provides Training summary.
+ *
+ * @tparam T Summary instance type
+ */
+@Since("3.0.0")
+private[ml] trait HasTrainingSummary[T] {
+
+ private[ml] final var trainingSummary: Option[T] = None
+
+ /** Indicates whether a training summary exists for this model instance. */
+ @Since("3.0.0")
+ def hasSummary: Boolean = trainingSummary.isDefined
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if if `hasSummary` is false.
+ */
+ @Since("3.0.0")
+ def summary: T = trainingSummary.getOrElse {
+ throw new SparkException(
+ s"No training summary available for this ${this.getClass.getSimpleName}")
+ }
+
+ private[ml] def setSummary(summary: Option[T]): this.type = {
+ this.trainingSummary = summary
+ this
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
index 020676cac5a64..525047973ad5c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -27,17 +27,18 @@ import org.apache.spark.sql.DataFrame
/**
* Evaluator for regression.
*
- * @param predictionAndObservations an RDD of (prediction, observation) pairs
+ * @param predAndObsWithOptWeight an RDD of either (prediction, observation, weight)
+ * or (prediction, observation) pairs
* @param throughOrigin True if the regression is through the origin. For example, in linear
* regression, it will be true without fitting intercept.
*/
@Since("1.2.0")
class RegressionMetrics @Since("2.0.0") (
- predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean)
+ predAndObsWithOptWeight: RDD[_ <: Product], throughOrigin: Boolean)
extends Logging {
@Since("1.2.0")
- def this(predictionAndObservations: RDD[(Double, Double)]) =
+ def this(predictionAndObservations: RDD[_ <: Product]) =
this(predictionAndObservations, false)
/**
@@ -52,10 +53,13 @@ class RegressionMetrics @Since("2.0.0") (
* Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
*/
private lazy val summary: MultivariateStatisticalSummary = {
- val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
- case (prediction, observation) => Vectors.dense(observation, observation - prediction)
+ val summary: MultivariateStatisticalSummary = predAndObsWithOptWeight.map {
+ case (prediction: Double, observation: Double, weight: Double) =>
+ (Vectors.dense(observation, observation - prediction), weight)
+ case (prediction: Double, observation: Double) =>
+ (Vectors.dense(observation, observation - prediction), 1.0)
}.treeAggregate(new MultivariateOnlineSummarizer())(
- (summary, v) => summary.add(v),
+ (summary, sample) => summary.add(sample._1, sample._2),
(sum1, sum2) => sum1.merge(sum2)
)
summary
@@ -63,11 +67,13 @@ class RegressionMetrics @Since("2.0.0") (
private lazy val SSy = math.pow(summary.normL2(0), 2)
private lazy val SSerr = math.pow(summary.normL2(1), 2)
- private lazy val SStot = summary.variance(0) * (summary.count - 1)
+ private lazy val SStot = summary.variance(0) * (summary.weightSum - 1)
private lazy val SSreg = {
val yMean = summary.mean(0)
- predictionAndObservations.map {
- case (prediction, _) => math.pow(prediction - yMean, 2)
+ predAndObsWithOptWeight.map {
+ case (prediction: Double, _: Double, weight: Double) =>
+ math.pow(prediction - yMean, 2) * weight
+ case (prediction: Double, _: Double) => math.pow(prediction - yMean, 2)
}.sum()
}
@@ -79,7 +85,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def explainedVariance: Double = {
- SSreg / summary.count
+ SSreg / summary.weightSum
}
/**
@@ -88,7 +94,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def meanAbsoluteError: Double = {
- summary.normL1(1) / summary.count
+ summary.normL1(1) / summary.weightSum
}
/**
@@ -97,7 +103,7 @@ class RegressionMetrics @Since("2.0.0") (
*/
@Since("1.2.0")
def meanSquaredError: Double = {
- SSerr / summary.count
+ SSerr / summary.weightSum
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
index e58860fea97d0..e32d615af2a47 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala
@@ -322,7 +322,7 @@ class BlockMatrix @Since("1.3.0") (
val m = numRows().toInt
val n = numCols().toInt
val mem = m * n / 125000
- if (mem > 500) logWarning(s"Storing this matrix will require $mem MB of memory!")
+ if (mem > 500) logWarning(s"Storing this matrix will require $mem MiB of memory!")
val localBlocks = blocks.collect()
val values = new Array[Double](m * n)
localBlocks.foreach { case ((blockRowIndex, blockColIndex), submat) =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 0554b6d8ff5b5..6d510e1633d67 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -52,7 +52,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
private var totalCnt: Long = 0
private var totalWeightSum: Double = 0.0
private var weightSquareSum: Double = 0.0
- private var weightSum: Array[Double] = _
+ private var currWeightSum: Array[Double] = _
private var nnz: Array[Long] = _
private var currMax: Array[Double] = _
private var currMin: Array[Double] = _
@@ -78,7 +78,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currM2n = Array.ofDim[Double](n)
currM2 = Array.ofDim[Double](n)
currL1 = Array.ofDim[Double](n)
- weightSum = Array.ofDim[Double](n)
+ currWeightSum = Array.ofDim[Double](n)
nnz = Array.ofDim[Long](n)
currMax = Array.fill[Double](n)(Double.MinValue)
currMin = Array.fill[Double](n)(Double.MaxValue)
@@ -91,7 +91,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val localCurrM2n = currM2n
val localCurrM2 = currM2
val localCurrL1 = currL1
- val localWeightSum = weightSum
+ val localWeightSum = currWeightSum
val localNumNonzeros = nnz
val localCurrMax = currMax
val localCurrMin = currMin
@@ -139,8 +139,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
weightSquareSum += other.weightSquareSum
var i = 0
while (i < n) {
- val thisNnz = weightSum(i)
- val otherNnz = other.weightSum(i)
+ val thisNnz = currWeightSum(i)
+ val otherNnz = other.currWeightSum(i)
val totalNnz = thisNnz + otherNnz
val totalCnnz = nnz(i) + other.nnz(i)
if (totalNnz != 0.0) {
@@ -157,7 +157,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currMax(i) = math.max(currMax(i), other.currMax(i))
currMin(i) = math.min(currMin(i), other.currMin(i))
}
- weightSum(i) = totalNnz
+ currWeightSum(i) = totalNnz
nnz(i) = totalCnnz
i += 1
}
@@ -170,7 +170,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
this.totalCnt = other.totalCnt
this.totalWeightSum = other.totalWeightSum
this.weightSquareSum = other.weightSquareSum
- this.weightSum = other.weightSum.clone()
+ this.currWeightSum = other.currWeightSum.clone()
this.nnz = other.nnz.clone()
this.currMax = other.currMax.clone()
this.currMin = other.currMin.clone()
@@ -189,7 +189,7 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val realMean = Array.ofDim[Double](n)
var i = 0
while (i < n) {
- realMean(i) = currMean(i) * (weightSum(i) / totalWeightSum)
+ realMean(i) = currMean(i) * (currWeightSum(i) / totalWeightSum)
i += 1
}
Vectors.dense(realMean)
@@ -214,8 +214,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
val len = currM2n.length
while (i < len) {
// We prevent variance from negative value caused by numerical error.
- realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * weightSum(i) *
- (totalWeightSum - weightSum(i)) / totalWeightSum) / denominator, 0.0)
+ realVariance(i) = math.max((currM2n(i) + deltaMean(i) * deltaMean(i) * currWeightSum(i) *
+ (totalWeightSum - currWeightSum(i)) / totalWeightSum) / denominator, 0.0)
i += 1
}
}
@@ -229,6 +229,11 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
@Since("1.1.0")
override def count: Long = totalCnt
+ /**
+ * Sum of weights.
+ */
+ override def weightSum: Double = totalWeightSum
+
/**
* Number of nonzero elements in each dimension.
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
index 39a16fb743d64..a4381032f8c0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateStatisticalSummary.scala
@@ -44,6 +44,12 @@ trait MultivariateStatisticalSummary {
@Since("1.0.0")
def count: Long
+ /**
+ * Sum of weights.
+ */
+ @Since("3.0.0")
+ def weightSum: Double
+
/**
* Number of nonzero elements (including explicitly presented zero values) in each column.
*/
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
index f1d517383643d..23809777f7d3a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -133,4 +133,54 @@ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
"root mean squared error mismatch")
assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch")
}
+
+ test("regression metrics with same (1.0) weight samples") {
+ val predictionAndObservationWithWeight = sc.parallelize(
+ Seq((2.25, 3.0, 1.0), (-0.25, -0.5, 1.0), (1.75, 2.0, 1.0), (7.75, 7.0, 1.0)), 2)
+ val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
+ assert(metrics.explainedVariance ~== 8.79687 absTol eps,
+ "explained variance regression score mismatch")
+ assert(metrics.meanAbsoluteError ~== 0.5 absTol eps, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 0.3125 absTol eps, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 0.55901 absTol eps,
+ "root mean squared error mismatch")
+ assert(metrics.r2 ~== 0.95717 absTol eps, "r2 score mismatch")
+ }
+
+ /**
+ * The following values are hand calculated using the formula:
+ * [[https://en.wikipedia.org/wiki/Weighted_arithmetic_mean#Reliability_weights]]
+ * preds = c(2.25, -0.25, 1.75, 7.75)
+ * obs = c(3.0, -0.5, 2.0, 7.0)
+ * weights = c(0.1, 0.2, 0.15, 0.05)
+ * count = 4
+ *
+ * Weighted metrics can be calculated with MultivariateStatisticalSummary.
+ * (observations, observations - predictions)
+ * mean (1.7, 0.05)
+ * variance (7.3, 0.3)
+ * numNonZeros (0.5, 0.5)
+ * max (7.0, 0.75)
+ * min (-0.5, -0.75)
+ * normL2 (2.0, 0.32596)
+ * normL1 (1.05, 0.2)
+ *
+ * explainedVariance: sum(pow((preds - 1.7),2)*weight) / weightedCount = 5.2425
+ * meanAbsoluteError: normL1(1) / weightedCount = 0.4
+ * meanSquaredError: pow(normL2(1),2) / weightedCount = 0.2125
+ * rootMeanSquaredError: sqrt(meanSquaredError) = 0.46098
+ * r2: 1 - pow(normL2(1),2) / (variance(0) * (weightedCount - 1)) = 1.02910
+ */
+ test("regression metrics with weighted samples") {
+ val predictionAndObservationWithWeight = sc.parallelize(
+ Seq((2.25, 3.0, 0.1), (-0.25, -0.5, 0.2), (1.75, 2.0, 0.15), (7.75, 7.0, 0.05)), 2)
+ val metrics = new RegressionMetrics(predictionAndObservationWithWeight, false)
+ assert(metrics.explainedVariance ~== 5.2425 absTol eps,
+ "explained variance regression score mismatch")
+ assert(metrics.meanAbsoluteError ~== 0.4 absTol eps, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 0.2125 absTol eps, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 0.46098 absTol eps,
+ "root mean squared error mismatch")
+ assert(metrics.r2 ~== 1.02910 absTol eps, "r2 score mismatch")
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 37eb794b0c5c9..6250b0363ee3b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -190,7 +190,7 @@ class GradientDescentClusterSuite extends SparkFunSuite with LocalClusterSparkCo
iter.map(i => (1.0, Vectors.dense(Array.fill(n)(random.nextDouble()))))
}.cache()
// If we serialize data directly in the task closure, the size of the serialized task would be
- // greater than 1MB and hence Spark would throw an error.
+ // greater than 1MiB and hence Spark would throw an error.
val (weights, loss) = GradientDescent.runMiniBatchSGD(
points,
new LogisticGradient,
diff --git a/pom.xml b/pom.xml
index 61321a1450708..310d7de955125 100644
--- a/pom.xml
+++ b/pom.xml
@@ -156,7 +156,7 @@
3.4.1
3.2.2
- 2.12.7
+ 2.12.8
2.12
--diff --test
@@ -173,7 +173,7 @@
3.8.1
3.2.10
- 3.0.10
+ 3.0.11
2.22.2
2.9.3
3.5.2
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 1c83cf5860c58..7bb70a29195d6 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -36,6 +36,9 @@ object MimaExcludes {
// Exclude rules for 3.0.x
lazy val v30excludes = v24excludes ++ Seq(
+ // [SPARK-24243][CORE] Expose exceptions from InProcessAppHandle
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.launcher.SparkAppHandle.getError"),
+
// [SPARK-25867] Remove KMeans computeCost
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.KMeansModel.computeCost"),
@@ -214,6 +217,13 @@ object MimaExcludes {
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this"),
+ // [SPARK-26139] Implement shuffle write metrics in SQL
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ShuffleDependency.this"),
+
+ // [SPARK-26362][CORE] Remove 'spark.driver.allowMultipleContexts' to disallow multiple creation of SparkContexts
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.setActiveContext"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.SparkContext.markPartiallyConstructed"),
+
// Data Source V2 API changes
(problem: Problem) => problem match {
case MissingClassProblem(cls) =>
@@ -525,7 +535,10 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseColMajor"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toDenseMatrix"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.toSparseMatrix"),
- ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes")
+ ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.ml.linalg.Matrix.getSizeInBytes"),
+
+ // [SPARK-18693] Added weightSum to trait MultivariateStatisticalSummary
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.stat.MultivariateStatisticalSummary.weightSum")
) ++ Seq(
// [SPARK-17019] Expose on-heap and off-heap memory usage in various places
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.scheduler.SparkListenerBlockManagerAdded.copy"),
diff --git a/project/build.properties b/project/build.properties
index d03985d980ec8..23aa187fb35a7 100644
--- a/project/build.properties
+++ b/project/build.properties
@@ -14,4 +14,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-sbt.version=0.13.17
+sbt.version=0.13.18
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 1180bf91baa5a..6137ed25a0dd9 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -63,6 +63,9 @@ class SparkContext(object):
Main entry point for Spark functionality. A SparkContext represents the
connection to a Spark cluster, and can be used to create L{RDD} and
broadcast variables on that cluster.
+
+ .. note:: Only one :class:`SparkContext` should be active per JVM. You must `stop()`
+ the active :class:`SparkContext` before creating a new one.
"""
_gateway = None
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index ce028512357f2..6ddfce95a3d4d 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -23,7 +23,7 @@
from pyspark.ml import Estimator, Model
from pyspark.ml.param.shared import *
from pyspark.ml.regression import DecisionTreeModel, DecisionTreeRegressionModel, \
- RandomForestParams, TreeEnsembleModel, TreeEnsembleParams
+ GBTParams, HasVarianceImpurity, RandomForestParams, TreeEnsembleModel, TreeEnsembleParams
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaParams
from pyspark.ml.wrapper import JavaWrapper
@@ -895,15 +895,6 @@ def getImpurity(self):
return self.getOrDefault(self.impurity)
-class GBTParams(TreeEnsembleParams):
- """
- Private class to track supported GBT params.
-
- .. versionadded:: 1.4.0
- """
- supportedLossTypes = ["logistic"]
-
-
@inherit_doc
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
@@ -1174,9 +1165,31 @@ def trees(self):
return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))]
+class GBTClassifierParams(GBTParams, HasVarianceImpurity):
+ """
+ Private class to track supported GBTClassifier params.
+
+ .. versionadded:: 3.0.0
+ """
+
+ supportedLossTypes = ["logistic"]
+
+ lossType = Param(Params._dummy(), "lossType",
+ "Loss function which GBT tries to minimize (case-insensitive). " +
+ "Supported options: " + ", ".join(supportedLossTypes),
+ typeConverter=TypeConverters.toString)
+
+ @since("1.4.0")
+ def getLossType(self):
+ """
+ Gets the value of lossType or its default value.
+ """
+ return self.getOrDefault(self.lossType)
+
+
@inherit_doc
-class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
- GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
+class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
+ GBTClassifierParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) `_
@@ -1242,32 +1255,28 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
[0.25..., 0.23..., 0.21..., 0.19..., 0.18...]
>>> model.numClasses
2
+ >>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
+ >>> gbt.getValidationIndicatorCol()
+ 'validationIndicator'
+ >>> gbt.getValidationTol()
+ 0.01
.. versionadded:: 1.4.0
"""
- lossType = Param(Params._dummy(), "lossType",
- "Loss function which GBT tries to minimize (case-insensitive). " +
- "Supported options: " + ", ".join(GBTParams.supportedLossTypes),
- typeConverter=TypeConverters.toString)
-
- stepSize = Param(Params._dummy(), "stepSize",
- "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
- "the contribution of each estimator.",
- typeConverter=TypeConverters.toFloat)
-
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
- maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
- featureSubsetStrategy="all"):
+ maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, impurity="variance",
+ featureSubsetStrategy="all", validationTol=0.01, validationIndicatorCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
- featureSubsetStrategy="all")
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
+ validationIndicatorCol=None)
"""
super(GBTClassifier, self).__init__()
self._java_obj = self._new_java_obj(
@@ -1275,7 +1284,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0,
- featureSubsetStrategy="all")
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -1285,13 +1294,15 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0,
- featureSubsetStrategy="all"):
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
+ validationIndicatorCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0, \
- featureSubsetStrategy="all")
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
+ validationIndicatorCol=None)
Sets params for Gradient Boosted Tree Classification.
"""
kwargs = self._input_kwargs
@@ -1307,13 +1318,6 @@ def setLossType(self, value):
"""
return self._set(lossType=value)
- @since("1.4.0")
- def getLossType(self):
- """
- Gets the value of lossType or its default value.
- """
- return self.getOrDefault(self.lossType)
-
@since("2.4.0")
def setFeatureSubsetStrategy(self, value):
"""
@@ -1321,6 +1325,13 @@ def setFeatureSubsetStrategy(self, value):
"""
return self._set(featureSubsetStrategy=value)
+ @since("3.0.0")
+ def setValidationIndicatorCol(self, value):
+ """
+ Sets the value of :py:attr:`validationIndicatorCol`.
+ """
+ return self._set(validationIndicatorCol=value)
+
class GBTClassificationModel(TreeEnsembleModel, JavaClassificationModel, JavaMLWritable,
JavaMLReadable):
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index d0b507ec5dad4..d8a6dfb7d3a71 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -1193,8 +1193,8 @@ class PowerIterationClustering(HasMaxIter, HasWeightCol, JavaParams, JavaMLReada
.. note:: Experimental
Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by
- `Lin and Cohen `_. From the abstract:
- PIC finds a very low-dimensional embedding of a dataset using truncated power
+ `Lin and Cohen `_. From the
+ abstract: PIC finds a very low-dimensional embedding of a dataset using truncated power
iteration on a normalized pair-wise similarity matrix of the data.
This class is not yet an Estimator/Transformer, use :py:func:`assignClusters` method
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 6cc80e181e5e0..08ae58246adb6 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -192,6 +192,7 @@ def approxSimilarityJoin(self, datasetA, datasetB, threshold, distCol="distCol")
"datasetA" and "datasetB", and a column "distCol" is added to show the distance
between each pair.
"""
+ threshold = TypeConverters.toFloat(threshold)
return self._call_java("approxSimilarityJoin", datasetA, datasetB, threshold, distCol)
@@ -239,6 +240,16 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp
| 3| 6| 2.23606797749979|
+---+---+-----------------+
...
+ >>> model.approxSimilarityJoin(df, df2, 3, distCol="EuclideanDistance").select(
+ ... col("datasetA.id").alias("idA"),
+ ... col("datasetB.id").alias("idB"),
+ ... col("EuclideanDistance")).show()
+ +---+---+-----------------+
+ |idA|idB|EuclideanDistance|
+ +---+---+-----------------+
+ | 3| 6| 2.23606797749979|
+ +---+---+-----------------+
+ ...
>>> brpPath = temp_path + "/brp"
>>> brp.save(brpPath)
>>> brp2 = BucketedRandomProjectionLSH.load(brpPath)
@@ -1648,22 +1659,22 @@ class OneHotEncoder(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid
at most a single one-value per row that indicates the input category index.
For example with 5 categories, an input value of 2.0 would map to an output vector of
`[0.0, 0.0, 1.0, 0.0]`.
- The last category is not included by default (configurable via `dropLast`),
+ The last category is not included by default (configurable via :py:attr:`dropLast`),
because it makes the vector entries sum up to one, and hence linearly dependent.
So an input value of 4.0 maps to `[0.0, 0.0, 0.0, 0.0]`.
- Note: This is different from scikit-learn's OneHotEncoder, which keeps all categories.
- The output vectors are sparse.
+ .. note:: This is different from scikit-learn's OneHotEncoder, which keeps all categories.
+ The output vectors are sparse.
- When `handleInvalid` is configured to 'keep', an extra "category" indicating invalid values is
- added as last category. So when `dropLast` is true, invalid values are encoded as all-zeros
- vector.
+ When :py:attr:`handleInvalid` is configured to 'keep', an extra "category" indicating invalid
+ values is added as last category. So when :py:attr:`dropLast` is true, invalid values are
+ encoded as all-zeros vector.
- Note: When encoding multi-column by using `inputCols` and `outputCols` params, input/output
- cols come in pairs, specified by the order in the arrays, and each pair is treated
- independently.
+ .. note:: When encoding multi-column by using :py:attr:`inputCols` and
+ :py:attr:`outputCols` params, input/output cols come in pairs, specified by the order in
+ the arrays, and each pair is treated independently.
- See `StringIndexer` for converting categorical values into category indices
+ .. seealso:: :py:class:`StringIndexer` for converting categorical values into category indices
>>> from pyspark.ml.linalg import Vectors
>>> df = spark.createDataFrame([(0.0,), (1.0,), (2.0,)], ["input"])
@@ -1671,7 +1682,7 @@ class OneHotEncoder(JavaEstimator, HasInputCols, HasOutputCols, HasHandleInvalid
>>> model = ohe.fit(df)
>>> model.transform(df).head().output
SparseVector(2, {0: 1.0})
- >>> ohePath = temp_path + "/oheEstimator"
+ >>> ohePath = temp_path + "/ohe"
>>> ohe.save(ohePath)
>>> loadedOHE = OneHotEncoder.load(ohePath)
>>> loadedOHE.getInputCols() == ohe.getInputCols()
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index e45ba840b412b..1b0c8c5d28b78 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -164,7 +164,10 @@ def get$Name(self):
"False", "TypeConverters.toBoolean"),
("loss", "the loss function to be optimized.", None, "TypeConverters.toString"),
("distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.",
- "'euclidean'", "TypeConverters.toString")]
+ "'euclidean'", "TypeConverters.toString"),
+ ("validationIndicatorCol", "name of the column that indicates whether each row is for " +
+ "training or for validation. False indicates training; true indicates validation.",
+ None, "TypeConverters.toString")]
code = []
for name, doc, defaultValueStr, typeConverter in shared:
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 618f5bf0a8103..6405b9fce7efb 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -702,6 +702,53 @@ def getLoss(self):
return self.getOrDefault(self.loss)
+class HasDistanceMeasure(Params):
+ """
+ Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'.
+ """
+
+ distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString)
+
+ def __init__(self):
+ super(HasDistanceMeasure, self).__init__()
+ self._setDefault(distanceMeasure='euclidean')
+
+ def setDistanceMeasure(self, value):
+ """
+ Sets the value of :py:attr:`distanceMeasure`.
+ """
+ return self._set(distanceMeasure=value)
+
+ def getDistanceMeasure(self):
+ """
+ Gets the value of distanceMeasure or its default value.
+ """
+ return self.getOrDefault(self.distanceMeasure)
+
+
+class HasValidationIndicatorCol(Params):
+ """
+ Mixin for param validationIndicatorCol: name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.
+ """
+
+ validationIndicatorCol = Param(Params._dummy(), "validationIndicatorCol", "name of the column that indicates whether each row is for training or for validation. False indicates training; true indicates validation.", typeConverter=TypeConverters.toString)
+
+ def __init__(self):
+ super(HasValidationIndicatorCol, self).__init__()
+
+ def setValidationIndicatorCol(self, value):
+ """
+ Sets the value of :py:attr:`validationIndicatorCol`.
+ """
+ return self._set(validationIndicatorCol=value)
+
+ def getValidationIndicatorCol(self):
+ """
+ Gets the value of validationIndicatorCol or its default value.
+ """
+ return self.getOrDefault(self.validationIndicatorCol)
+
+
class DecisionTreeParams(Params):
"""
Mixin for Decision Tree parameters.
@@ -790,27 +837,3 @@ def getCacheNodeIds(self):
"""
return self.getOrDefault(self.cacheNodeIds)
-
-class HasDistanceMeasure(Params):
- """
- Mixin for param distanceMeasure: the distance measure. Supported options: 'euclidean' and 'cosine'.
- """
-
- distanceMeasure = Param(Params._dummy(), "distanceMeasure", "the distance measure. Supported options: 'euclidean' and 'cosine'.", typeConverter=TypeConverters.toString)
-
- def __init__(self):
- super(HasDistanceMeasure, self).__init__()
- self._setDefault(distanceMeasure='euclidean')
-
- def setDistanceMeasure(self, value):
- """
- Sets the value of :py:attr:`distanceMeasure`.
- """
- return self._set(distanceMeasure=value)
-
- def getDistanceMeasure(self):
- """
- Gets the value of distanceMeasure or its default value.
- """
- return self.getOrDefault(self.distanceMeasure)
-
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 98f4361351847..78cb4a6703554 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -650,19 +650,20 @@ def getFeatureSubsetStrategy(self):
return self.getOrDefault(self.featureSubsetStrategy)
-class TreeRegressorParams(Params):
+class HasVarianceImpurity(Params):
"""
Private class to track supported impurity measures.
"""
supportedImpurities = ["variance"]
+
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
def __init__(self):
- super(TreeRegressorParams, self).__init__()
+ super(HasVarianceImpurity, self).__init__()
@since("1.4.0")
def setImpurity(self, value):
@@ -679,6 +680,10 @@ def getImpurity(self):
return self.getOrDefault(self.impurity)
+class TreeRegressorParams(HasVarianceImpurity):
+ pass
+
+
class RandomForestParams(TreeEnsembleParams):
"""
Private class to track supported random forest parameters.
@@ -705,12 +710,52 @@ def getNumTrees(self):
return self.getOrDefault(self.numTrees)
-class GBTParams(TreeEnsembleParams):
+class GBTParams(TreeEnsembleParams, HasMaxIter, HasStepSize, HasValidationIndicatorCol):
"""
Private class to track supported GBT params.
"""
+
+ stepSize = Param(Params._dummy(), "stepSize",
+ "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
+ "the contribution of each estimator.",
+ typeConverter=TypeConverters.toFloat)
+
+ validationTol = Param(Params._dummy(), "validationTol",
+ "Threshold for stopping early when fit with validation is used. " +
+ "If the error rate on the validation input changes by less than the " +
+ "validationTol, then learning will stop early (before `maxIter`). " +
+ "This parameter is ignored when fit without validation is used.",
+ typeConverter=TypeConverters.toFloat)
+
+ @since("3.0.0")
+ def getValidationTol(self):
+ """
+ Gets the value of validationTol or its default value.
+ """
+ return self.getOrDefault(self.validationTol)
+
+
+class GBTRegressorParams(GBTParams, TreeRegressorParams):
+ """
+ Private class to track supported GBTRegressor params.
+
+ .. versionadded:: 3.0.0
+ """
+
supportedLossTypes = ["squared", "absolute"]
+ lossType = Param(Params._dummy(), "lossType",
+ "Loss function which GBT tries to minimize (case-insensitive). " +
+ "Supported options: " + ", ".join(supportedLossTypes),
+ typeConverter=TypeConverters.toString)
+
+ @since("1.4.0")
+ def getLossType(self):
+ """
+ Gets the value of lossType or its default value.
+ """
+ return self.getOrDefault(self.lossType)
+
@inherit_doc
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
@@ -1030,9 +1075,9 @@ def featureImportances(self):
@inherit_doc
-class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
- GBTParams, HasCheckpointInterval, HasStepSize, HasSeed, JavaMLWritable,
- JavaMLReadable, TreeRegressorParams):
+class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
+ GBTRegressorParams, HasCheckpointInterval, HasSeed, JavaMLWritable,
+ JavaMLReadable):
"""
`Gradient-Boosted Trees (GBTs) `_
learning algorithm for regression.
@@ -1079,39 +1124,36 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
... ["label", "features"])
>>> model.evaluateEachIteration(validation, "squared")
[0.0, 0.0, 0.0, 0.0, 0.0]
+ >>> gbt = gbt.setValidationIndicatorCol("validationIndicator")
+ >>> gbt.getValidationIndicatorCol()
+ 'validationIndicator'
+ >>> gbt.getValidationTol()
+ 0.01
.. versionadded:: 1.4.0
"""
- lossType = Param(Params._dummy(), "lossType",
- "Loss function which GBT tries to minimize (case-insensitive). " +
- "Supported options: " + ", ".join(GBTParams.supportedLossTypes),
- typeConverter=TypeConverters.toString)
-
- stepSize = Param(Params._dummy(), "stepSize",
- "Step size (a.k.a. learning rate) in interval (0, 1] for shrinking " +
- "the contribution of each estimator.",
- typeConverter=TypeConverters.toFloat)
-
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
- impurity="variance", featureSubsetStrategy="all"):
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01,
+ validationIndicatorCol=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
- impurity="variance", featureSubsetStrategy="all")
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
+ validationIndicatorCol=None)
"""
super(GBTRegressor, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.regression.GBTRegressor", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1,
- impurity="variance", featureSubsetStrategy="all")
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01)
kwargs = self._input_kwargs
self.setParams(**kwargs)
@@ -1121,13 +1163,15 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0,
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None,
- impuriy="variance", featureSubsetStrategy="all"):
+ impuriy="variance", featureSubsetStrategy="all", validationTol=0.01,
+ validationIndicatorCol=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, \
checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, seed=None, \
- impurity="variance", featureSubsetStrategy="all")
+ impurity="variance", featureSubsetStrategy="all", validationTol=0.01, \
+ validationIndicatorCol=None)
Sets params for Gradient Boosted Tree Regression.
"""
kwargs = self._input_kwargs
@@ -1143,13 +1187,6 @@ def setLossType(self, value):
"""
return self._set(lossType=value)
- @since("1.4.0")
- def getLossType(self):
- """
- Gets the value of lossType or its default value.
- """
- return self.getOrDefault(self.lossType)
-
@since("2.4.0")
def setFeatureSubsetStrategy(self, value):
"""
@@ -1157,6 +1194,13 @@ def setFeatureSubsetStrategy(self, value):
"""
return self._set(featureSubsetStrategy=value)
+ @since("3.0.0")
+ def setValidationIndicatorCol(self, value):
+ """
+ Sets the value of :py:attr:`validationIndicatorCol`.
+ """
+ return self._set(validationIndicatorCol=value)
+
class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable):
"""
diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py
index 4bc8904acd31c..bf2ad2d267bb2 100644
--- a/python/pyspark/mllib/tests/test_streaming_algorithms.py
+++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py
@@ -364,7 +364,7 @@ def condition():
return True
return "Latest errors: " + ", ".join(map(lambda x: str(x), errors))
- self._eventually(condition)
+ self._eventually(condition, timeout=60.0)
class StreamingLinearRegressionWithTests(MLLibStreamingTestCase):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 8bd6897df925f..b6e17cab44e9c 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -127,7 +127,7 @@ def __new__(cls, mean, confidence, low, high):
def _parse_memory(s):
"""
Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
- return the value in MB
+ return the value in MiB
>>> _parse_memory("256m")
256
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index ff9a612b77f61..fd4695210fb7c 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -185,6 +185,39 @@ def loads(self, obj):
raise NotImplementedError
+class ArrowCollectSerializer(Serializer):
+ """
+ Deserialize a stream of batches followed by batch order information. Used in
+ DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM.
+ """
+
+ def __init__(self):
+ self.serializer = ArrowStreamSerializer()
+
+ def dump_stream(self, iterator, stream):
+ return self.serializer.dump_stream(iterator, stream)
+
+ def load_stream(self, stream):
+ """
+ Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
+ a list of indices that can be used to put the RecordBatches in the correct order.
+ """
+ # load the batches
+ for batch in self.serializer.load_stream(stream):
+ yield batch
+
+ # load the batch order indices
+ num = read_int(stream)
+ batch_order = []
+ for i in xrange(num):
+ index = read_int(stream)
+ batch_order.append(index)
+ yield batch_order
+
+ def __repr__(self):
+ return "ArrowCollectSerializer(%s)" % self.serializer
+
+
class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
@@ -248,7 +281,10 @@ def create_array(s, t):
# TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0.
return pa.Array.from_pandas(s.apply(
lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t)
- return pa.Array.from_pandas(s, mask=mask, type=t)
+ elif LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
+ # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
+ return pa.Array.from_pandas(s, mask=mask, type=t)
+ return pa.Array.from_pandas(s, mask=mask, type=t, safe=False)
arrs = [create_array(s, t) for s, t in series]
return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))])
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index bd0ac0039ffe1..5d2d63850e9b2 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -37,7 +37,7 @@
process = None
def get_used_memory():
- """ Return the used memory in MB """
+ """ Return the used memory in MiB """
global process
if process is None or process._pid != os.getpid():
process = psutil.Process(os.getpid())
@@ -50,7 +50,7 @@ def get_used_memory():
except ImportError:
def get_used_memory():
- """ Return the used memory in MB """
+ """ Return the used memory in MiB """
if platform.system() == 'Linux':
for line in open('/proc/self/status'):
if line.startswith('VmRSS:'):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1b1092c409be0..a1056d0b787e3 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -29,7 +29,7 @@
from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
-from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \
+from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@@ -2168,7 +2168,14 @@ def _collectAsArrow(self):
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.collectAsArrowToPython()
- return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
+
+ # Collect list of un-ordered batches where last element is a list of correct order indices
+ results = list(_load_from_socket(sock_info, ArrowCollectSerializer()))
+ batches = results[:-1]
+ batch_order = results[-1]
+
+ # Re-order the batch list using the correct order
+ return [batches[i] for i in batch_order]
##########################################################################################
# Pandas compatibility
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f98e550e39da8..d188de39e21c7 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2982,8 +2982,7 @@ def pandas_udf(f=None, returnType=None, functionType=None):
| 2| 6.0|
+---+-----------+
- This example shows using grouped aggregated UDFs as window functions. Note that only
- unbounded window frame is supported at the moment:
+ This example shows using grouped aggregated UDFs as window functions.
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
>>> from pyspark.sql import Window
@@ -2993,20 +2992,24 @@ def pandas_udf(f=None, returnType=None, functionType=None):
>>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP
... def mean_udf(v):
... return v.mean()
- >>> w = Window \\
- ... .partitionBy('id') \\
- ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+ >>> w = (Window.partitionBy('id')
+ ... .orderBy('v')
+ ... .rowsBetween(-1, 0))
>>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP
+---+----+------+
| id| v|mean_v|
+---+----+------+
- | 1| 1.0| 1.5|
+ | 1| 1.0| 1.0|
| 1| 2.0| 1.5|
- | 2| 3.0| 6.0|
- | 2| 5.0| 6.0|
- | 2|10.0| 6.0|
+ | 2| 3.0| 3.0|
+ | 2| 5.0| 4.0|
+ | 2|10.0| 7.5|
+---+----+------+
+ .. note:: For performance reasons, the input series to window functions are not copied.
+ Therefore, mutating the input series is not allowed and will cause incorrect results.
+ For the same reason, users should also not rely on the index of the input series.
+
.. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window`
.. note:: The user-defined functions are considered deterministic by default. Due to
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index 1d2dd4d808930..7b10512a43294 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -211,7 +211,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
set, it uses the default value, ``PERMISSIVE``.
* ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \
- into a field configured by ``columnNameOfCorruptRecord``, and sets other \
+ into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \
fields to ``null``. To keep corrupt records, an user can set a string type \
field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \
schema does not have the field, it drops corrupt records during parsing. \
@@ -424,7 +424,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
set, it uses the default value, ``PERMISSIVE``.
* ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \
- into a field configured by ``columnNameOfCorruptRecord``, and sets other \
+ into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \
fields to ``null``. To keep corrupt records, an user can set a string type \
field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \
schema does not have the field, it drops corrupt records during parsing. \
diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py
index d92b0d5677e25..fc23b9d99c34a 100644
--- a/python/pyspark/sql/streaming.py
+++ b/python/pyspark/sql/streaming.py
@@ -441,7 +441,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
set, it uses the default value, ``PERMISSIVE``.
* ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \
- into a field configured by ``columnNameOfCorruptRecord``, and sets other \
+ into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \
fields to ``null``. To keep corrupt records, an user can set a string type \
field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \
schema does not have the field, it drops corrupt records during parsing. \
@@ -648,7 +648,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
set, it uses the default value, ``PERMISSIVE``.
* ``PERMISSIVE`` : when it meets a corrupted record, puts the malformed string \
- into a field configured by ``columnNameOfCorruptRecord``, and sets other \
+ into a field configured by ``columnNameOfCorruptRecord``, and sets malformed \
fields to ``null``. To keep corrupt records, an user can set a string type \
field named ``columnNameOfCorruptRecord`` in an user-defined schema. If a \
schema does not have the field, it drops corrupt records during parsing. \
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index 6e75e82d58009..21fe5000df5d9 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -381,6 +381,34 @@ def test_timestamp_dst(self):
self.assertPandasEqual(pdf, df_from_python.toPandas())
self.assertPandasEqual(pdf, df_from_pandas.toPandas())
+ def test_toPandas_batch_order(self):
+
+ def delay_first_part(partition_index, iterator):
+ if partition_index == 0:
+ time.sleep(0.1)
+ return iterator
+
+ # Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python
+ def run_test(num_records, num_parts, max_records, use_delay=False):
+ df = self.spark.range(num_records, numPartitions=num_parts).toDF("a")
+ if use_delay:
+ df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
+ with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
+ pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+ self.assertPandasEqual(pdf, pdf_arrow)
+
+ cases = [
+ (1024, 512, 2), # Use large num partitions for more likely collecting out of order
+ (64, 8, 2, True), # Use delay in first partition to force collecting out of order
+ (64, 64, 1), # Test single batch per partition
+ (64, 1, 64), # Test single partition, single batch
+ (64, 1, 8), # Test single partition, multiple batches
+ (30, 7, 2), # Test different sized partitions
+ ]
+
+ for case in cases:
+ run_test(*case)
+
class EncryptionArrowTests(ArrowTests):
diff --git a/python/pyspark/sql/tests/test_pandas_udf.py b/python/pyspark/sql/tests/test_pandas_udf.py
index c4b5478a7e893..d4d9679649ee9 100644
--- a/python/pyspark/sql/tests/test_pandas_udf.py
+++ b/python/pyspark/sql/tests/test_pandas_udf.py
@@ -17,12 +17,16 @@
import unittest
+from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.sql.utils import ParseException
+from pyspark.rdd import PythonEvalType
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
+from py4j.protocol import Py4JJavaError
+
@unittest.skipIf(
not have_pandas or not have_pyarrow,
@@ -30,9 +34,6 @@
class PandasUDFTests(ReusedSQLTestCase):
def test_pandas_udf_basic(self):
- from pyspark.rdd import PythonEvalType
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
udf = pandas_udf(lambda x: x, DoubleType())
self.assertEqual(udf.returnType, DoubleType())
self.assertEqual(udf.evalType, PythonEvalType.SQL_SCALAR_PANDAS_UDF)
@@ -65,10 +66,6 @@ def test_pandas_udf_basic(self):
self.assertEqual(udf.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
def test_pandas_udf_decorator(self):
- from pyspark.rdd import PythonEvalType
- from pyspark.sql.functions import pandas_udf, PandasUDFType
- from pyspark.sql.types import StructType, StructField, DoubleType
-
@pandas_udf(DoubleType())
def foo(x):
return x
@@ -114,8 +111,6 @@ def foo(x):
self.assertEqual(foo.evalType, PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF)
def test_udf_wrong_arg(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
with QuietTest(self.sc):
with self.assertRaises(ParseException):
@pandas_udf('blah')
@@ -151,9 +146,6 @@ def foo(k, v, w):
return k
def test_stopiteration_in_udf(self):
- from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
- from py4j.protocol import Py4JJavaError
-
def foo(x):
raise StopIteration()
diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
index 5383704434c85..18264ead2fd08 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_agg.py
@@ -17,6 +17,9 @@
import unittest
+from pyspark.rdd import PythonEvalType
+from pyspark.sql.functions import array, explode, col, lit, mean, sum, \
+ udf, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
@@ -31,7 +34,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
@property
def data(self):
- from pyspark.sql.functions import array, explode, col, lit
return self.spark.range(10).toDF('id') \
.withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))) \
@@ -40,8 +42,6 @@ def data(self):
@property
def python_plus_one(self):
- from pyspark.sql.functions import udf
-
@udf('double')
def plus_one(v):
assert isinstance(v, (int, float))
@@ -51,7 +51,6 @@ def plus_one(v):
@property
def pandas_scalar_plus_two(self):
import pandas as pd
- from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.SCALAR)
def plus_two(v):
@@ -61,8 +60,6 @@ def plus_two(v):
@property
def pandas_agg_mean_udf(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def avg(v):
return v.mean()
@@ -70,8 +67,6 @@ def avg(v):
@property
def pandas_agg_sum_udf(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def sum(v):
return v.sum()
@@ -80,7 +75,6 @@ def sum(v):
@property
def pandas_agg_weighted_mean_udf(self):
import numpy as np
- from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def weighted_mean(v, w):
@@ -88,8 +82,6 @@ def weighted_mean(v, w):
return weighted_mean
def test_manual(self):
- from pyspark.sql.functions import pandas_udf, array
-
df = self.data
sum_udf = self.pandas_agg_sum_udf
mean_udf = self.pandas_agg_mean_udf
@@ -118,8 +110,6 @@ def test_manual(self):
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_basic(self):
- from pyspark.sql.functions import col, lit, mean
-
df = self.data
weighted_mean_udf = self.pandas_agg_weighted_mean_udf
@@ -150,9 +140,6 @@ def test_basic(self):
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
def test_unsupported_types(self):
- from pyspark.sql.types import DoubleType, MapType
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
with QuietTest(self.sc):
with self.assertRaisesRegexp(NotImplementedError, 'not supported'):
pandas_udf(
@@ -173,8 +160,6 @@ def mean_and_std_udf(v):
return {v.mean(): v.std()}
def test_alias(self):
- from pyspark.sql.functions import mean
-
df = self.data
mean_udf = self.pandas_agg_mean_udf
@@ -187,8 +172,6 @@ def test_mixed_sql(self):
"""
Test mixing group aggregate pandas UDF with sql expression.
"""
- from pyspark.sql.functions import sum
-
df = self.data
sum_udf = self.pandas_agg_sum_udf
@@ -225,8 +208,6 @@ def test_mixed_udfs(self):
"""
Test mixing group aggregate pandas UDF with python UDF and scalar pandas UDF.
"""
- from pyspark.sql.functions import sum
-
df = self.data
plus_one = self.python_plus_one
plus_two = self.pandas_scalar_plus_two
@@ -292,8 +273,6 @@ def test_multiple_udfs(self):
"""
Test multiple group aggregate pandas UDFs in one agg function.
"""
- from pyspark.sql.functions import sum, mean
-
df = self.data
mean_udf = self.pandas_agg_mean_udf
sum_udf = self.pandas_agg_sum_udf
@@ -315,8 +294,6 @@ def test_multiple_udfs(self):
self.assertPandasEqual(expected1, result1)
def test_complex_groupby(self):
- from pyspark.sql.functions import sum
-
df = self.data
sum_udf = self.pandas_agg_sum_udf
plus_one = self.python_plus_one
@@ -359,8 +336,6 @@ def test_complex_groupby(self):
self.assertPandasEqual(expected7.toPandas(), result7.toPandas())
def test_complex_expressions(self):
- from pyspark.sql.functions import col, sum
-
df = self.data
plus_one = self.python_plus_one
plus_two = self.pandas_scalar_plus_two
@@ -434,7 +409,6 @@ def test_complex_expressions(self):
self.assertPandasEqual(expected3, result3)
def test_retain_group_columns(self):
- from pyspark.sql.functions import sum
with self.sql_conf({"spark.sql.retainGroupColumns": False}):
df = self.data
sum_udf = self.pandas_agg_sum_udf
@@ -444,8 +418,6 @@ def test_retain_group_columns(self):
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_array_type(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
df = self.data
array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array', PandasUDFType.GROUPED_AGG)
@@ -453,8 +425,6 @@ def test_array_type(self):
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
def test_invalid_args(self):
- from pyspark.sql.functions import mean
-
df = self.data
plus_one = self.python_plus_one
mean_udf = self.pandas_agg_mean_udf
@@ -478,9 +448,6 @@ def test_invalid_args(self):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
def test_register_vectorized_udf_basic(self):
- from pyspark.sql.functions import pandas_udf
- from pyspark.rdd import PythonEvalType
-
sum_pandas_udf = pandas_udf(
lambda v: v.sum(), "integer", PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
index bfecc071386e9..80e70349b78d3 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
@@ -18,7 +18,12 @@
import datetime
import unittest
+from collections import OrderedDict
+from decimal import Decimal
+from distutils.version import LooseVersion
+
from pyspark.sql import Row
+from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
@@ -32,16 +37,12 @@ class GroupedMapPandasUDFTests(ReusedSQLTestCase):
@property
def data(self):
- from pyspark.sql.functions import array, explode, col, lit
return self.spark.range(10).toDF('id') \
.withColumn("vs", array([lit(i) for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))).drop('vs')
def test_supported_types(self):
- from decimal import Decimal
- from distutils.version import LooseVersion
import pyarrow as pa
- from pyspark.sql.functions import pandas_udf, PandasUDFType
values = [
1, 2, 3,
@@ -131,8 +132,6 @@ def test_supported_types(self):
self.assertPandasEqual(expected3, result3)
def test_array_type_correct(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
-
df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
output_schema = StructType(
@@ -151,8 +150,6 @@ def test_array_type_correct(self):
self.assertPandasEqual(expected, result)
def test_register_grouped_map_udf(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
with QuietTest(self.sc):
with self.assertRaisesRegexp(
@@ -161,7 +158,6 @@ def test_register_grouped_map_udf(self):
self.spark.catalog.registerFunction("foo_udf", foo_udf)
def test_decorator(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data
@pandas_udf(
@@ -176,7 +172,6 @@ def foo(pdf):
self.assertPandasEqual(expected, result)
def test_coerce(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data
foo = pandas_udf(
@@ -191,7 +186,6 @@ def test_coerce(self):
self.assertPandasEqual(expected, result)
def test_complex_groupby(self):
- from pyspark.sql.functions import pandas_udf, col, PandasUDFType
df = self.data
@pandas_udf(
@@ -210,7 +204,6 @@ def normalize(pdf):
self.assertPandasEqual(expected, result)
def test_empty_groupby(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data
@pandas_udf(
@@ -229,7 +222,6 @@ def normalize(pdf):
self.assertPandasEqual(expected, result)
def test_datatype_string(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
df = self.data
foo_udf = pandas_udf(
@@ -243,8 +235,6 @@ def test_datatype_string(self):
self.assertPandasEqual(expected, result)
def test_wrong_return_type(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
@@ -255,7 +245,6 @@ def test_wrong_return_type(self):
PandasUDFType.GROUPED_MAP)
def test_wrong_args(self):
- from pyspark.sql.functions import udf, pandas_udf, sum, PandasUDFType
df = self.data
with QuietTest(self.sc):
@@ -277,9 +266,7 @@ def test_wrong_args(self):
pandas_udf(lambda x, y: x, DoubleType(), PandasUDFType.SCALAR))
def test_unsupported_types(self):
- from distutils.version import LooseVersion
import pyarrow as pa
- from pyspark.sql.functions import pandas_udf, PandasUDFType
common_err_msg = 'Invalid returnType.*grouped map Pandas UDF.*'
unsupported_types = [
@@ -300,7 +287,6 @@ def test_unsupported_types(self):
# Regression test for SPARK-23314
def test_timestamp_dst(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
dt = [datetime.datetime(2015, 11, 1, 0, 30),
datetime.datetime(2015, 11, 1, 1, 30),
@@ -311,12 +297,12 @@ def test_timestamp_dst(self):
self.assertPandasEqual(df.toPandas(), result.toPandas())
def test_udf_with_key(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
+ import numpy as np
+
df = self.data
pdf = df.toPandas()
def foo1(key, pdf):
- import numpy as np
assert type(key) == tuple
assert type(key[0]) == np.int64
@@ -326,7 +312,6 @@ def foo1(key, pdf):
v4=pdf.v * pdf.id.mean())
def foo2(key, pdf):
- import numpy as np
assert type(key) == tuple
assert type(key[0]) == np.int64
assert type(key[1]) == np.int32
@@ -385,9 +370,7 @@ def foo3(key, pdf):
self.assertPandasEqual(expected4, result4)
def test_column_order(self):
- from collections import OrderedDict
import pandas as pd
- from pyspark.sql.functions import pandas_udf, PandasUDFType
# Helper function to set column names from a list
def rename_pdf(pdf, names):
@@ -468,12 +451,17 @@ def invalid_positional_types(pdf):
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, "KeyError: 'id'"):
grouped_df.apply(column_name_typo).collect()
- with self.assertRaisesRegexp(Exception, "No cast implemented"):
- grouped_df.apply(invalid_positional_types).collect()
+ import pyarrow as pa
+ if LooseVersion(pa.__version__) < LooseVersion("0.11.0"):
+ # TODO: see ARROW-1949. Remove when the minimum PyArrow version becomes 0.11.0.
+ with self.assertRaisesRegexp(Exception, "No cast implemented"):
+ grouped_df.apply(invalid_positional_types).collect()
+ else:
+ with self.assertRaisesRegexp(Exception, "an integer is required"):
+ grouped_df.apply(invalid_positional_types).collect()
def test_positional_assignment_conf(self):
import pandas as pd
- from pyspark.sql.functions import pandas_udf, PandasUDFType
with self.sql_conf({
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName": False}):
@@ -489,9 +477,7 @@ def foo(_):
self.assertEqual(r.b, 1)
def test_self_join_with_pandas(self):
- import pyspark.sql.functions as F
-
- @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
+ @pandas_udf('key long, col string', PandasUDFType.GROUPED_MAP)
def dummy_pandas_udf(df):
return df[['key', 'col']]
@@ -501,12 +487,11 @@ def dummy_pandas_udf(df):
# this was throwing an AnalysisException before SPARK-24208
res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'),
- F.col('temp0.key') == F.col('temp1.key'))
+ col('temp0.key') == col('temp1.key'))
self.assertEquals(res.count(), 5)
def test_mixed_scalar_udfs_followed_by_grouby_apply(self):
import pandas as pd
- from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
df = self.spark.range(0, 10).toDF('v1')
df = df.withColumn('v2', udf(lambda x: x + 1, 'int')(df['v1'])) \
diff --git a/python/pyspark/sql/tests/test_pandas_udf_scalar.py b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
index 2f585a3725988..6a6865a9fb16d 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_scalar.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_scalar.py
@@ -16,12 +16,20 @@
#
import datetime
import os
+import random
import shutil
import sys
import tempfile
import time
import unittest
+from datetime import date, datetime
+from decimal import Decimal
+from distutils.version import LooseVersion
+
+from pyspark.rdd import PythonEvalType
+from pyspark.sql import Column
+from pyspark.sql.functions import array, col, expr, lit, sum, udf, pandas_udf
from pyspark.sql.types import Row
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
@@ -59,18 +67,16 @@ def tearDownClass(cls):
@property
def nondeterministic_vectorized_udf(self):
- from pyspark.sql.functions import pandas_udf
+ import pandas as pd
+ import numpy as np
@pandas_udf('double')
def random_udf(v):
- import pandas as pd
- import numpy as np
return pd.Series(np.random.random(len(v)))
random_udf = random_udf.asNondeterministic()
return random_udf
def test_pandas_udf_tokenize(self):
- from pyspark.sql.functions import pandas_udf
tokenize = pandas_udf(lambda s: s.apply(lambda str: str.split(' ')),
ArrayType(StringType()))
self.assertEqual(tokenize.returnType, ArrayType(StringType()))
@@ -79,7 +85,6 @@ def test_pandas_udf_tokenize(self):
self.assertEqual([Row(hi=[u'hi', u'boo']), Row(hi=[u'bye', u'boo'])], result.collect())
def test_pandas_udf_nested_arrays(self):
- from pyspark.sql.functions import pandas_udf
tokenize = pandas_udf(lambda s: s.apply(lambda str: [str.split(' ')]),
ArrayType(ArrayType(StringType())))
self.assertEqual(tokenize.returnType, ArrayType(ArrayType(StringType())))
@@ -88,7 +93,6 @@ def test_pandas_udf_nested_arrays(self):
self.assertEqual([Row(hi=[[u'hi', u'boo']]), Row(hi=[[u'bye', u'boo']])], result.collect())
def test_vectorized_udf_basic(self):
- from pyspark.sql.functions import pandas_udf, col, array
df = self.spark.range(10).select(
col('id').cast('string').alias('str'),
col('id').cast('int').alias('int'),
@@ -114,9 +118,6 @@ def test_vectorized_udf_basic(self):
self.assertEquals(df.collect(), res.collect())
def test_register_nondeterministic_vectorized_udf_basic(self):
- from pyspark.sql.functions import pandas_udf
- from pyspark.rdd import PythonEvalType
- import random
random_pandas_udf = pandas_udf(
lambda x: random.randint(6, 6) + x, IntegerType()).asNondeterministic()
self.assertEqual(random_pandas_udf.deterministic, False)
@@ -129,7 +130,6 @@ def test_register_nondeterministic_vectorized_udf_basic(self):
self.assertEqual(row[0], 7)
def test_vectorized_udf_null_boolean(self):
- from pyspark.sql.functions import pandas_udf, col
data = [(True,), (True,), (None,), (False,)]
schema = StructType().add("bool", BooleanType())
df = self.spark.createDataFrame(data, schema)
@@ -138,7 +138,6 @@ def test_vectorized_udf_null_boolean(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_byte(self):
- from pyspark.sql.functions import pandas_udf, col
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("byte", ByteType())
df = self.spark.createDataFrame(data, schema)
@@ -147,7 +146,6 @@ def test_vectorized_udf_null_byte(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_short(self):
- from pyspark.sql.functions import pandas_udf, col
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("short", ShortType())
df = self.spark.createDataFrame(data, schema)
@@ -156,7 +154,6 @@ def test_vectorized_udf_null_short(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_int(self):
- from pyspark.sql.functions import pandas_udf, col
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("int", IntegerType())
df = self.spark.createDataFrame(data, schema)
@@ -165,7 +162,6 @@ def test_vectorized_udf_null_int(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_long(self):
- from pyspark.sql.functions import pandas_udf, col
data = [(None,), (2,), (3,), (4,)]
schema = StructType().add("long", LongType())
df = self.spark.createDataFrame(data, schema)
@@ -174,7 +170,6 @@ def test_vectorized_udf_null_long(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_float(self):
- from pyspark.sql.functions import pandas_udf, col
data = [(3.0,), (5.0,), (-1.0,), (None,)]
schema = StructType().add("float", FloatType())
df = self.spark.createDataFrame(data, schema)
@@ -183,7 +178,6 @@ def test_vectorized_udf_null_float(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_double(self):
- from pyspark.sql.functions import pandas_udf, col
data = [(3.0,), (5.0,), (-1.0,), (None,)]
schema = StructType().add("double", DoubleType())
df = self.spark.createDataFrame(data, schema)
@@ -192,8 +186,6 @@ def test_vectorized_udf_null_double(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_decimal(self):
- from decimal import Decimal
- from pyspark.sql.functions import pandas_udf, col
data = [(Decimal(3.0),), (Decimal(5.0),), (Decimal(-1.0),), (None,)]
schema = StructType().add("decimal", DecimalType(38, 18))
df = self.spark.createDataFrame(data, schema)
@@ -202,7 +194,6 @@ def test_vectorized_udf_null_decimal(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_string(self):
- from pyspark.sql.functions import pandas_udf, col
data = [("foo",), (None,), ("bar",), ("bar",)]
schema = StructType().add("str", StringType())
df = self.spark.createDataFrame(data, schema)
@@ -211,7 +202,6 @@ def test_vectorized_udf_null_string(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_string_in_udf(self):
- from pyspark.sql.functions import pandas_udf, col
import pandas as pd
df = self.spark.range(10)
str_f = pandas_udf(lambda x: pd.Series(map(str, x)), StringType())
@@ -220,7 +210,6 @@ def test_vectorized_udf_string_in_udf(self):
self.assertEquals(expected.collect(), actual.collect())
def test_vectorized_udf_datatype_string(self):
- from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10).select(
col('id').cast('string').alias('str'),
col('id').cast('int').alias('int'),
@@ -244,9 +233,8 @@ def test_vectorized_udf_datatype_string(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_null_binary(self):
- from distutils.version import LooseVersion
import pyarrow as pa
- from pyspark.sql.functions import pandas_udf, col
+
if LooseVersion(pa.__version__) < LooseVersion("0.10.0"):
with QuietTest(self.sc):
with self.assertRaisesRegexp(
@@ -262,7 +250,6 @@ def test_vectorized_udf_null_binary(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_array_type(self):
- from pyspark.sql.functions import pandas_udf, col
data = [([1, 2],), ([3, 4],)]
array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
df = self.spark.createDataFrame(data, schema=array_schema)
@@ -271,7 +258,6 @@ def test_vectorized_udf_array_type(self):
self.assertEquals(df.collect(), result.collect())
def test_vectorized_udf_null_array(self):
- from pyspark.sql.functions import pandas_udf, col
data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
df = self.spark.createDataFrame(data, schema=array_schema)
@@ -280,7 +266,6 @@ def test_vectorized_udf_null_array(self):
self.assertEquals(df.collect(), result.collect())
def test_vectorized_udf_complex(self):
- from pyspark.sql.functions import pandas_udf, col, expr
df = self.spark.range(10).select(
col('id').cast('int').alias('a'),
col('id').cast('int').alias('b'),
@@ -293,7 +278,6 @@ def test_vectorized_udf_complex(self):
self.assertEquals(expected.collect(), res.collect())
def test_vectorized_udf_exception(self):
- from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
raise_exception = pandas_udf(lambda x: x * (1 / 0), LongType())
with QuietTest(self.sc):
@@ -301,8 +285,8 @@ def test_vectorized_udf_exception(self):
df.select(raise_exception(col('id'))).collect()
def test_vectorized_udf_invalid_length(self):
- from pyspark.sql.functions import pandas_udf, col
import pandas as pd
+
df = self.spark.range(10)
raise_exception = pandas_udf(lambda _: pd.Series(1), LongType())
with QuietTest(self.sc):
@@ -312,7 +296,6 @@ def test_vectorized_udf_invalid_length(self):
df.select(raise_exception(col('id'))).collect()
def test_vectorized_udf_chained(self):
- from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
f = pandas_udf(lambda x: x + 1, LongType())
g = pandas_udf(lambda x: x - 1, LongType())
@@ -320,7 +303,6 @@ def test_vectorized_udf_chained(self):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_wrong_return_type(self):
- from pyspark.sql.functions import pandas_udf
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
@@ -328,7 +310,6 @@ def test_vectorized_udf_wrong_return_type(self):
pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
def test_vectorized_udf_return_scalar(self):
- from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
f = pandas_udf(lambda x: 1.0, DoubleType())
with QuietTest(self.sc):
@@ -336,7 +317,6 @@ def test_vectorized_udf_return_scalar(self):
df.select(f(col('id'))).collect()
def test_vectorized_udf_decorator(self):
- from pyspark.sql.functions import pandas_udf, col
df = self.spark.range(10)
@pandas_udf(returnType=LongType())
@@ -346,21 +326,18 @@ def identity(x):
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_empty_partition(self):
- from pyspark.sql.functions import pandas_udf, col
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
f = pandas_udf(lambda x: x, LongType())
res = df.select(f(col('id')))
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_varargs(self):
- from pyspark.sql.functions import pandas_udf, col
df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2))
f = pandas_udf(lambda *v: v[0], LongType())
res = df.select(f(col('id')))
self.assertEquals(df.collect(), res.collect())
def test_vectorized_udf_unsupported_types(self):
- from pyspark.sql.functions import pandas_udf
with QuietTest(self.sc):
with self.assertRaisesRegexp(
NotImplementedError,
@@ -368,8 +345,6 @@ def test_vectorized_udf_unsupported_types(self):
pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
def test_vectorized_udf_dates(self):
- from pyspark.sql.functions import pandas_udf, col
- from datetime import date
schema = StructType().add("idx", LongType()).add("date", DateType())
data = [(0, date(1969, 1, 1),),
(1, date(2012, 2, 2),),
@@ -405,8 +380,6 @@ def check_data(idx, date, date_copy):
self.assertIsNone(result[i][3]) # "check_data" col
def test_vectorized_udf_timestamps(self):
- from pyspark.sql.functions import pandas_udf, col
- from datetime import datetime
schema = StructType([
StructField("idx", LongType(), True),
StructField("timestamp", TimestampType(), True)])
@@ -447,8 +420,8 @@ def check_data(idx, timestamp, timestamp_copy):
self.assertIsNone(result[i][3]) # "check_data" col
def test_vectorized_udf_return_timestamp_tz(self):
- from pyspark.sql.functions import pandas_udf, col
import pandas as pd
+
df = self.spark.range(10)
@pandas_udf(returnType=TimestampType())
@@ -465,8 +438,8 @@ def gen_timestamps(id):
self.assertEquals(expected, ts)
def test_vectorized_udf_check_config(self):
- from pyspark.sql.functions import pandas_udf, col
import pandas as pd
+
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
df = self.spark.range(10, numPartitions=1)
@@ -479,9 +452,8 @@ def check_records_per_batch(x):
self.assertTrue(r <= 3)
def test_vectorized_udf_timestamps_respect_session_timezone(self):
- from pyspark.sql.functions import pandas_udf, col
- from datetime import datetime
import pandas as pd
+
schema = StructType([
StructField("idx", LongType(), True),
StructField("timestamp", TimestampType(), True)])
@@ -519,8 +491,6 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self):
def test_nondeterministic_vectorized_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
- from pyspark.sql.functions import pandas_udf, col
-
@pandas_udf('double')
def plus_ten(v):
return v + 10
@@ -533,8 +503,6 @@ def plus_ten(v):
self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10))
def test_nondeterministic_vectorized_udf_in_aggregate(self):
- from pyspark.sql.functions import sum
-
df = self.spark.range(10)
random_udf = self.nondeterministic_vectorized_udf
@@ -545,8 +513,6 @@ def test_nondeterministic_vectorized_udf_in_aggregate(self):
df.agg(sum(random_udf(df.id))).collect()
def test_register_vectorized_udf_basic(self):
- from pyspark.rdd import PythonEvalType
- from pyspark.sql.functions import pandas_udf, col, expr
df = self.spark.range(10).select(
col('id').cast('int').alias('a'),
col('id').cast('int').alias('b'))
@@ -563,11 +529,10 @@ def test_register_vectorized_udf_basic(self):
# Regression test for SPARK-23314
def test_timestamp_dst(self):
- from pyspark.sql.functions import pandas_udf
# Daylight saving time for Los Angeles for 2015 is Sun, Nov 1 at 2:00 am
- dt = [datetime.datetime(2015, 11, 1, 0, 30),
- datetime.datetime(2015, 11, 1, 1, 30),
- datetime.datetime(2015, 11, 1, 2, 30)]
+ dt = [datetime(2015, 11, 1, 0, 30),
+ datetime(2015, 11, 1, 1, 30),
+ datetime(2015, 11, 1, 2, 30)]
df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
foo_udf = pandas_udf(lambda x: x, 'timestamp')
result = df.withColumn('time', foo_udf(df.time))
@@ -593,7 +558,6 @@ def test_type_annotation(self):
def test_mixed_udf(self):
import pandas as pd
- from pyspark.sql.functions import col, udf, pandas_udf
df = self.spark.range(0, 1).toDF('v')
@@ -696,8 +660,6 @@ def f4(x):
def test_mixed_udf_and_sql(self):
import pandas as pd
- from pyspark.sql import Column
- from pyspark.sql.functions import udf, pandas_udf
df = self.spark.range(0, 1).toDF('v')
@@ -758,7 +720,6 @@ def test_datasource_with_udf(self):
# This needs to a separate test because Arrow dependency is optional
import pandas as pd
import numpy as np
- from pyspark.sql.functions import pandas_udf, lit, col
path = tempfile.mkdtemp()
shutil.rmtree(path)
diff --git a/python/pyspark/sql/tests/test_pandas_udf_window.py b/python/pyspark/sql/tests/test_pandas_udf_window.py
index f0e6d2696df62..3ba98e76468b3 100644
--- a/python/pyspark/sql/tests/test_pandas_udf_window.py
+++ b/python/pyspark/sql/tests/test_pandas_udf_window.py
@@ -18,6 +18,8 @@
import unittest
from pyspark.sql.utils import AnalysisException
+from pyspark.sql.functions import array, explode, col, lit, mean, min, max, rank, \
+ udf, pandas_udf, PandasUDFType
from pyspark.sql.window import Window
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
@@ -30,7 +32,6 @@
class WindowPandasUDFTests(ReusedSQLTestCase):
@property
def data(self):
- from pyspark.sql.functions import array, explode, col, lit
return self.spark.range(10).toDF('id') \
.withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \
.withColumn("v", explode(col('vs'))) \
@@ -39,18 +40,23 @@ def data(self):
@property
def python_plus_one(self):
- from pyspark.sql.functions import udf
return udf(lambda v: v + 1, 'double')
@property
def pandas_scalar_time_two(self):
- from pyspark.sql.functions import pandas_udf
return pandas_udf(lambda v: v * 2, 'double')
@property
- def pandas_agg_mean_udf(self):
+ def pandas_agg_count_udf(self):
from pyspark.sql.functions import pandas_udf, PandasUDFType
+ @pandas_udf('long', PandasUDFType.GROUPED_AGG)
+ def count(v):
+ return len(v)
+ return count
+
+ @property
+ def pandas_agg_mean_udf(self):
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def avg(v):
return v.mean()
@@ -58,8 +64,6 @@ def avg(v):
@property
def pandas_agg_max_udf(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def max(v):
return v.max()
@@ -67,8 +71,6 @@ def max(v):
@property
def pandas_agg_min_udf(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
@pandas_udf('double', PandasUDFType.GROUPED_AGG)
def min(v):
return v.min()
@@ -77,7 +79,7 @@ def min(v):
@property
def unbounded_window(self):
return Window.partitionBy('id') \
- .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
+ .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing).orderBy('v')
@property
def ordered_window(self):
@@ -87,9 +89,33 @@ def ordered_window(self):
def unpartitioned_window(self):
return Window.partitionBy()
- def test_simple(self):
- from pyspark.sql.functions import mean
+ @property
+ def sliding_row_window(self):
+ return Window.partitionBy('id').orderBy('v').rowsBetween(-2, 1)
+ @property
+ def sliding_range_window(self):
+ return Window.partitionBy('id').orderBy('v').rangeBetween(-2, 4)
+
+ @property
+ def growing_row_window(self):
+ return Window.partitionBy('id').orderBy('v').rowsBetween(Window.unboundedPreceding, 3)
+
+ @property
+ def growing_range_window(self):
+ return Window.partitionBy('id').orderBy('v') \
+ .rangeBetween(Window.unboundedPreceding, 4)
+
+ @property
+ def shrinking_row_window(self):
+ return Window.partitionBy('id').orderBy('v').rowsBetween(-2, Window.unboundedFollowing)
+
+ @property
+ def shrinking_range_window(self):
+ return Window.partitionBy('id').orderBy('v') \
+ .rangeBetween(-3, Window.unboundedFollowing)
+
+ def test_simple(self):
df = self.data
w = self.unbounded_window
@@ -105,24 +131,20 @@ def test_simple(self):
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
def test_multiple_udfs(self):
- from pyspark.sql.functions import max, min, mean
-
df = self.data
w = self.unbounded_window
result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \
- .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
- .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
+ .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \
+ .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w))
expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \
- .withColumn('max_v', max(df['v']).over(w)) \
- .withColumn('min_w', min(df['w']).over(w))
+ .withColumn('max_v', max(df['v']).over(w)) \
+ .withColumn('min_w', min(df['w']).over(w))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_replace_existing(self):
- from pyspark.sql.functions import mean
-
df = self.data
w = self.unbounded_window
@@ -132,8 +154,6 @@ def test_replace_existing(self):
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_mixed_sql(self):
- from pyspark.sql.functions import mean
-
df = self.data
w = self.unbounded_window
mean_udf = self.pandas_agg_mean_udf
@@ -144,8 +164,6 @@ def test_mixed_sql(self):
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
def test_mixed_udf(self):
- from pyspark.sql.functions import mean
-
df = self.data
w = self.unbounded_window
@@ -171,8 +189,6 @@ def test_mixed_udf(self):
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
def test_without_partitionBy(self):
- from pyspark.sql.functions import mean
-
df = self.data
w = self.unpartitioned_window
mean_udf = self.pandas_agg_mean_udf
@@ -187,8 +203,6 @@ def test_without_partitionBy(self):
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
def test_mixed_sql_and_udf(self):
- from pyspark.sql.functions import max, min, rank, col
-
df = self.data
w = self.unbounded_window
ow = self.ordered_window
@@ -204,16 +218,16 @@ def test_mixed_sql_and_udf(self):
# Test chaining sql aggregate function and udf
result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
- .withColumn('min_v', min(df['v']).over(w)) \
- .withColumn('v_diff', col('max_v') - col('min_v')) \
- .drop('max_v', 'min_v')
+ .withColumn('min_v', min(df['v']).over(w)) \
+ .withColumn('v_diff', col('max_v') - col('min_v')) \
+ .drop('max_v', 'min_v')
expected3 = expected1
# Test mixing sql window function and udf
result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \
- .withColumn('rank', rank().over(ow))
+ .withColumn('rank', rank().over(ow))
expected4 = df.withColumn('max_v', max(df['v']).over(w)) \
- .withColumn('rank', rank().over(ow))
+ .withColumn('rank', rank().over(ow))
self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
self.assertPandasEqual(expected2.toPandas(), result2.toPandas())
@@ -221,8 +235,6 @@ def test_mixed_sql_and_udf(self):
self.assertPandasEqual(expected4.toPandas(), result4.toPandas())
def test_array_type(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
df = self.data
w = self.unbounded_window
@@ -231,12 +243,8 @@ def test_array_type(self):
self.assertEquals(result1.first()['v2'], [1.0, 2.0])
def test_invalid_args(self):
- from pyspark.sql.functions import pandas_udf, PandasUDFType
-
df = self.data
w = self.unbounded_window
- ow = self.ordered_window
- mean_udf = self.pandas_agg_mean_udf
with QuietTest(self.sc):
with self.assertRaisesRegexp(
@@ -245,11 +253,101 @@ def test_invalid_args(self):
foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP)
df.withColumn('v2', foo_udf(df['v']).over(w))
- with QuietTest(self.sc):
- with self.assertRaisesRegexp(
- AnalysisException,
- '.*Only unbounded window frame is supported.*'):
- df.withColumn('mean_v', mean_udf(df['v']).over(ow))
+ def test_bounded_simple(self):
+ from pyspark.sql.functions import mean, max, min, count
+
+ df = self.data
+ w1 = self.sliding_row_window
+ w2 = self.shrinking_range_window
+
+ plus_one = self.python_plus_one
+ count_udf = self.pandas_agg_count_udf
+ mean_udf = self.pandas_agg_mean_udf
+ max_udf = self.pandas_agg_max_udf
+ min_udf = self.pandas_agg_min_udf
+
+ result1 = df.withColumn('mean_v', mean_udf(plus_one(df['v'])).over(w1)) \
+ .withColumn('count_v', count_udf(df['v']).over(w2)) \
+ .withColumn('max_v', max_udf(df['v']).over(w2)) \
+ .withColumn('min_v', min_udf(df['v']).over(w1))
+
+ expected1 = df.withColumn('mean_v', mean(plus_one(df['v'])).over(w1)) \
+ .withColumn('count_v', count(df['v']).over(w2)) \
+ .withColumn('max_v', max(df['v']).over(w2)) \
+ .withColumn('min_v', min(df['v']).over(w1))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_growing_window(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w1 = self.growing_row_window
+ w2 = self.growing_range_window
+
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
+ .withColumn('m2', mean_udf(df['v']).over(w2))
+
+ expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
+ .withColumn('m2', mean(df['v']).over(w2))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_sliding_window(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w1 = self.sliding_row_window
+ w2 = self.sliding_range_window
+
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
+ .withColumn('m2', mean_udf(df['v']).over(w2))
+
+ expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
+ .withColumn('m2', mean(df['v']).over(w2))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_shrinking_window(self):
+ from pyspark.sql.functions import mean
+
+ df = self.data
+ w1 = self.shrinking_row_window
+ w2 = self.shrinking_range_window
+
+ mean_udf = self.pandas_agg_mean_udf
+
+ result1 = df.withColumn('m1', mean_udf(df['v']).over(w1)) \
+ .withColumn('m2', mean_udf(df['v']).over(w2))
+
+ expected1 = df.withColumn('m1', mean(df['v']).over(w1)) \
+ .withColumn('m2', mean(df['v']).over(w2))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
+
+ def test_bounded_mixed(self):
+ from pyspark.sql.functions import mean, max
+
+ df = self.data
+ w1 = self.sliding_row_window
+ w2 = self.unbounded_window
+
+ mean_udf = self.pandas_agg_mean_udf
+ max_udf = self.pandas_agg_max_udf
+
+ result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w1)) \
+ .withColumn('max_v', max_udf(df['v']).over(w2)) \
+ .withColumn('mean_unbounded_v', mean_udf(df['v']).over(w1))
+
+ expected1 = df.withColumn('mean_v', mean(df['v']).over(w1)) \
+ .withColumn('max_v', max(df['v']).over(w2)) \
+ .withColumn('mean_unbounded_v', mean(df['v']).over(w1))
+
+ self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
if __name__ == "__main__":
diff --git a/python/pyspark/sql/tests/test_udf.py b/python/pyspark/sql/tests/test_udf.py
index ed298f724d551..12cf8c7de1dad 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -23,7 +23,7 @@
from pyspark import SparkContext
from pyspark.sql import SparkSession, Column, Row
-from pyspark.sql.functions import UserDefinedFunction
+from pyspark.sql.functions import UserDefinedFunction, udf
from pyspark.sql.types import *
from pyspark.sql.utils import AnalysisException
from pyspark.testing.sqlutils import ReusedSQLTestCase, test_compiled, test_not_compiled_message
@@ -102,7 +102,6 @@ def test_udf_registration_return_type_not_none(self):
def test_nondeterministic_udf(self):
# Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
- from pyspark.sql.functions import udf
import random
udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic()
self.assertEqual(udf_random_col.deterministic, False)
@@ -113,7 +112,6 @@ def test_nondeterministic_udf(self):
def test_nondeterministic_udf2(self):
import random
- from pyspark.sql.functions import udf
random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic()
self.assertEqual(random_udf.deterministic, False)
random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf)
@@ -132,7 +130,6 @@ def test_nondeterministic_udf2(self):
def test_nondeterministic_udf3(self):
# regression test for SPARK-23233
- from pyspark.sql.functions import udf
f = udf(lambda x: x)
# Here we cache the JVM UDF instance.
self.spark.range(1).select(f("id"))
@@ -144,7 +141,7 @@ def test_nondeterministic_udf3(self):
self.assertFalse(deterministic)
def test_nondeterministic_udf_in_aggregate(self):
- from pyspark.sql.functions import udf, sum
+ from pyspark.sql.functions import sum
import random
udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic()
df = self.spark.range(10)
@@ -181,7 +178,6 @@ def test_multiple_udfs(self):
self.assertEqual(tuple(row), (6, 5))
def test_udf_in_filter_on_top_of_outer_join(self):
- from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(a=1)])
df = left.join(right, on='a', how='left_outer')
@@ -190,7 +186,6 @@ def test_udf_in_filter_on_top_of_outer_join(self):
def test_udf_in_filter_on_top_of_join(self):
# regression test for SPARK-18589
- from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
@@ -199,7 +194,6 @@ def test_udf_in_filter_on_top_of_join(self):
def test_udf_in_join_condition(self):
# regression test for SPARK-25314
- from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a, b: a == b, BooleanType())
@@ -211,7 +205,7 @@ def test_udf_in_join_condition(self):
def test_udf_in_left_outer_join_condition(self):
# regression test for SPARK-26147
- from pyspark.sql.functions import udf, col
+ from pyspark.sql.functions import col
left = self.spark.createDataFrame([Row(a=1)])
right = self.spark.createDataFrame([Row(b=1)])
f = udf(lambda a: str(a), StringType())
@@ -223,7 +217,6 @@ def test_udf_in_left_outer_join_condition(self):
def test_udf_in_left_semi_join_condition(self):
# regression test for SPARK-25314
- from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
@@ -236,7 +229,6 @@ def test_udf_in_left_semi_join_condition(self):
def test_udf_and_common_filter_in_join_condition(self):
# regression test for SPARK-25314
# test the complex scenario with both udf and common filter
- from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
@@ -247,7 +239,6 @@ def test_udf_and_common_filter_in_join_condition(self):
def test_udf_and_common_filter_in_left_semi_join_condition(self):
# regression test for SPARK-25314
# test the complex scenario with both udf and common filter
- from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
@@ -258,7 +249,6 @@ def test_udf_and_common_filter_in_left_semi_join_condition(self):
def test_udf_not_supported_in_join_condition(self):
# regression test for SPARK-25314
# test python udf is not supported in join type besides left_semi and inner join.
- from pyspark.sql.functions import udf
left = self.spark.createDataFrame([Row(a=1, a1=1, a2=1), Row(a=2, a1=2, a2=2)])
right = self.spark.createDataFrame([Row(b=1, b1=1, b2=1), Row(b=1, b1=3, b2=1)])
f = udf(lambda a, b: a == b, BooleanType())
@@ -301,7 +291,7 @@ def test_broadcast_in_udf(self):
def test_udf_with_filter_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
- from pyspark.sql.functions import udf, col
+ from pyspark.sql.functions import col
from pyspark.sql.types import BooleanType
my_filter = udf(lambda a: a < 2, BooleanType())
@@ -310,7 +300,7 @@ def test_udf_with_filter_function(self):
def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
- from pyspark.sql.functions import udf, col, sum
+ from pyspark.sql.functions import col, sum
from pyspark.sql.types import BooleanType
my_filter = udf(lambda a: a == 1, BooleanType())
@@ -326,7 +316,7 @@ def test_udf_with_aggregate_function(self):
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
def test_udf_in_generate(self):
- from pyspark.sql.functions import udf, explode
+ from pyspark.sql.functions import explode
df = self.spark.range(5)
f = udf(lambda x: list(range(x)), ArrayType(LongType()))
row = df.select(explode(f(*df))).groupBy().sum().first()
@@ -353,7 +343,6 @@ def test_udf_in_generate(self):
self.assertEqual(res[3][1], 1)
def test_udf_with_order_by_and_limit(self):
- from pyspark.sql.functions import udf
my_copy = udf(lambda x: x, IntegerType())
df = self.spark.range(10).orderBy("id")
res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
@@ -394,14 +383,14 @@ def test_non_existed_udaf(self):
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))
def test_udf_with_input_file_name(self):
- from pyspark.sql.functions import udf, input_file_name
+ from pyspark.sql.functions import input_file_name
sourceFile = udf(lambda path: path, StringType())
filePath = "python/test_support/sql/people1.json"
row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first()
self.assertTrue(row[0].find("people1.json") != -1)
def test_udf_with_input_file_name_for_hadooprdd(self):
- from pyspark.sql.functions import udf, input_file_name
+ from pyspark.sql.functions import input_file_name
def filename(path):
return path
@@ -427,9 +416,6 @@ def test_udf_defers_judf_initialization(self):
# This is separate of UDFInitializationTests
# to avoid context initialization
# when udf is called
-
- from pyspark.sql.functions import UserDefinedFunction
-
f = UserDefinedFunction(lambda x: x, StringType())
self.assertIsNone(
@@ -445,8 +431,6 @@ def test_udf_defers_judf_initialization(self):
)
def test_udf_with_string_return_type(self):
- from pyspark.sql.functions import UserDefinedFunction
-
add_one = UserDefinedFunction(lambda x: x + 1, "integer")
make_pair = UserDefinedFunction(lambda x: (-x, x), "struct")
make_array = UserDefinedFunction(
@@ -460,13 +444,11 @@ def test_udf_with_string_return_type(self):
self.assertTupleEqual(expected, actual)
def test_udf_shouldnt_accept_noncallable_object(self):
- from pyspark.sql.functions import UserDefinedFunction
-
non_callable = None
self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType())
def test_udf_with_decorator(self):
- from pyspark.sql.functions import lit, udf
+ from pyspark.sql.functions import lit
from pyspark.sql.types import IntegerType, DoubleType
@udf(IntegerType())
@@ -523,7 +505,6 @@ def as_double(x):
)
def test_udf_wrapper(self):
- from pyspark.sql.functions import udf
from pyspark.sql.types import IntegerType
def f(x):
@@ -569,7 +550,7 @@ def test_nonparam_udf_with_aggregate(self):
# SPARK-24721
@unittest.skipIf(not test_compiled, test_not_compiled_message)
def test_datasource_with_udf(self):
- from pyspark.sql.functions import udf, lit, col
+ from pyspark.sql.functions import lit, col
path = tempfile.mkdtemp()
shutil.rmtree(path)
@@ -609,8 +590,6 @@ def test_datasource_with_udf(self):
# SPARK-25591
def test_same_accumulator_in_udfs(self):
- from pyspark.sql.functions import udf
-
data_schema = StructType([StructField("a", IntegerType(), True),
StructField("b", IntegerType(), True)])
data = self.spark.createDataFrame([[1, 2]], schema=data_schema)
@@ -632,6 +611,15 @@ def second_udf(x):
data.collect()
self.assertEqual(test_accum.value, 101)
+ # SPARK-26293
+ def test_udf_in_subquery(self):
+ f = udf(lambda x: x, "long")
+ with self.tempView("v"):
+ self.spark.range(1).filter(f("id") >= 0).createTempView("v")
+ sql = self.spark.sql
+ result = sql("select i from values(0L) as data(i) where i in (select id from v)")
+ self.assertEqual(result.collect(), [Row(i=0)])
+
class UDFInitializationTests(unittest.TestCase):
def tearDown(self):
@@ -642,8 +630,6 @@ def tearDown(self):
SparkContext._active_spark_context.stop()
def test_udf_init_shouldnt_initialize_context(self):
- from pyspark.sql.functions import UserDefinedFunction
-
UserDefinedFunction(lambda x: x, StringType())
self.assertIsNone(
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 953b468e96519..bf007b0c62d8d 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -145,7 +145,18 @@ def wrapped(*series):
return lambda *a: (wrapped(*a), arrow_return_type)
-def wrap_window_agg_pandas_udf(f, return_type):
+def wrap_window_agg_pandas_udf(f, return_type, runner_conf, udf_index):
+ window_bound_types_str = runner_conf.get('pandas_window_bound_types')
+ window_bound_type = [t.strip().lower() for t in window_bound_types_str.split(',')][udf_index]
+ if window_bound_type == 'bounded':
+ return wrap_bounded_window_agg_pandas_udf(f, return_type)
+ elif window_bound_type == 'unbounded':
+ return wrap_unbounded_window_agg_pandas_udf(f, return_type)
+ else:
+ raise RuntimeError("Invalid window bound type: {} ".format(window_bound_type))
+
+
+def wrap_unbounded_window_agg_pandas_udf(f, return_type):
# This is similar to grouped_agg_pandas_udf, the only difference
# is that window_agg_pandas_udf needs to repeat the return value
# to match window length, where grouped_agg_pandas_udf just returns
@@ -160,7 +171,41 @@ def wrapped(*series):
return lambda *a: (wrapped(*a), arrow_return_type)
-def read_single_udf(pickleSer, infile, eval_type, runner_conf):
+def wrap_bounded_window_agg_pandas_udf(f, return_type):
+ arrow_return_type = to_arrow_type(return_type)
+
+ def wrapped(begin_index, end_index, *series):
+ import pandas as pd
+ result = []
+
+ # Index operation is faster on np.ndarray,
+ # So we turn the index series into np array
+ # here for performance
+ begin_array = begin_index.values
+ end_array = end_index.values
+
+ for i in range(len(begin_array)):
+ # Note: Create a slice from a series for each window is
+ # actually pretty expensive. However, there
+ # is no easy way to reduce cost here.
+ # Note: s.iloc[i : j] is about 30% faster than s[i: j], with
+ # the caveat that the created slices shares the same
+ # memory with s. Therefore, user are not allowed to
+ # change the value of input series inside the window
+ # function. It is rare that user needs to modify the
+ # input series in the window function, and therefore,
+ # it is be a reasonable restriction.
+ # Note: Calling reset_index on the slices will increase the cost
+ # of creating slices by about 100%. Therefore, for performance
+ # reasons we don't do it here.
+ series_slices = [s.iloc[begin_array[i]: end_array[i]] for s in series]
+ result.append(f(*series_slices))
+ return pd.Series(result)
+
+ return lambda *a: (wrapped(*a), arrow_return_type)
+
+
+def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
num_arg = read_int(infile)
arg_offsets = [read_int(infile) for i in range(num_arg)]
row_func = None
@@ -184,7 +229,7 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf):
elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF:
- return arg_offsets, wrap_window_agg_pandas_udf(func, return_type)
+ return arg_offsets, wrap_window_agg_pandas_udf(func, return_type, runner_conf, udf_index)
elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
return arg_offsets, wrap_udf(func, return_type)
else:
@@ -226,7 +271,8 @@ def read_udfs(pickleSer, infile, eval_type):
# See FlatMapGroupsInPandasExec for how arg_offsets are used to
# distinguish between grouping attributes and data attributes
- arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf)
+ arg_offsets, udf = read_single_udf(
+ pickleSer, infile, eval_type, runner_conf, udf_index=0)
udfs['f'] = udf
split_offset = arg_offsets[0] + 1
arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]]
@@ -238,7 +284,8 @@ def read_udfs(pickleSer, infile, eval_type):
# In the special case of a single UDF this will return a single result rather
# than a tuple of results; this is the format that the JVM side expects.
for i in range(num_udfs):
- arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, runner_conf)
+ arg_offsets, udf = read_single_udf(
+ pickleSer, infile, eval_type, runner_conf, udf_index=i)
udfs['f%d' % i] = udf
args = ["a[%d]" % o for o in arg_offsets]
call_udf.append("f%d(%s)" % (i, ", ".join(args)))
diff --git a/python/run-tests-with-coverage b/python/run-tests-with-coverage
index 6d74b563e9140..457821037d43c 100755
--- a/python/run-tests-with-coverage
+++ b/python/run-tests-with-coverage
@@ -50,8 +50,6 @@ export SPARK_CONF_DIR="$COVERAGE_DIR/conf"
# This environment variable enables the coverage.
export COVERAGE_PROCESS_START="$FWDIR/.coveragerc"
-# If you'd like to run a specific unittest class, you could do such as
-# SPARK_TESTING=1 ../bin/pyspark pyspark.sql.tests VectorizedUDFTests
./run-tests "$@"
# Don't run coverage for the coverage command itself
diff --git a/python/run-tests.py b/python/run-tests.py
index 01a6e81264dd6..e45268c13769a 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -19,7 +19,7 @@
from __future__ import print_function
import logging
-from optparse import OptionParser
+from optparse import OptionParser, OptionGroup
import os
import re
import shutil
@@ -99,7 +99,7 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
try:
per_test_output = tempfile.TemporaryFile()
retcode = subprocess.Popen(
- [os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
+ [os.path.join(SPARK_HOME, "bin/pyspark")] + test_name.split(),
stderr=per_test_output, stdout=per_test_output, env=env).wait()
shutil.rmtree(tmp_dir, ignore_errors=True)
except:
@@ -190,6 +190,20 @@ def parse_opts():
help="Enable additional debug logging"
)
+ group = OptionGroup(parser, "Developer Options")
+ group.add_option(
+ "--testnames", type="string",
+ default=None,
+ help=(
+ "A comma-separated list of specific modules, classes and functions of doctest "
+ "or unittest to test. "
+ "For example, 'pyspark.sql.foo' to run the module as unittests or doctests, "
+ "'pyspark.sql.tests FooTests' to run the specific class of unittests, "
+ "'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. "
+ "'--modules' option is ignored if they are given.")
+ )
+ parser.add_option_group(group)
+
(opts, args) = parser.parse_args()
if args:
parser.error("Unsupported arguments: %s" % ' '.join(args))
@@ -213,25 +227,31 @@ def _check_coverage(python_exec):
def main():
opts = parse_opts()
- if (opts.verbose):
+ if opts.verbose:
log_level = logging.DEBUG
else:
log_level = logging.INFO
+ should_test_modules = opts.testnames is None
logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
if os.path.exists(LOG_FILE):
os.remove(LOG_FILE)
python_execs = opts.python_executables.split(',')
- modules_to_test = []
- for module_name in opts.modules.split(','):
- if module_name in python_modules:
- modules_to_test.append(python_modules[module_name])
- else:
- print("Error: unrecognized module '%s'. Supported modules: %s" %
- (module_name, ", ".join(python_modules)))
- sys.exit(-1)
LOGGER.info("Will test against the following Python executables: %s", python_execs)
- LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
+
+ if should_test_modules:
+ modules_to_test = []
+ for module_name in opts.modules.split(','):
+ if module_name in python_modules:
+ modules_to_test.append(python_modules[module_name])
+ else:
+ print("Error: unrecognized module '%s'. Supported modules: %s" %
+ (module_name, ", ".join(python_modules)))
+ sys.exit(-1)
+ LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
+ else:
+ testnames_to_test = opts.testnames.split(',')
+ LOGGER.info("Will test the following Python tests: %s", testnames_to_test)
task_queue = Queue.PriorityQueue()
for python_exec in python_execs:
@@ -246,16 +266,20 @@ def main():
LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation)
LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output(
[python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
- for module in modules_to_test:
- if python_implementation not in module.blacklisted_python_implementations:
- for test_goal in module.python_test_goals:
- heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
- 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
- if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
- priority = 0
- else:
- priority = 100
- task_queue.put((priority, (python_exec, test_goal)))
+ if should_test_modules:
+ for module in modules_to_test:
+ if python_implementation not in module.blacklisted_python_implementations:
+ for test_goal in module.python_test_goals:
+ heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
+ 'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
+ if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
+ priority = 0
+ else:
+ priority = 100
+ task_queue.put((priority, (python_exec, test_goal)))
+ else:
+ for test_goal in testnames_to_test:
+ task_queue.put((0, (python_exec, test_goal)))
# Create the target directory before starting tasks to avoid races.
target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target'))
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
index 85917b88e912a..76041e7de5182 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Constants.scala
@@ -87,25 +87,22 @@ private[spark] object Constants {
val NON_JVM_MEMORY_OVERHEAD_FACTOR = 0.4d
// Hadoop Configuration
- val HADOOP_FILE_VOLUME = "hadoop-properties"
+ val HADOOP_CONF_VOLUME = "hadoop-properties"
val KRB_FILE_VOLUME = "krb5-file"
val HADOOP_CONF_DIR_PATH = "/opt/hadoop/conf"
val KRB_FILE_DIR_PATH = "/etc"
val ENV_HADOOP_CONF_DIR = "HADOOP_CONF_DIR"
val HADOOP_CONFIG_MAP_NAME =
"spark.kubernetes.executor.hadoopConfigMapName"
- val KRB5_CONFIG_MAP_NAME =
- "spark.kubernetes.executor.krb5ConfigMapName"
// Kerberos Configuration
- val KERBEROS_DELEGEGATION_TOKEN_SECRET_NAME = "delegation-tokens"
val KERBEROS_DT_SECRET_NAME =
"spark.kubernetes.kerberos.dt-secret-name"
val KERBEROS_DT_SECRET_KEY =
"spark.kubernetes.kerberos.dt-secret-key"
- val KERBEROS_SPARK_USER_NAME =
- "spark.kubernetes.kerberos.spark-user-name"
val KERBEROS_SECRET_KEY = "hadoop-tokens"
+ val KERBEROS_KEYTAB_VOLUME = "kerberos-keytab"
+ val KERBEROS_KEYTAB_MOUNT_POINT = "/mnt/secrets/kerberos-keytab"
// Hadoop credentials secrets for the Spark app.
val SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR = "/mnt/secrets/hadoop-credentials"
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala
index a06c21b47f15e..6febad981af56 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala
@@ -42,10 +42,6 @@ private[spark] abstract class KubernetesConf(val sparkConf: SparkConf) {
def appName: String = get("spark.app.name", "spark")
- def hadoopConfigMapName: String = s"$resourceNamePrefix-hadoop-config"
-
- def krbConfigMapName: String = s"$resourceNamePrefix-krb5-file"
-
def namespace: String = get(KUBERNETES_NAMESPACE)
def imagePullPolicy: String = get(CONTAINER_IMAGE_PULL_POLICY)
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala
index 345dd117fd35f..fd1196368a7ff 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala
@@ -18,7 +18,30 @@ package org.apache.spark.deploy.k8s
import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder}
-private[spark] case class SparkPod(pod: Pod, container: Container)
+private[spark] case class SparkPod(pod: Pod, container: Container) {
+
+ /**
+ * Convenience method to apply a series of chained transformations to a pod.
+ *
+ * Use it like:
+ *
+ * original.modify { case pod =>
+ * // update pod and return new one
+ * }.modify { case pod =>
+ * // more changes that create a new pod
+ * }.modify {
+ * case pod if someCondition => // new pod
+ * }
+ *
+ * This makes it cleaner to apply multiple transformations, avoiding having to create
+ * a bunch of awkwardly-named local variables. Since the argument is a partial function,
+ * it can do matching without needing to exhaust all the possibilities. If the function
+ * is not applied, then the original pod will be kept.
+ */
+ def transform(fn: PartialFunction[SparkPod, SparkPod]): SparkPod = fn.lift(this).getOrElse(this)
+
+}
+
private[spark] object SparkPod {
def initialPod(): SparkPod = {
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala
index d8cf3653d3226..8362c14fb289d 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala
@@ -110,6 +110,10 @@ private[spark] class BasicDriverFeatureStep(conf: KubernetesDriverConf)
.withContainerPort(driverUIPort)
.withProtocol("TCP")
.endPort()
+ .addNewEnv()
+ .withName(ENV_SPARK_USER)
+ .withValue(Utils.getCurrentUserName())
+ .endEnv()
.addAllToEnv(driverCustomEnvs.asJava)
.addNewEnv()
.withName(ENV_DRIVER_BIND_ADDRESS)
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
index 8bf315248388f..c8bf7cdb4224f 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala
@@ -20,20 +20,21 @@ import scala.collection.JavaConverters._
import io.fabric8.kubernetes.api.model._
-import org.apache.spark.SparkException
+import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.k8s._
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD, PYSPARK_EXECUTOR_MEMORY}
+import org.apache.spark.internal.config._
import org.apache.spark.rpc.RpcEndpointAddress
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
import org.apache.spark.util.Utils
-private[spark] class BasicExecutorFeatureStep(kubernetesConf: KubernetesExecutorConf)
+private[spark] class BasicExecutorFeatureStep(
+ kubernetesConf: KubernetesExecutorConf,
+ secMgr: SecurityManager)
extends KubernetesFeatureConfigStep {
// Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf
- private val executorExtraClasspath = kubernetesConf.get(EXECUTOR_CLASS_PATH)
private val executorContainerImage = kubernetesConf
.get(EXECUTOR_CONTAINER_IMAGE)
.getOrElse(throw new SparkException("Must specify the executor container image"))
@@ -87,44 +88,63 @@ private[spark] class BasicExecutorFeatureStep(kubernetesConf: KubernetesExecutor
val executorCpuQuantity = new QuantityBuilder(false)
.withAmount(executorCoresRequest)
.build()
- val executorExtraClasspathEnv = executorExtraClasspath.map { cp =>
- new EnvVarBuilder()
- .withName(ENV_CLASSPATH)
- .withValue(cp)
- .build()
- }
- val executorExtraJavaOptionsEnv = kubernetesConf
- .get(EXECUTOR_JAVA_OPTIONS)
- .map { opts =>
- val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId,
- kubernetesConf.executorId)
- val delimitedOpts = Utils.splitCommandString(subsOpts)
- delimitedOpts.zipWithIndex.map {
- case (opt, index) =>
- new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build()
+
+ val executorEnv: Seq[EnvVar] = {
+ (Seq(
+ (ENV_DRIVER_URL, driverUrl),
+ (ENV_EXECUTOR_CORES, executorCores.toString),
+ (ENV_EXECUTOR_MEMORY, executorMemoryString),
+ (ENV_APPLICATION_ID, kubernetesConf.appId),
+ // This is to set the SPARK_CONF_DIR to be /opt/spark/conf
+ (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL),
+ (ENV_EXECUTOR_ID, kubernetesConf.executorId)
+ ) ++ kubernetesConf.environment).map { case (k, v) =>
+ new EnvVarBuilder()
+ .withName(k)
+ .withValue(v)
+ .build()
}
- }.getOrElse(Seq.empty[EnvVar])
- val executorEnv = (Seq(
- (ENV_DRIVER_URL, driverUrl),
- (ENV_EXECUTOR_CORES, executorCores.toString),
- (ENV_EXECUTOR_MEMORY, executorMemoryString),
- (ENV_APPLICATION_ID, kubernetesConf.appId),
- // This is to set the SPARK_CONF_DIR to be /opt/spark/conf
- (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL),
- (ENV_EXECUTOR_ID, kubernetesConf.executorId)) ++
- kubernetesConf.environment)
- .map(env => new EnvVarBuilder()
- .withName(env._1)
- .withValue(env._2)
- .build()
- ) ++ Seq(
- new EnvVarBuilder()
- .withName(ENV_EXECUTOR_POD_IP)
- .withValueFrom(new EnvVarSourceBuilder()
- .withNewFieldRef("v1", "status.podIP")
+ } ++ {
+ Seq(new EnvVarBuilder()
+ .withName(ENV_EXECUTOR_POD_IP)
+ .withValueFrom(new EnvVarSourceBuilder()
+ .withNewFieldRef("v1", "status.podIP")
+ .build())
.build())
- .build()
- ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq
+ } ++ {
+ if (kubernetesConf.get(AUTH_SECRET_FILE_EXECUTOR).isEmpty) {
+ Option(secMgr.getSecretKey()).map { authSecret =>
+ new EnvVarBuilder()
+ .withName(SecurityManager.ENV_AUTH_SECRET)
+ .withValue(authSecret)
+ .build()
+ }
+ } else None
+ } ++ {
+ kubernetesConf.get(EXECUTOR_CLASS_PATH).map { cp =>
+ new EnvVarBuilder()
+ .withName(ENV_CLASSPATH)
+ .withValue(cp)
+ .build()
+ }
+ } ++ {
+ val userOpts = kubernetesConf.get(EXECUTOR_JAVA_OPTIONS).toSeq.flatMap { opts =>
+ val subsOpts = Utils.substituteAppNExecIds(opts, kubernetesConf.appId,
+ kubernetesConf.executorId)
+ Utils.splitCommandString(subsOpts)
+ }
+
+ val sparkOpts = Utils.sparkJavaOpts(kubernetesConf.sparkConf,
+ SparkConf.isExecutorStartupConf)
+
+ (userOpts ++ sparkOpts).zipWithIndex.map { case (opt, index) =>
+ new EnvVarBuilder()
+ .withName(s"$ENV_JAVA_OPT_PREFIX$index")
+ .withValue(opt)
+ .build()
+ }
+ }
+
val requiredPorts = Seq(
(BLOCK_MANAGER_PORT_NAME, blockManagerPort))
.map { case (name, port) =>
@@ -143,6 +163,10 @@ private[spark] class BasicExecutorFeatureStep(kubernetesConf: KubernetesExecutor
.addToLimits("memory", executorMemoryQuantity)
.addToRequests("cpu", executorCpuQuantity)
.endResources()
+ .addNewEnv()
+ .withName(ENV_SPARK_USER)
+ .withValue(Utils.getCurrentUserName())
+ .endEnv()
.addAllToEnv(executorEnv.asJava)
.withPorts(requiredPorts.asJava)
.addToArgs("executor")
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala
new file mode 100644
index 0000000000000..d602ed5481e65
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStep.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+
+import scala.collection.JavaConverters._
+
+import com.google.common.io.Files
+import io.fabric8.kubernetes.api.model._
+
+import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesUtils, SparkPod}
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+
+/**
+ * Mounts the Hadoop configuration - either a pre-defined config map, or a local configuration
+ * directory - on the driver pod.
+ */
+private[spark] class HadoopConfDriverFeatureStep(conf: KubernetesConf)
+ extends KubernetesFeatureConfigStep {
+
+ private val confDir = Option(conf.sparkConf.getenv(ENV_HADOOP_CONF_DIR))
+ private val existingConfMap = conf.get(KUBERNETES_HADOOP_CONF_CONFIG_MAP)
+
+ KubernetesUtils.requireNandDefined(
+ confDir,
+ existingConfMap,
+ "Do not specify both the `HADOOP_CONF_DIR` in your ENV and the ConfigMap " +
+ "as the creation of an additional ConfigMap, when one is already specified is extraneous")
+
+ private lazy val confFiles: Seq[File] = {
+ val dir = new File(confDir.get)
+ if (dir.isDirectory) {
+ dir.listFiles.filter(_.isFile).toSeq
+ } else {
+ Nil
+ }
+ }
+
+ private def newConfigMapName: String = s"${conf.resourceNamePrefix}-hadoop-config"
+
+ private def hasHadoopConf: Boolean = confDir.isDefined || existingConfMap.isDefined
+
+ override def configurePod(original: SparkPod): SparkPod = {
+ original.transform { case pod if hasHadoopConf =>
+ val confVolume = if (confDir.isDefined) {
+ val keyPaths = confFiles.map { file =>
+ new KeyToPathBuilder()
+ .withKey(file.getName())
+ .withPath(file.getName())
+ .build()
+ }
+ new VolumeBuilder()
+ .withName(HADOOP_CONF_VOLUME)
+ .withNewConfigMap()
+ .withName(newConfigMapName)
+ .withItems(keyPaths.asJava)
+ .endConfigMap()
+ .build()
+ } else {
+ new VolumeBuilder()
+ .withName(HADOOP_CONF_VOLUME)
+ .withNewConfigMap()
+ .withName(existingConfMap.get)
+ .endConfigMap()
+ .build()
+ }
+
+ val podWithConf = new PodBuilder(pod.pod)
+ .editSpec()
+ .addNewVolumeLike(confVolume)
+ .endVolume()
+ .endSpec()
+ .build()
+
+ val containerWithMount = new ContainerBuilder(pod.container)
+ .addNewVolumeMount()
+ .withName(HADOOP_CONF_VOLUME)
+ .withMountPath(HADOOP_CONF_DIR_PATH)
+ .endVolumeMount()
+ .addNewEnv()
+ .withName(ENV_HADOOP_CONF_DIR)
+ .withValue(HADOOP_CONF_DIR_PATH)
+ .endEnv()
+ .build()
+
+ SparkPod(podWithConf, containerWithMount)
+ }
+ }
+
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = {
+ if (confDir.isDefined) {
+ val fileMap = confFiles.map { file =>
+ (file.getName(), Files.toString(file, StandardCharsets.UTF_8))
+ }.toMap.asJava
+
+ Seq(new ConfigMapBuilder()
+ .withNewMetadata()
+ .withName(newConfigMapName)
+ .endMetadata()
+ .addToData(fileMap)
+ .build())
+ } else {
+ Nil
+ }
+ }
+
+}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala
deleted file mode 100644
index bca66759d586e..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopConfExecutorFeatureStep.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.features
-
-import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod}
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil
-import org.apache.spark.internal.Logging
-
-/**
- * This step is responsible for bootstraping the container with ConfigMaps
- * containing Hadoop config files mounted as volumes and an ENV variable
- * pointed to the mounted file directory.
- */
-private[spark] class HadoopConfExecutorFeatureStep(conf: KubernetesExecutorConf)
- extends KubernetesFeatureConfigStep with Logging {
-
- override def configurePod(pod: SparkPod): SparkPod = {
- val hadoopConfDirCMapName = conf.getOption(HADOOP_CONFIG_MAP_NAME)
- require(hadoopConfDirCMapName.isDefined,
- "Ensure that the env `HADOOP_CONF_DIR` is defined either in the client or " +
- " using pre-existing ConfigMaps")
- logInfo("HADOOP_CONF_DIR defined")
- HadoopBootstrapUtil.bootstrapHadoopConfDir(None, None, hadoopConfDirCMapName, pod)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala
index c6d5a866fa7bc..721d7e97b21f8 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStep.scala
@@ -16,31 +16,40 @@
*/
package org.apache.spark.deploy.k8s.features
-import io.fabric8.kubernetes.api.model.{HasMetadata, Secret, SecretBuilder}
+import java.io.File
+import java.nio.charset.StandardCharsets
+
+import scala.collection.JavaConverters._
+
+import com.google.common.io.Files
+import io.fabric8.kubernetes.api.model._
import org.apache.commons.codec.binary.Base64
-import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.k8s.{KubernetesDriverConf, KubernetesUtils, SparkPod}
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.features.hadooputils._
import org.apache.spark.deploy.security.HadoopDelegationTokenManager
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils
/**
- * Runs the necessary Hadoop-based logic based on Kerberos configs and the presence of the
- * HADOOP_CONF_DIR. This runs various bootstrap methods defined in HadoopBootstrapUtil.
+ * Provide kerberos / service credentials to the Spark driver.
+ *
+ * There are three use cases, in order of precedence:
+ *
+ * - keytab: if a kerberos keytab is defined, it is provided to the driver, and the driver will
+ * manage the kerberos login and the creation of delegation tokens.
+ * - existing tokens: if a secret containing delegation tokens is provided, it will be mounted
+ * on the driver pod, and the driver will handle distribution of those tokens to executors.
+ * - tgt only: if Hadoop security is enabled, the local TGT will be used to create delegation
+ * tokens which will be provided to the driver. The driver will handle distribution of the
+ * tokens to executors.
*/
private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDriverConf)
- extends KubernetesFeatureConfigStep {
-
- private val hadoopConfDir = Option(kubernetesConf.sparkConf.getenv(ENV_HADOOP_CONF_DIR))
- private val hadoopConfigMapName = kubernetesConf.get(KUBERNETES_HADOOP_CONF_CONFIG_MAP)
- KubernetesUtils.requireNandDefined(
- hadoopConfDir,
- hadoopConfigMapName,
- "Do not specify both the `HADOOP_CONF_DIR` in your ENV and the ConfigMap " +
- "as the creation of an additional ConfigMap, when one is already specified is extraneous")
+ extends KubernetesFeatureConfigStep with Logging {
private val principal = kubernetesConf.get(org.apache.spark.internal.config.PRINCIPAL)
private val keytab = kubernetesConf.get(org.apache.spark.internal.config.KEYTAB)
@@ -49,15 +58,6 @@ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDri
private val krb5File = kubernetesConf.get(KUBERNETES_KERBEROS_KRB5_FILE)
private val krb5CMap = kubernetesConf.get(KUBERNETES_KERBEROS_KRB5_CONFIG_MAP)
private val hadoopConf = SparkHadoopUtil.get.newConfiguration(kubernetesConf.sparkConf)
- private val tokenManager = new HadoopDelegationTokenManager(kubernetesConf.sparkConf, hadoopConf)
- private val isKerberosEnabled =
- (hadoopConfDir.isDefined && UserGroupInformation.isSecurityEnabled) ||
- (hadoopConfigMapName.isDefined && (krb5File.isDefined || krb5CMap.isDefined))
- require(keytab.isEmpty || isKerberosEnabled,
- "You must enable Kerberos support if you are specifying a Kerberos Keytab")
-
- require(existingSecretName.isEmpty || isKerberosEnabled,
- "You must enable Kerberos support if you are specifying a Kerberos Secret")
KubernetesUtils.requireNandDefined(
krb5File,
@@ -79,128 +79,183 @@ private[spark] class KerberosConfDriverFeatureStep(kubernetesConf: KubernetesDri
"If a secret storing a Kerberos Delegation Token is specified you must also" +
" specify the item-key where the data is stored")
- private val hadoopConfigurationFiles = hadoopConfDir.map { hConfDir =>
- HadoopBootstrapUtil.getHadoopConfFiles(hConfDir)
+ if (!hasKerberosConf) {
+ logInfo("You have not specified a krb5.conf file locally or via a ConfigMap. " +
+ "Make sure that you have the krb5.conf locally on the driver image.")
}
- private val newHadoopConfigMapName =
- if (hadoopConfigMapName.isEmpty) {
- Some(kubernetesConf.hadoopConfigMapName)
- } else {
- None
- }
- // Either use pre-existing secret or login to create new Secret with DT stored within
- private val kerberosConfSpec: Option[KerberosConfigSpec] = (for {
- secretName <- existingSecretName
- secretItemKey <- existingSecretItemKey
- } yield {
- KerberosConfigSpec(
- dtSecret = None,
- dtSecretName = secretName,
- dtSecretItemKey = secretItemKey,
- jobUserName = UserGroupInformation.getCurrentUser.getShortUserName)
- }).orElse(
- if (isKerberosEnabled) {
- Some(buildKerberosSpec())
+ // Create delegation tokens if needed. This is a lazy val so that it's not populated
+ // unnecessarily. But it needs to be accessible to different methods in this class,
+ // since it's not clear based solely on available configuration options that delegation
+ // tokens are needed when other credentials are not available.
+ private lazy val delegationTokens: Array[Byte] = {
+ if (keytab.isEmpty && existingSecretName.isEmpty) {
+ val tokenManager = new HadoopDelegationTokenManager(kubernetesConf.sparkConf,
+ SparkHadoopUtil.get.newConfiguration(kubernetesConf.sparkConf))
+ val creds = UserGroupInformation.getCurrentUser().getCredentials()
+ tokenManager.obtainDelegationTokens(creds)
+ // If no tokens and no secrets are stored in the credentials, make sure nothing is returned,
+ // to avoid creating an unnecessary secret.
+ if (creds.numberOfTokens() > 0 || creds.numberOfSecretKeys() > 0) {
+ SparkHadoopUtil.get.serialize(creds)
+ } else {
+ null
+ }
} else {
- None
+ null
}
- )
+ }
- override def configurePod(pod: SparkPod): SparkPod = {
- if (!isKerberosEnabled) {
- return pod
- }
+ private def needKeytabUpload: Boolean = keytab.exists(!Utils.isLocalUri(_))
- val hadoopBasedSparkPod = HadoopBootstrapUtil.bootstrapHadoopConfDir(
- hadoopConfDir,
- newHadoopConfigMapName,
- hadoopConfigMapName,
- pod)
- kerberosConfSpec.map { hSpec =>
- HadoopBootstrapUtil.bootstrapKerberosPod(
- hSpec.dtSecretName,
- hSpec.dtSecretItemKey,
- hSpec.jobUserName,
- krb5File,
- Some(kubernetesConf.krbConfigMapName),
- krb5CMap,
- hadoopBasedSparkPod)
- }.getOrElse(
- HadoopBootstrapUtil.bootstrapSparkUserPod(
- UserGroupInformation.getCurrentUser.getShortUserName,
- hadoopBasedSparkPod))
- }
+ private def dtSecretName: String = s"${kubernetesConf.resourceNamePrefix}-delegation-tokens"
- override def getAdditionalPodSystemProperties(): Map[String, String] = {
- if (!isKerberosEnabled) {
- return Map.empty
- }
+ private def ktSecretName: String = s"${kubernetesConf.resourceNamePrefix}-kerberos-keytab"
- val resolvedConfValues = kerberosConfSpec.map { hSpec =>
- Map(KERBEROS_DT_SECRET_NAME -> hSpec.dtSecretName,
- KERBEROS_DT_SECRET_KEY -> hSpec.dtSecretItemKey,
- KERBEROS_SPARK_USER_NAME -> hSpec.jobUserName,
- KRB5_CONFIG_MAP_NAME -> krb5CMap.getOrElse(kubernetesConf.krbConfigMapName))
- }.getOrElse(
- Map(KERBEROS_SPARK_USER_NAME ->
- UserGroupInformation.getCurrentUser.getShortUserName))
- Map(HADOOP_CONFIG_MAP_NAME ->
- hadoopConfigMapName.getOrElse(kubernetesConf.hadoopConfigMapName)) ++ resolvedConfValues
- }
+ private def hasKerberosConf: Boolean = krb5CMap.isDefined | krb5File.isDefined
- override def getAdditionalKubernetesResources(): Seq[HasMetadata] = {
- if (!isKerberosEnabled) {
- return Seq.empty
- }
+ private def newConfigMapName: String = s"${kubernetesConf.resourceNamePrefix}-krb5-file"
- val hadoopConfConfigMap = for {
- hName <- newHadoopConfigMapName
- hFiles <- hadoopConfigurationFiles
- } yield {
- HadoopBootstrapUtil.buildHadoopConfigMap(hName, hFiles)
- }
+ override def configurePod(original: SparkPod): SparkPod = {
+ original.transform { case pod if hasKerberosConf =>
+ val configMapVolume = if (krb5CMap.isDefined) {
+ new VolumeBuilder()
+ .withName(KRB_FILE_VOLUME)
+ .withNewConfigMap()
+ .withName(krb5CMap.get)
+ .endConfigMap()
+ .build()
+ } else {
+ val krb5Conf = new File(krb5File.get)
+ new VolumeBuilder()
+ .withName(KRB_FILE_VOLUME)
+ .withNewConfigMap()
+ .withName(newConfigMapName)
+ .withItems(new KeyToPathBuilder()
+ .withKey(krb5Conf.getName())
+ .withPath(krb5Conf.getName())
+ .build())
+ .endConfigMap()
+ .build()
+ }
- val krb5ConfigMap = krb5File.map { fileLocation =>
- HadoopBootstrapUtil.buildkrb5ConfigMap(
- kubernetesConf.krbConfigMapName,
- fileLocation)
- }
+ val podWithVolume = new PodBuilder(pod.pod)
+ .editSpec()
+ .addNewVolumeLike(configMapVolume)
+ .endVolume()
+ .endSpec()
+ .build()
+
+ val containerWithMount = new ContainerBuilder(pod.container)
+ .addNewVolumeMount()
+ .withName(KRB_FILE_VOLUME)
+ .withMountPath(KRB_FILE_DIR_PATH + "/krb5.conf")
+ .withSubPath("krb5.conf")
+ .endVolumeMount()
+ .build()
+
+ SparkPod(podWithVolume, containerWithMount)
+ }.transform {
+ case pod if needKeytabUpload =>
+ // If keytab is defined and is a submission-local file (not local: URI), then create a
+ // secret for it. The keytab data will be stored in this secret below.
+ val podWitKeytab = new PodBuilder(pod.pod)
+ .editOrNewSpec()
+ .addNewVolume()
+ .withName(KERBEROS_KEYTAB_VOLUME)
+ .withNewSecret()
+ .withSecretName(ktSecretName)
+ .endSecret()
+ .endVolume()
+ .endSpec()
+ .build()
+
+ val containerWithKeytab = new ContainerBuilder(pod.container)
+ .addNewVolumeMount()
+ .withName(KERBEROS_KEYTAB_VOLUME)
+ .withMountPath(KERBEROS_KEYTAB_MOUNT_POINT)
+ .endVolumeMount()
+ .build()
+
+ SparkPod(podWitKeytab, containerWithKeytab)
+
+ case pod if existingSecretName.isDefined | delegationTokens != null =>
+ val secretName = existingSecretName.getOrElse(dtSecretName)
+ val itemKey = existingSecretItemKey.getOrElse(KERBEROS_SECRET_KEY)
+
+ val podWithTokens = new PodBuilder(pod.pod)
+ .editOrNewSpec()
+ .addNewVolume()
+ .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME)
+ .withNewSecret()
+ .withSecretName(secretName)
+ .endSecret()
+ .endVolume()
+ .endSpec()
+ .build()
- val kerberosDTSecret = kerberosConfSpec.flatMap(_.dtSecret)
+ val containerWithTokens = new ContainerBuilder(pod.container)
+ .addNewVolumeMount()
+ .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME)
+ .withMountPath(SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR)
+ .endVolumeMount()
+ .addNewEnv()
+ .withName(ENV_HADOOP_TOKEN_FILE_LOCATION)
+ .withValue(s"$SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR/$itemKey")
+ .endEnv()
+ .build()
- hadoopConfConfigMap.toSeq ++
- krb5ConfigMap.toSeq ++
- kerberosDTSecret.toSeq
+ SparkPod(podWithTokens, containerWithTokens)
+ }
}
- private def buildKerberosSpec(): KerberosConfigSpec = {
- // The JobUserUGI will be taken fom the Local Ticket Cache or via keytab+principal
- // The login happens in the SparkSubmit so login logic is not necessary to include
- val jobUserUGI = UserGroupInformation.getCurrentUser
- val creds = jobUserUGI.getCredentials
- tokenManager.obtainDelegationTokens(creds)
- val tokenData = SparkHadoopUtil.get.serialize(creds)
- require(tokenData.nonEmpty, "Did not obtain any delegation tokens")
- val newSecretName =
- s"${kubernetesConf.resourceNamePrefix}-$KERBEROS_DELEGEGATION_TOKEN_SECRET_NAME"
- val secretDT =
- new SecretBuilder()
- .withNewMetadata()
- .withName(newSecretName)
- .endMetadata()
- .addToData(KERBEROS_SECRET_KEY, Base64.encodeBase64String(tokenData))
- .build()
- KerberosConfigSpec(
- dtSecret = Some(secretDT),
- dtSecretName = newSecretName,
- dtSecretItemKey = KERBEROS_SECRET_KEY,
- jobUserName = jobUserUGI.getShortUserName)
+ override def getAdditionalPodSystemProperties(): Map[String, String] = {
+ // If a submission-local keytab is provided, update the Spark config so that it knows the
+ // path of the keytab in the driver container.
+ if (needKeytabUpload) {
+ val ktName = new File(keytab.get).getName()
+ Map(KEYTAB.key -> s"$KERBEROS_KEYTAB_MOUNT_POINT/$ktName")
+ } else {
+ Map.empty
+ }
}
- private case class KerberosConfigSpec(
- dtSecret: Option[Secret],
- dtSecretName: String,
- dtSecretItemKey: String,
- jobUserName: String)
+ override def getAdditionalKubernetesResources(): Seq[HasMetadata] = {
+ Seq[HasMetadata]() ++ {
+ krb5File.map { path =>
+ val file = new File(path)
+ new ConfigMapBuilder()
+ .withNewMetadata()
+ .withName(newConfigMapName)
+ .endMetadata()
+ .addToData(
+ Map(file.getName() -> Files.toString(file, StandardCharsets.UTF_8)).asJava)
+ .build()
+ }
+ } ++ {
+ // If a submission-local keytab is provided, stash it in a secret.
+ if (needKeytabUpload) {
+ val kt = new File(keytab.get)
+ Seq(new SecretBuilder()
+ .withNewMetadata()
+ .withName(ktSecretName)
+ .endMetadata()
+ .addToData(kt.getName(), Base64.encodeBase64String(Files.toByteArray(kt)))
+ .build())
+ } else {
+ Nil
+ }
+ } ++ {
+ if (delegationTokens != null) {
+ Seq(new SecretBuilder()
+ .withNewMetadata()
+ .withName(dtSecretName)
+ .endMetadata()
+ .addToData(KERBEROS_SECRET_KEY, Base64.encodeBase64String(delegationTokens))
+ .build())
+ } else {
+ Nil
+ }
+ }
+ }
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala
deleted file mode 100644
index 32bb6a5d2bcbb..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KerberosConfExecutorFeatureStep.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.features
-
-import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod}
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil
-import org.apache.spark.internal.Logging
-
-/**
- * This step is responsible for mounting the DT secret for the executors
- */
-private[spark] class KerberosConfExecutorFeatureStep(conf: KubernetesExecutorConf)
- extends KubernetesFeatureConfigStep with Logging {
-
- private val maybeKrb5CMap = conf.getOption(KRB5_CONFIG_MAP_NAME)
- require(maybeKrb5CMap.isDefined, "HADOOP_CONF_DIR ConfigMap not found")
-
- override def configurePod(pod: SparkPod): SparkPod = {
- logInfo(s"Mounting Resources for Kerberos")
- HadoopBootstrapUtil.bootstrapKerberosPod(
- conf.get(KERBEROS_DT_SECRET_NAME),
- conf.get(KERBEROS_DT_SECRET_KEY),
- conf.get(KERBEROS_SPARK_USER_NAME),
- None,
- None,
- maybeKrb5CMap,
- pod)
- }
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala
index 09dcf93a54f8e..7f41ca43589b6 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStep.scala
@@ -28,44 +28,60 @@ import org.apache.spark.deploy.k8s.Constants._
private[spark] class PodTemplateConfigMapStep(conf: KubernetesConf)
extends KubernetesFeatureConfigStep {
+
+ private val hasTemplate = conf.contains(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE)
+
def configurePod(pod: SparkPod): SparkPod = {
- val podWithVolume = new PodBuilder(pod.pod)
- .editSpec()
- .addNewVolume()
- .withName(POD_TEMPLATE_VOLUME)
- .withNewConfigMap()
- .withName(POD_TEMPLATE_CONFIGMAP)
- .addNewItem()
- .withKey(POD_TEMPLATE_KEY)
- .withPath(EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME)
- .endItem()
- .endConfigMap()
- .endVolume()
- .endSpec()
- .build()
+ if (hasTemplate) {
+ val podWithVolume = new PodBuilder(pod.pod)
+ .editSpec()
+ .addNewVolume()
+ .withName(POD_TEMPLATE_VOLUME)
+ .withNewConfigMap()
+ .withName(POD_TEMPLATE_CONFIGMAP)
+ .addNewItem()
+ .withKey(POD_TEMPLATE_KEY)
+ .withPath(EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME)
+ .endItem()
+ .endConfigMap()
+ .endVolume()
+ .endSpec()
+ .build()
- val containerWithVolume = new ContainerBuilder(pod.container)
- .addNewVolumeMount()
- .withName(POD_TEMPLATE_VOLUME)
- .withMountPath(EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH)
- .endVolumeMount()
- .build()
- SparkPod(podWithVolume, containerWithVolume)
+ val containerWithVolume = new ContainerBuilder(pod.container)
+ .addNewVolumeMount()
+ .withName(POD_TEMPLATE_VOLUME)
+ .withMountPath(EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH)
+ .endVolumeMount()
+ .build()
+ SparkPod(podWithVolume, containerWithVolume)
+ } else {
+ pod
+ }
}
- override def getAdditionalPodSystemProperties(): Map[String, String] = Map[String, String](
- KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key ->
- (EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH + "/" + EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME))
+ override def getAdditionalPodSystemProperties(): Map[String, String] = {
+ if (hasTemplate) {
+ Map[String, String](
+ KUBERNETES_EXECUTOR_PODTEMPLATE_FILE.key ->
+ (EXECUTOR_POD_SPEC_TEMPLATE_MOUNTPATH + "/" + EXECUTOR_POD_SPEC_TEMPLATE_FILE_NAME))
+ } else {
+ Map.empty
+ }
+ }
override def getAdditionalKubernetesResources(): Seq[HasMetadata] = {
- require(conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined)
- val podTemplateFile = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).get
- val podTemplateString = Files.toString(new File(podTemplateFile), StandardCharsets.UTF_8)
- Seq(new ConfigMapBuilder()
- .withNewMetadata()
- .withName(POD_TEMPLATE_CONFIGMAP)
- .endMetadata()
- .addToData(POD_TEMPLATE_KEY, podTemplateString)
- .build())
+ if (hasTemplate) {
+ val podTemplateFile = conf.get(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).get
+ val podTemplateString = Files.toString(new File(podTemplateFile), StandardCharsets.UTF_8)
+ Seq(new ConfigMapBuilder()
+ .withNewMetadata()
+ .withName(POD_TEMPLATE_CONFIGMAP)
+ .endMetadata()
+ .addToData(POD_TEMPLATE_KEY, podTemplateString)
+ .build())
+ } else {
+ Nil
+ }
}
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala
deleted file mode 100644
index 5bee766caf2be..0000000000000
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/hadooputils/HadoopBootstrapUtil.scala
+++ /dev/null
@@ -1,283 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.features.hadooputils
-
-import java.io.File
-import java.nio.charset.StandardCharsets
-
-import scala.collection.JavaConverters._
-
-import com.google.common.io.Files
-import io.fabric8.kubernetes.api.model._
-
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.SparkPod
-import org.apache.spark.internal.Logging
-
-private[spark] object HadoopBootstrapUtil extends Logging {
-
- /**
- * Mounting the DT secret for both the Driver and the executors
- *
- * @param dtSecretName Name of the secret that stores the Delegation Token
- * @param dtSecretItemKey Name of the Item Key storing the Delegation Token
- * @param userName Name of the SparkUser to set SPARK_USER
- * @param fileLocation Optional Location of the krb5 file
- * @param newKrb5ConfName Optional location of the ConfigMap for Krb5
- * @param existingKrb5ConfName Optional name of ConfigMap for Krb5
- * @param pod Input pod to be appended to
- * @return a modified SparkPod
- */
- def bootstrapKerberosPod(
- dtSecretName: String,
- dtSecretItemKey: String,
- userName: String,
- fileLocation: Option[String],
- newKrb5ConfName: Option[String],
- existingKrb5ConfName: Option[String],
- pod: SparkPod): SparkPod = {
-
- val preConfigMapVolume = existingKrb5ConfName.map { kconf =>
- new VolumeBuilder()
- .withName(KRB_FILE_VOLUME)
- .withNewConfigMap()
- .withName(kconf)
- .endConfigMap()
- .build()
- }
-
- val createConfigMapVolume = for {
- fLocation <- fileLocation
- krb5ConfName <- newKrb5ConfName
- } yield {
- val krb5File = new File(fLocation)
- val fileStringPath = krb5File.toPath.getFileName.toString
- new VolumeBuilder()
- .withName(KRB_FILE_VOLUME)
- .withNewConfigMap()
- .withName(krb5ConfName)
- .withItems(new KeyToPathBuilder()
- .withKey(fileStringPath)
- .withPath(fileStringPath)
- .build())
- .endConfigMap()
- .build()
- }
-
- // Breaking up Volume creation for clarity
- val configMapVolume = preConfigMapVolume.orElse(createConfigMapVolume)
- if (configMapVolume.isEmpty) {
- logInfo("You have not specified a krb5.conf file locally or via a ConfigMap. " +
- "Make sure that you have the krb5.conf locally on the Driver and Executor images")
- }
-
- val kerberizedPodWithDTSecret = new PodBuilder(pod.pod)
- .editOrNewSpec()
- .addNewVolume()
- .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME)
- .withNewSecret()
- .withSecretName(dtSecretName)
- .endSecret()
- .endVolume()
- .endSpec()
- .build()
-
- // Optionally add the krb5.conf ConfigMap
- val kerberizedPod = configMapVolume.map { cmVolume =>
- new PodBuilder(kerberizedPodWithDTSecret)
- .editSpec()
- .addNewVolumeLike(cmVolume)
- .endVolume()
- .endSpec()
- .build()
- }.getOrElse(kerberizedPodWithDTSecret)
-
- val kerberizedContainerWithMounts = new ContainerBuilder(pod.container)
- .addNewVolumeMount()
- .withName(SPARK_APP_HADOOP_SECRET_VOLUME_NAME)
- .withMountPath(SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR)
- .endVolumeMount()
- .addNewEnv()
- .withName(ENV_HADOOP_TOKEN_FILE_LOCATION)
- .withValue(s"$SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR/$dtSecretItemKey")
- .endEnv()
- .addNewEnv()
- .withName(ENV_SPARK_USER)
- .withValue(userName)
- .endEnv()
- .build()
-
- // Optionally add the krb5.conf Volume Mount
- val kerberizedContainer =
- if (configMapVolume.isDefined) {
- new ContainerBuilder(kerberizedContainerWithMounts)
- .addNewVolumeMount()
- .withName(KRB_FILE_VOLUME)
- .withMountPath(KRB_FILE_DIR_PATH + "/krb5.conf")
- .withSubPath("krb5.conf")
- .endVolumeMount()
- .build()
- } else {
- kerberizedContainerWithMounts
- }
-
- SparkPod(kerberizedPod, kerberizedContainer)
- }
-
- /**
- * setting ENV_SPARK_USER when HADOOP_FILES are detected
- *
- * @param sparkUserName Name of the SPARK_USER
- * @param pod Input pod to be appended to
- * @return a modified SparkPod
- */
- def bootstrapSparkUserPod(sparkUserName: String, pod: SparkPod): SparkPod = {
- val envModifiedContainer = new ContainerBuilder(pod.container)
- .addNewEnv()
- .withName(ENV_SPARK_USER)
- .withValue(sparkUserName)
- .endEnv()
- .build()
- SparkPod(pod.pod, envModifiedContainer)
- }
-
- /**
- * Grabbing files in the HADOOP_CONF_DIR
- *
- * @param path location of HADOOP_CONF_DIR
- * @return a list of File object
- */
- def getHadoopConfFiles(path: String): Seq[File] = {
- val dir = new File(path)
- if (dir.isDirectory) {
- dir.listFiles.filter(_.isFile).toSeq
- } else {
- Seq.empty[File]
- }
- }
-
- /**
- * Bootstraping the container with ConfigMaps that store
- * Hadoop configuration files
- *
- * @param hadoopConfDir directory location of HADOOP_CONF_DIR env
- * @param newHadoopConfigMapName name of the new configMap for HADOOP_CONF_DIR
- * @param existingHadoopConfigMapName name of the pre-defined configMap for HADOOP_CONF_DIR
- * @param pod Input pod to be appended to
- * @return a modified SparkPod
- */
- def bootstrapHadoopConfDir(
- hadoopConfDir: Option[String],
- newHadoopConfigMapName: Option[String],
- existingHadoopConfigMapName: Option[String],
- pod: SparkPod): SparkPod = {
- val preConfigMapVolume = existingHadoopConfigMapName.map { hConf =>
- new VolumeBuilder()
- .withName(HADOOP_FILE_VOLUME)
- .withNewConfigMap()
- .withName(hConf)
- .endConfigMap()
- .build() }
-
- val createConfigMapVolume = for {
- dirLocation <- hadoopConfDir
- hConfName <- newHadoopConfigMapName
- } yield {
- val hadoopConfigFiles = getHadoopConfFiles(dirLocation)
- val keyPaths = hadoopConfigFiles.map { file =>
- val fileStringPath = file.toPath.getFileName.toString
- new KeyToPathBuilder()
- .withKey(fileStringPath)
- .withPath(fileStringPath)
- .build()
- }
- new VolumeBuilder()
- .withName(HADOOP_FILE_VOLUME)
- .withNewConfigMap()
- .withName(hConfName)
- .withItems(keyPaths.asJava)
- .endConfigMap()
- .build()
- }
-
- // Breaking up Volume Creation for clarity
- val configMapVolume = preConfigMapVolume.getOrElse(createConfigMapVolume.get)
-
- val hadoopSupportedPod = new PodBuilder(pod.pod)
- .editSpec()
- .addNewVolumeLike(configMapVolume)
- .endVolume()
- .endSpec()
- .build()
-
- val hadoopSupportedContainer = new ContainerBuilder(pod.container)
- .addNewVolumeMount()
- .withName(HADOOP_FILE_VOLUME)
- .withMountPath(HADOOP_CONF_DIR_PATH)
- .endVolumeMount()
- .addNewEnv()
- .withName(ENV_HADOOP_CONF_DIR)
- .withValue(HADOOP_CONF_DIR_PATH)
- .endEnv()
- .build()
- SparkPod(hadoopSupportedPod, hadoopSupportedContainer)
- }
-
- /**
- * Builds ConfigMap given the file location of the
- * krb5.conf file
- *
- * @param configMapName name of configMap for krb5
- * @param fileLocation location of krb5 file
- * @return a ConfigMap
- */
- def buildkrb5ConfigMap(
- configMapName: String,
- fileLocation: String): ConfigMap = {
- val file = new File(fileLocation)
- new ConfigMapBuilder()
- .withNewMetadata()
- .withName(configMapName)
- .endMetadata()
- .addToData(Map(file.toPath.getFileName.toString ->
- Files.toString(file, StandardCharsets.UTF_8)).asJava)
- .build()
- }
-
- /**
- * Builds ConfigMap given the ConfigMap name
- * and a list of Hadoop Conf files
- *
- * @param hadoopConfigMapName name of hadoopConfigMap
- * @param hadoopConfFiles list of hadoopFiles
- * @return a ConfigMap
- */
- def buildHadoopConfigMap(
- hadoopConfigMapName: String,
- hadoopConfFiles: Seq[File]): ConfigMap = {
- new ConfigMapBuilder()
- .withNewMetadata()
- .withName(hadoopConfigMapName)
- .endMetadata()
- .addToData(hadoopConfFiles.map { file =>
- (file.toPath.getFileName.toString,
- Files.toString(file, StandardCharsets.UTF_8))
- }.toMap.asJava)
- .build()
- }
-
-}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala
index 70a93c968795e..3888778bf84ca 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala
@@ -104,7 +104,7 @@ private[spark] class Client(
watcher: LoggingPodStatusWatcher) extends Logging {
def run(): Unit = {
- val resolvedDriverSpec = builder.buildFromFeatures(conf)
+ val resolvedDriverSpec = builder.buildFromFeatures(conf, kubernetesClient)
val configMapName = s"${conf.resourceNamePrefix}-driver-conf-map"
val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties)
// The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the
@@ -232,7 +232,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication {
None)) { kubernetesClient =>
val client = new Client(
kubernetesConf,
- KubernetesDriverBuilder(kubernetesClient, kubernetesConf.sparkConf),
+ new KubernetesDriverBuilder(),
kubernetesClient,
waitForAppCompletion,
watcher)
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
index a5ad9729aee9a..57e4060bc85b9 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala
@@ -20,90 +20,50 @@ import java.io.File
import io.fabric8.kubernetes.client.KubernetesClient
-import org.apache.spark.SparkConf
import org.apache.spark.deploy.k8s._
import org.apache.spark.deploy.k8s.features._
-private[spark] class KubernetesDriverBuilder(
- provideBasicStep: (KubernetesDriverConf => BasicDriverFeatureStep) =
- new BasicDriverFeatureStep(_),
- provideCredentialsStep: (KubernetesDriverConf => DriverKubernetesCredentialsFeatureStep) =
- new DriverKubernetesCredentialsFeatureStep(_),
- provideServiceStep: (KubernetesDriverConf => DriverServiceFeatureStep) =
- new DriverServiceFeatureStep(_),
- provideSecretsStep: (KubernetesConf => MountSecretsFeatureStep) =
- new MountSecretsFeatureStep(_),
- provideEnvSecretsStep: (KubernetesConf => EnvSecretsFeatureStep) =
- new EnvSecretsFeatureStep(_),
- provideLocalDirsStep: (KubernetesConf => LocalDirsFeatureStep) =
- new LocalDirsFeatureStep(_),
- provideVolumesStep: (KubernetesConf => MountVolumesFeatureStep) =
- new MountVolumesFeatureStep(_),
- provideDriverCommandStep: (KubernetesDriverConf => DriverCommandFeatureStep) =
- new DriverCommandFeatureStep(_),
- provideHadoopGlobalStep: (KubernetesDriverConf => KerberosConfDriverFeatureStep) =
- new KerberosConfDriverFeatureStep(_),
- providePodTemplateConfigMapStep: (KubernetesConf => PodTemplateConfigMapStep) =
- new PodTemplateConfigMapStep(_),
- provideInitialPod: () => SparkPod = () => SparkPod.initialPod) {
+private[spark] class KubernetesDriverBuilder {
- def buildFromFeatures(kubernetesConf: KubernetesDriverConf): KubernetesDriverSpec = {
- val baseFeatures = Seq(
- provideBasicStep(kubernetesConf),
- provideCredentialsStep(kubernetesConf),
- provideServiceStep(kubernetesConf),
- provideLocalDirsStep(kubernetesConf))
-
- val secretFeature = if (kubernetesConf.secretNamesToMountPaths.nonEmpty) {
- Seq(provideSecretsStep(kubernetesConf))
- } else Nil
- val envSecretFeature = if (kubernetesConf.secretEnvNamesToKeyRefs.nonEmpty) {
- Seq(provideEnvSecretsStep(kubernetesConf))
- } else Nil
- val volumesFeature = if (kubernetesConf.volumes.nonEmpty) {
- Seq(provideVolumesStep(kubernetesConf))
- } else Nil
- val podTemplateFeature = if (
- kubernetesConf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE).isDefined) {
- Seq(providePodTemplateConfigMapStep(kubernetesConf))
- } else Nil
-
- val driverCommandStep = provideDriverCommandStep(kubernetesConf)
-
- val hadoopConfigStep = Some(provideHadoopGlobalStep(kubernetesConf))
+ def buildFromFeatures(
+ conf: KubernetesDriverConf,
+ client: KubernetesClient): KubernetesDriverSpec = {
+ val initialPod = conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE)
+ .map { file =>
+ KubernetesUtils.loadPodFromTemplate(
+ client,
+ new File(file),
+ conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_CONTAINER_NAME))
+ }
+ .getOrElse(SparkPod.initialPod())
- val allFeatures: Seq[KubernetesFeatureConfigStep] =
- baseFeatures ++ Seq(driverCommandStep) ++
- secretFeature ++ envSecretFeature ++ volumesFeature ++
- hadoopConfigStep ++ podTemplateFeature
+ val features = Seq(
+ new BasicDriverFeatureStep(conf),
+ new DriverKubernetesCredentialsFeatureStep(conf),
+ new DriverServiceFeatureStep(conf),
+ new MountSecretsFeatureStep(conf),
+ new EnvSecretsFeatureStep(conf),
+ new LocalDirsFeatureStep(conf),
+ new MountVolumesFeatureStep(conf),
+ new DriverCommandFeatureStep(conf),
+ new HadoopConfDriverFeatureStep(conf),
+ new KerberosConfDriverFeatureStep(conf),
+ new PodTemplateConfigMapStep(conf))
- var spec = KubernetesDriverSpec(
- provideInitialPod(),
+ val spec = KubernetesDriverSpec(
+ initialPod,
driverKubernetesResources = Seq.empty,
- kubernetesConf.sparkConf.getAll.toMap)
- for (feature <- allFeatures) {
+ conf.sparkConf.getAll.toMap)
+
+ features.foldLeft(spec) { case (spec, feature) =>
val configuredPod = feature.configurePod(spec.pod)
val addedSystemProperties = feature.getAdditionalPodSystemProperties()
val addedResources = feature.getAdditionalKubernetesResources()
- spec = KubernetesDriverSpec(
+ KubernetesDriverSpec(
configuredPod,
spec.driverKubernetesResources ++ addedResources,
spec.systemProperties ++ addedSystemProperties)
}
- spec
}
-}
-private[spark] object KubernetesDriverBuilder {
- def apply(kubernetesClient: KubernetesClient, conf: SparkConf): KubernetesDriverBuilder = {
- conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE)
- .map(new File(_))
- .map(file => new KubernetesDriverBuilder(provideInitialPod = () =>
- KubernetesUtils.loadPodFromTemplate(
- kubernetesClient,
- file,
- conf.get(Config.KUBERNETES_DRIVER_PODTEMPLATE_CONTAINER_NAME))
- ))
- .getOrElse(new KubernetesDriverBuilder())
- }
}
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala
index 2f0f949566d6a..da3edfeca9b1f 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocator.scala
@@ -22,7 +22,7 @@ import io.fabric8.kubernetes.api.model.PodBuilder
import io.fabric8.kubernetes.client.KubernetesClient
import scala.collection.mutable
-import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.{SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.deploy.k8s.KubernetesConf
@@ -31,6 +31,7 @@ import org.apache.spark.util.{Clock, Utils}
private[spark] class ExecutorPodsAllocator(
conf: SparkConf,
+ secMgr: SecurityManager,
executorBuilder: KubernetesExecutorBuilder,
kubernetesClient: KubernetesClient,
snapshotsStore: ExecutorPodsSnapshotsStore,
@@ -135,7 +136,8 @@ private[spark] class ExecutorPodsAllocator(
newExecutorId.toString,
applicationId,
driverPod)
- val executorPod = executorBuilder.buildFromFeatures(executorConf)
+ val executorPod = executorBuilder.buildFromFeatures(executorConf, secMgr,
+ kubernetesClient)
val podWithAttachedContainer = new PodBuilder(executorPod.pod)
.editOrNewSpec()
.addToContainers(executorPod.container)
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
index ce10f766334ff..809bdf8ca8c27 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala
@@ -94,7 +94,8 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit
val executorPodsAllocator = new ExecutorPodsAllocator(
sc.conf,
- KubernetesExecutorBuilder(kubernetesClient, sc.conf),
+ sc.env.securityManager,
+ new KubernetesExecutorBuilder(),
kubernetesClient,
snapshotsStore,
new SystemClock())
@@ -110,7 +111,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit
new KubernetesClusterSchedulerBackend(
scheduler.asInstanceOf[TaskSchedulerImpl],
- sc.env.rpcEnv,
+ sc,
kubernetesClient,
requestExecutorsService,
snapshotsStore,
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
index 6356b58645806..cd298971e02a7 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala
@@ -18,11 +18,14 @@ package org.apache.spark.scheduler.cluster.k8s
import java.util.concurrent.ExecutorService
-import io.fabric8.kubernetes.client.KubernetesClient
import scala.concurrent.{ExecutionContext, Future}
+import io.fabric8.kubernetes.client.KubernetesClient
+
+import org.apache.spark.SparkContext
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.security.HadoopDelegationTokenManager
import org.apache.spark.rpc.{RpcAddress, RpcEnv}
import org.apache.spark.scheduler.{ExecutorLossReason, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils}
@@ -30,7 +33,7 @@ import org.apache.spark.util.{ThreadUtils, Utils}
private[spark] class KubernetesClusterSchedulerBackend(
scheduler: TaskSchedulerImpl,
- rpcEnv: RpcEnv,
+ sc: SparkContext,
kubernetesClient: KubernetesClient,
requestExecutorsService: ExecutorService,
snapshotsStore: ExecutorPodsSnapshotsStore,
@@ -38,10 +41,10 @@ private[spark] class KubernetesClusterSchedulerBackend(
lifecycleEventHandler: ExecutorPodsLifecycleManager,
watchEvents: ExecutorPodsWatchSnapshotSource,
pollEvents: ExecutorPodsPollingSnapshotSource)
- extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) {
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) {
- private implicit val requestExecutorContext = ExecutionContext.fromExecutorService(
- requestExecutorsService)
+ private implicit val requestExecutorContext =
+ ExecutionContext.fromExecutorService(requestExecutorsService)
protected override val minRegisteredRatio =
if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
@@ -59,6 +62,17 @@ private[spark] class KubernetesClusterSchedulerBackend(
removeExecutor(executorId, reason)
}
+ /**
+ * Get an application ID associated with the job.
+ * This returns the string value of spark.app.id if set, otherwise
+ * the locally-generated ID from the superclass.
+ *
+ * @return The application ID
+ */
+ override def applicationId(): String = {
+ conf.getOption("spark.app.id").map(_.toString).getOrElse(super.applicationId)
+ }
+
override def start(): Unit = {
super.start()
if (!Utils.isDynamicAllocationEnabled(conf)) {
@@ -87,7 +101,8 @@ private[spark] class KubernetesClusterSchedulerBackend(
if (shouldDeleteExecutors) {
Utils.tryLogNonFatalError {
- kubernetesClient.pods()
+ kubernetesClient
+ .pods()
.withLabel(SPARK_APP_ID_LABEL, applicationId())
.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)
.delete()
@@ -119,7 +134,8 @@ private[spark] class KubernetesClusterSchedulerBackend(
}
override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] {
- kubernetesClient.pods()
+ kubernetesClient
+ .pods()
.withLabel(SPARK_APP_ID_LABEL, applicationId())
.withLabel(SPARK_ROLE_LABEL, SPARK_POD_EXECUTOR_ROLE)
.withLabelIn(SPARK_EXECUTOR_ID_LABEL, executorIds: _*)
@@ -128,11 +144,15 @@ private[spark] class KubernetesClusterSchedulerBackend(
}
override def createDriverEndpoint(properties: Seq[(String, String)]): DriverEndpoint = {
- new KubernetesDriverEndpoint(rpcEnv, properties)
+ new KubernetesDriverEndpoint(sc.env.rpcEnv, properties)
+ }
+
+ override protected def createTokenManager(): Option[HadoopDelegationTokenManager] = {
+ Some(new HadoopDelegationTokenManager(conf, sc.hadoopConfiguration))
}
private class KubernetesDriverEndpoint(rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)])
- extends DriverEndpoint(rpcEnv, sparkProperties) {
+ extends DriverEndpoint(rpcEnv, sparkProperties) {
override def onDisconnected(rpcAddress: RpcAddress): Unit = {
// Don't do anything besides disabling the executor - allow the Kubernetes API events to
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
index d24ff0d1e6600..48aa2c56d4d69 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala
@@ -20,83 +20,33 @@ import java.io.File
import io.fabric8.kubernetes.client.KubernetesClient
-import org.apache.spark.SparkConf
+import org.apache.spark.SecurityManager
import org.apache.spark.deploy.k8s._
-import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.deploy.k8s.features._
-private[spark] class KubernetesExecutorBuilder(
- provideBasicStep: (KubernetesExecutorConf => BasicExecutorFeatureStep) =
- new BasicExecutorFeatureStep(_),
- provideSecretsStep: (KubernetesConf => MountSecretsFeatureStep) =
- new MountSecretsFeatureStep(_),
- provideEnvSecretsStep: (KubernetesConf => EnvSecretsFeatureStep) =
- new EnvSecretsFeatureStep(_),
- provideLocalDirsStep: (KubernetesConf => LocalDirsFeatureStep) =
- new LocalDirsFeatureStep(_),
- provideVolumesStep: (KubernetesConf => MountVolumesFeatureStep) =
- new MountVolumesFeatureStep(_),
- provideHadoopConfStep: (KubernetesExecutorConf => HadoopConfExecutorFeatureStep) =
- new HadoopConfExecutorFeatureStep(_),
- provideKerberosConfStep: (KubernetesExecutorConf => KerberosConfExecutorFeatureStep) =
- new KerberosConfExecutorFeatureStep(_),
- provideHadoopSparkUserStep: (KubernetesExecutorConf => HadoopSparkUserExecutorFeatureStep) =
- new HadoopSparkUserExecutorFeatureStep(_),
- provideInitialPod: () => SparkPod = () => SparkPod.initialPod()) {
-
- def buildFromFeatures(kubernetesConf: KubernetesExecutorConf): SparkPod = {
- val sparkConf = kubernetesConf.sparkConf
- val maybeHadoopConfigMap = sparkConf.getOption(HADOOP_CONFIG_MAP_NAME)
- val maybeDTSecretName = sparkConf.getOption(KERBEROS_DT_SECRET_NAME)
- val maybeDTDataItem = sparkConf.getOption(KERBEROS_DT_SECRET_KEY)
-
- val baseFeatures = Seq(provideBasicStep(kubernetesConf), provideLocalDirsStep(kubernetesConf))
- val secretFeature = if (kubernetesConf.secretNamesToMountPaths.nonEmpty) {
- Seq(provideSecretsStep(kubernetesConf))
- } else Nil
- val secretEnvFeature = if (kubernetesConf.secretEnvNamesToKeyRefs.nonEmpty) {
- Seq(provideEnvSecretsStep(kubernetesConf))
- } else Nil
- val volumesFeature = if (kubernetesConf.volumes.nonEmpty) {
- Seq(provideVolumesStep(kubernetesConf))
- } else Nil
-
- val maybeHadoopConfFeatureSteps = maybeHadoopConfigMap.map { _ =>
- val maybeKerberosStep =
- if (maybeDTSecretName.isDefined && maybeDTDataItem.isDefined) {
- provideKerberosConfStep(kubernetesConf)
- } else {
- provideHadoopSparkUserStep(kubernetesConf)
- }
- Seq(provideHadoopConfStep(kubernetesConf)) :+
- maybeKerberosStep
- }.getOrElse(Seq.empty[KubernetesFeatureConfigStep])
-
- val allFeatures: Seq[KubernetesFeatureConfigStep] =
- baseFeatures ++
- secretFeature ++
- secretEnvFeature ++
- volumesFeature ++
- maybeHadoopConfFeatureSteps
-
- var executorPod = provideInitialPod()
- for (feature <- allFeatures) {
- executorPod = feature.configurePod(executorPod)
- }
- executorPod
+private[spark] class KubernetesExecutorBuilder {
+
+ def buildFromFeatures(
+ conf: KubernetesExecutorConf,
+ secMgr: SecurityManager,
+ client: KubernetesClient): SparkPod = {
+ val initialPod = conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE)
+ .map { file =>
+ KubernetesUtils.loadPodFromTemplate(
+ client,
+ new File(file),
+ conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME))
+ }
+ .getOrElse(SparkPod.initialPod())
+
+ val features = Seq(
+ new BasicExecutorFeatureStep(conf, secMgr),
+ new MountSecretsFeatureStep(conf),
+ new EnvSecretsFeatureStep(conf),
+ new LocalDirsFeatureStep(conf),
+ new MountVolumesFeatureStep(conf))
+
+ features.foldLeft(initialPod) { case (pod, feature) => feature.configurePod(pod) }
}
-}
-private[spark] object KubernetesExecutorBuilder {
- def apply(kubernetesClient: KubernetesClient, conf: SparkConf): KubernetesExecutorBuilder = {
- conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE)
- .map(new File(_))
- .map(file => new KubernetesExecutorBuilder(provideInitialPod = () =>
- KubernetesUtils.loadPodFromTemplate(
- kubernetesClient,
- file,
- conf.get(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_CONTAINER_NAME))
- ))
- .getOrElse(new KubernetesExecutorBuilder())
- }
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala
new file mode 100644
index 0000000000000..7dde0c1377168
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/PodBuilderSuite.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s
+
+import java.io.File
+
+import io.fabric8.kubernetes.api.model.{Config => _, _}
+import io.fabric8.kubernetes.client.KubernetesClient
+import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource}
+import org.mockito.Matchers.any
+import org.mockito.Mockito.{mock, never, verify, when}
+import scala.collection.JavaConverters._
+
+import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.deploy.k8s._
+import org.apache.spark.internal.config.ConfigEntry
+
+abstract class PodBuilderSuite extends SparkFunSuite {
+
+ protected def templateFileConf: ConfigEntry[_]
+
+ protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod
+
+ private val baseConf = new SparkConf(false)
+ .set(Config.CONTAINER_IMAGE, "spark-executor:latest")
+
+ test("use empty initial pod if template is not specified") {
+ val client = mock(classOf[KubernetesClient])
+ buildPod(baseConf.clone(), client)
+ verify(client, never()).pods()
+ }
+
+ test("load pod template if specified") {
+ val client = mockKubernetesClient()
+ val sparkConf = baseConf.clone().set(templateFileConf.key, "template-file.yaml")
+ val pod = buildPod(sparkConf, client)
+ verifyPod(pod)
+ }
+
+ test("complain about misconfigured pod template") {
+ val client = mockKubernetesClient(
+ new PodBuilder()
+ .withNewMetadata()
+ .addToLabels("test-label-key", "test-label-value")
+ .endMetadata()
+ .build())
+ val sparkConf = baseConf.clone().set(templateFileConf.key, "template-file.yaml")
+ val exception = intercept[SparkException] {
+ buildPod(sparkConf, client)
+ }
+ assert(exception.getMessage.contains("Could not load pod from template file."))
+ }
+
+ private def mockKubernetesClient(pod: Pod = podWithSupportedFeatures()): KubernetesClient = {
+ val kubernetesClient = mock(classOf[KubernetesClient])
+ val pods =
+ mock(classOf[MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]])
+ val podResource = mock(classOf[PodResource[Pod, DoneablePod]])
+ when(kubernetesClient.pods()).thenReturn(pods)
+ when(pods.load(any(classOf[File]))).thenReturn(podResource)
+ when(podResource.get()).thenReturn(pod)
+ kubernetesClient
+ }
+
+ private def verifyPod(pod: SparkPod): Unit = {
+ val metadata = pod.pod.getMetadata
+ assert(metadata.getLabels.containsKey("test-label-key"))
+ assert(metadata.getAnnotations.containsKey("test-annotation-key"))
+ assert(metadata.getNamespace === "namespace")
+ assert(metadata.getOwnerReferences.asScala.exists(_.getName == "owner-reference"))
+ val spec = pod.pod.getSpec
+ assert(!spec.getContainers.asScala.exists(_.getName == "executor-container"))
+ assert(spec.getDnsPolicy === "dns-policy")
+ assert(spec.getHostAliases.asScala.exists(_.getHostnames.asScala.exists(_ == "hostname")))
+ assert(spec.getImagePullSecrets.asScala.exists(_.getName == "local-reference"))
+ assert(spec.getInitContainers.asScala.exists(_.getName == "init-container"))
+ assert(spec.getNodeName == "node-name")
+ assert(spec.getNodeSelector.get("node-selector-key") === "node-selector-value")
+ assert(spec.getSchedulerName === "scheduler")
+ assert(spec.getSecurityContext.getRunAsUser === 1000L)
+ assert(spec.getServiceAccount === "service-account")
+ assert(spec.getSubdomain === "subdomain")
+ assert(spec.getTolerations.asScala.exists(_.getKey == "toleration-key"))
+ assert(spec.getVolumes.asScala.exists(_.getName == "test-volume"))
+ val container = pod.container
+ assert(container.getName === "executor-container")
+ assert(container.getArgs.contains("arg"))
+ assert(container.getCommand.equals(List("command").asJava))
+ assert(container.getEnv.asScala.exists(_.getName == "env-key"))
+ assert(container.getResources.getLimits.get("gpu") ===
+ new QuantityBuilder().withAmount("1").build())
+ assert(container.getSecurityContext.getRunAsNonRoot)
+ assert(container.getStdin)
+ assert(container.getTerminationMessagePath === "termination-message-path")
+ assert(container.getTerminationMessagePolicy === "termination-message-policy")
+ assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "test-volume"))
+ }
+
+ private def podWithSupportedFeatures(): Pod = {
+ new PodBuilder()
+ .withNewMetadata()
+ .addToLabels("test-label-key", "test-label-value")
+ .addToAnnotations("test-annotation-key", "test-annotation-value")
+ .withNamespace("namespace")
+ .addNewOwnerReference()
+ .withController(true)
+ .withName("owner-reference")
+ .endOwnerReference()
+ .endMetadata()
+ .withNewSpec()
+ .withDnsPolicy("dns-policy")
+ .withHostAliases(new HostAliasBuilder().withHostnames("hostname").build())
+ .withImagePullSecrets(
+ new LocalObjectReferenceBuilder().withName("local-reference").build())
+ .withInitContainers(new ContainerBuilder().withName("init-container").build())
+ .withNodeName("node-name")
+ .withNodeSelector(Map("node-selector-key" -> "node-selector-value").asJava)
+ .withSchedulerName("scheduler")
+ .withNewSecurityContext()
+ .withRunAsUser(1000L)
+ .endSecurityContext()
+ .withServiceAccount("service-account")
+ .withSubdomain("subdomain")
+ .withTolerations(new TolerationBuilder()
+ .withKey("toleration-key")
+ .withOperator("Equal")
+ .withEffect("NoSchedule")
+ .build())
+ .addNewVolume()
+ .withNewHostPath()
+ .withPath("/test")
+ .endHostPath()
+ .withName("test-volume")
+ .endVolume()
+ .addNewContainer()
+ .withArgs("arg")
+ .withCommand("command")
+ .addNewEnv()
+ .withName("env-key")
+ .withValue("env-value")
+ .endEnv()
+ .withImagePullPolicy("Always")
+ .withName("executor-container")
+ .withNewResources()
+ .withLimits(Map("gpu" -> new QuantityBuilder().withAmount("1").build()).asJava)
+ .endResources()
+ .withNewSecurityContext()
+ .withRunAsNonRoot(true)
+ .endSecurityContext()
+ .withStdin(true)
+ .withTerminationMessagePath("termination-message-path")
+ .withTerminationMessagePolicy("termination-message-policy")
+ .addToVolumeMounts(
+ new VolumeMountBuilder()
+ .withName("test-volume")
+ .withMountPath("/test")
+ .build())
+ .endContainer()
+ .endSpec()
+ .build()
+ }
+
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
index e4951bc1e69ed..5ceb9d6d6fcd0 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.deploy.k8s.submit._
import org.apache.spark.internal.config._
import org.apache.spark.ui.SparkUI
+import org.apache.spark.util.Utils
class BasicDriverFeatureStepSuite extends SparkFunSuite {
@@ -73,7 +74,6 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
val foundPortNames = configuredPod.container.getPorts.asScala.toSet
assert(expectedPortNames === foundPortNames)
- assert(configuredPod.container.getEnv.size === 3)
val envs = configuredPod.container
.getEnv
.asScala
@@ -82,6 +82,7 @@ class BasicDriverFeatureStepSuite extends SparkFunSuite {
DRIVER_ENVS.foreach { case (k, v) =>
assert(envs(v) === v)
}
+ assert(envs(ENV_SPARK_USER) === Utils.getCurrentUserName())
assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala ===
TEST_IMAGE_PULL_SECRET_OBJECTS)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
index d6003c977937c..c2efab01e4248 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala
@@ -16,18 +16,23 @@
*/
package org.apache.spark.deploy.k8s.features
+import java.io.File
+import java.nio.charset.StandardCharsets
+import java.nio.file.Files
+
import scala.collection.JavaConverters._
import io.fabric8.kubernetes.api.model._
import org.scalatest.BeforeAndAfter
-import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SparkPod}
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.internal.config._
import org.apache.spark.rpc.RpcEndpointAddress
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.util.Utils
class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
@@ -63,7 +68,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
private var baseConf: SparkConf = _
before {
- baseConf = new SparkConf()
+ baseConf = new SparkConf(false)
.set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME)
.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX)
.set(CONTAINER_IMAGE, EXECUTOR_IMAGE)
@@ -84,7 +89,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
}
test("basic executor pod has reasonable defaults") {
- val step = new BasicExecutorFeatureStep(newExecutorConf())
+ val step = new BasicExecutorFeatureStep(newExecutorConf(), new SecurityManager(baseConf))
val executor = step.configurePod(SparkPod.initialPod())
// The executor pod name and default labels.
@@ -106,7 +111,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
assert(executor.pod.getSpec.getNodeSelector.isEmpty)
assert(executor.pod.getSpec.getVolumes.isEmpty)
- checkEnv(executor, Map())
+ checkEnv(executor, baseConf, Map())
checkOwnerReferences(executor.pod, DRIVER_POD_UID)
}
@@ -114,7 +119,7 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple"
baseConf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, longPodNamePrefix)
- val step = new BasicExecutorFeatureStep(newExecutorConf())
+ val step = new BasicExecutorFeatureStep(newExecutorConf(), new SecurityManager(baseConf))
assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63)
}
@@ -122,10 +127,10 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
baseConf.set(EXECUTOR_JAVA_OPTIONS, "foo=bar")
baseConf.set(EXECUTOR_CLASS_PATH, "bar=baz")
val kconf = newExecutorConf(environment = Map("qux" -> "quux"))
- val step = new BasicExecutorFeatureStep(kconf)
+ val step = new BasicExecutorFeatureStep(kconf, new SecurityManager(baseConf))
val executor = step.configurePod(SparkPod.initialPod())
- checkEnv(executor,
+ checkEnv(executor, baseConf,
Map("SPARK_JAVA_OPT_0" -> "foo=bar",
ENV_CLASSPATH -> "bar=baz",
"qux" -> "quux"))
@@ -136,12 +141,46 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
baseConf.set("spark.kubernetes.resource.type", "python")
baseConf.set(PYSPARK_EXECUTOR_MEMORY, 42L)
- val step = new BasicExecutorFeatureStep(newExecutorConf())
+ val step = new BasicExecutorFeatureStep(newExecutorConf(), new SecurityManager(baseConf))
val executor = step.configurePod(SparkPod.initialPod())
// This is checking that basic executor + executorMemory = 1408 + 42 = 1450
assert(executor.container.getResources.getRequests.get("memory").getAmount === "1450Mi")
}
+ test("auth secret propagation") {
+ val conf = baseConf.clone()
+ .set(NETWORK_AUTH_ENABLED, true)
+ .set("spark.master", "k8s://127.0.0.1")
+
+ val secMgr = new SecurityManager(conf)
+ secMgr.initializeAuth()
+
+ val step = new BasicExecutorFeatureStep(KubernetesTestConf.createExecutorConf(sparkConf = conf),
+ secMgr)
+
+ val executor = step.configurePod(SparkPod.initialPod())
+ checkEnv(executor, conf, Map(SecurityManager.ENV_AUTH_SECRET -> secMgr.getSecretKey()))
+ }
+
+ test("Auth secret shouldn't propagate if files are loaded.") {
+ val secretDir = Utils.createTempDir("temp-secret")
+ val secretFile = new File(secretDir, "secret-file.txt")
+ Files.write(secretFile.toPath, "some-secret".getBytes(StandardCharsets.UTF_8))
+ val conf = baseConf.clone()
+ .set(NETWORK_AUTH_ENABLED, true)
+ .set(AUTH_SECRET_FILE, secretFile.getAbsolutePath)
+ .set("spark.master", "k8s://127.0.0.1")
+ val secMgr = new SecurityManager(conf)
+ secMgr.initializeAuth()
+
+ val step = new BasicExecutorFeatureStep(KubernetesTestConf.createExecutorConf(sparkConf = conf),
+ secMgr)
+
+ val executor = step.configurePod(SparkPod.initialPod())
+ assert(!KubernetesFeaturesTestUtils.containerHasEnvVar(
+ executor.container, SecurityManager.ENV_AUTH_SECRET))
+ }
+
// There is always exactly one controller reference, and it points to the driver pod.
private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = {
assert(executor.getMetadata.getOwnerReferences.size() === 1)
@@ -150,7 +189,10 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
}
// Check that the expected environment variables are present.
- private def checkEnv(executorPod: SparkPod, additionalEnvVars: Map[String, String]): Unit = {
+ private def checkEnv(
+ executorPod: SparkPod,
+ conf: SparkConf,
+ additionalEnvVars: Map[String, String]): Unit = {
val defaultEnvs = Map(
ENV_EXECUTOR_ID -> "1",
ENV_DRIVER_URL -> DRIVER_ADDRESS.toString,
@@ -158,12 +200,20 @@ class BasicExecutorFeatureStepSuite extends SparkFunSuite with BeforeAndAfter {
ENV_EXECUTOR_MEMORY -> "1g",
ENV_APPLICATION_ID -> KubernetesTestConf.APP_ID,
ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL,
- ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars
+ ENV_EXECUTOR_POD_IP -> null,
+ ENV_SPARK_USER -> Utils.getCurrentUserName())
- assert(executorPod.container.getEnv.size() === defaultEnvs.size)
- val mapEnvs = executorPod.container.getEnv.asScala.map {
+ val extraJavaOptsStart = additionalEnvVars.keys.count(_.startsWith(ENV_JAVA_OPT_PREFIX))
+ val extraJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf)
+ val extraJavaOptsEnvs = extraJavaOpts.zipWithIndex.map { case (opt, ind) =>
+ s"$ENV_JAVA_OPT_PREFIX${ind + extraJavaOptsStart}" -> opt
+ }.toMap
+
+ val containerEnvs = executorPod.container.getEnv.asScala.map {
x => (x.getName, x.getValue)
}.toMap
- assert(defaultEnvs === mapEnvs)
+
+ val expectedEnvs = defaultEnvs ++ additionalEnvVars ++ extraJavaOptsEnvs
+ assert(containerEnvs === expectedEnvs)
}
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala
new file mode 100644
index 0000000000000..e1c01dbdc7358
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/HadoopConfDriverFeatureStepSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import java.io.File
+import java.nio.charset.StandardCharsets.UTF_8
+
+import scala.collection.JavaConverters._
+
+import com.google.common.io.Files
+import io.fabric8.kubernetes.api.model.ConfigMap
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.k8s._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.submit.JavaMainAppResource
+import org.apache.spark.util.{SparkConfWithEnv, Utils}
+
+class HadoopConfDriverFeatureStepSuite extends SparkFunSuite {
+
+ import KubernetesFeaturesTestUtils._
+ import SecretVolumeUtils._
+
+ test("mount hadoop config map if defined") {
+ val sparkConf = new SparkConf(false)
+ .set(Config.KUBERNETES_HADOOP_CONF_CONFIG_MAP, "testConfigMap")
+ val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf)
+ val step = new HadoopConfDriverFeatureStep(conf)
+ checkPod(step.configurePod(SparkPod.initialPod()))
+ assert(step.getAdditionalKubernetesResources().isEmpty)
+ }
+
+ test("create hadoop config map if config dir is defined") {
+ val confDir = Utils.createTempDir()
+ val confFiles = Set("core-site.xml", "hdfs-site.xml")
+
+ confFiles.foreach { f =>
+ Files.write("some data", new File(confDir, f), UTF_8)
+ }
+
+ val sparkConf = new SparkConfWithEnv(Map(ENV_HADOOP_CONF_DIR -> confDir.getAbsolutePath()))
+ val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf)
+
+ val step = new HadoopConfDriverFeatureStep(conf)
+ checkPod(step.configurePod(SparkPod.initialPod()))
+
+ val hadoopConfMap = filter[ConfigMap](step.getAdditionalKubernetesResources()).head
+ assert(hadoopConfMap.getData().keySet().asScala === confFiles)
+ }
+
+ private def checkPod(pod: SparkPod): Unit = {
+ assert(podHasVolume(pod.pod, HADOOP_CONF_VOLUME))
+ assert(containerHasVolume(pod.container, HADOOP_CONF_VOLUME, HADOOP_CONF_DIR_PATH))
+ assert(containerHasEnvVar(pod.container, ENV_HADOOP_CONF_DIR))
+ }
+
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala
new file mode 100644
index 0000000000000..41ca3a94ce7a7
--- /dev/null
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KerberosConfDriverFeatureStepSuite.scala
@@ -0,0 +1,171 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.deploy.k8s.features
+
+import java.io.File
+import java.nio.charset.StandardCharsets.UTF_8
+import java.security.PrivilegedExceptionAction
+
+import scala.collection.JavaConverters._
+
+import com.google.common.io.Files
+import io.fabric8.kubernetes.api.model.{ConfigMap, Secret}
+import org.apache.commons.codec.binary.Base64
+import org.apache.hadoop.io.Text
+import org.apache.hadoop.security.{Credentials, UserGroupInformation}
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.deploy.k8s._
+import org.apache.spark.deploy.k8s.Config._
+import org.apache.spark.deploy.k8s.Constants._
+import org.apache.spark.deploy.k8s.submit.JavaMainAppResource
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils
+
+class KerberosConfDriverFeatureStepSuite extends SparkFunSuite {
+
+ import KubernetesFeaturesTestUtils._
+ import SecretVolumeUtils._
+
+ private val tmpDir = Utils.createTempDir()
+
+ test("mount krb5 config map if defined") {
+ val configMap = "testConfigMap"
+ val step = createStep(
+ new SparkConf(false).set(KUBERNETES_KERBEROS_KRB5_CONFIG_MAP, configMap))
+
+ checkPodForKrbConf(step.configurePod(SparkPod.initialPod()), configMap)
+ assert(step.getAdditionalPodSystemProperties().isEmpty)
+ assert(filter[ConfigMap](step.getAdditionalKubernetesResources()).isEmpty)
+ }
+
+ test("create krb5.conf config map if local config provided") {
+ val krbConf = File.createTempFile("krb5", ".conf", tmpDir)
+ Files.write("some data", krbConf, UTF_8)
+
+ val sparkConf = new SparkConf(false)
+ .set(KUBERNETES_KERBEROS_KRB5_FILE, krbConf.getAbsolutePath())
+ val step = createStep(sparkConf)
+
+ val confMap = filter[ConfigMap](step.getAdditionalKubernetesResources()).head
+ assert(confMap.getData().keySet().asScala === Set(krbConf.getName()))
+
+ checkPodForKrbConf(step.configurePod(SparkPod.initialPod()), confMap.getMetadata().getName())
+ assert(step.getAdditionalPodSystemProperties().isEmpty)
+ }
+
+ test("create keytab secret if client keytab file used") {
+ val keytab = File.createTempFile("keytab", ".bin", tmpDir)
+ Files.write("some data", keytab, UTF_8)
+
+ val sparkConf = new SparkConf(false)
+ .set(KEYTAB, keytab.getAbsolutePath())
+ .set(PRINCIPAL, "alice")
+ val step = createStep(sparkConf)
+
+ val pod = step.configurePod(SparkPod.initialPod())
+ assert(podHasVolume(pod.pod, KERBEROS_KEYTAB_VOLUME))
+ assert(containerHasVolume(pod.container, KERBEROS_KEYTAB_VOLUME, KERBEROS_KEYTAB_MOUNT_POINT))
+
+ assert(step.getAdditionalPodSystemProperties().keys === Set(KEYTAB.key))
+
+ val secret = filter[Secret](step.getAdditionalKubernetesResources()).head
+ assert(secret.getData().keySet().asScala === Set(keytab.getName()))
+ }
+
+ test("do nothing if container-local keytab used") {
+ val sparkConf = new SparkConf(false)
+ .set(KEYTAB, "local:/my.keytab")
+ .set(PRINCIPAL, "alice")
+ val step = createStep(sparkConf)
+
+ val initial = SparkPod.initialPod()
+ assert(step.configurePod(initial) === initial)
+ assert(step.getAdditionalPodSystemProperties().isEmpty)
+ assert(step.getAdditionalKubernetesResources().isEmpty)
+ }
+
+ test("mount delegation tokens if provided") {
+ val dtSecret = "tokenSecret"
+ val sparkConf = new SparkConf(false)
+ .set(KUBERNETES_KERBEROS_DT_SECRET_NAME, dtSecret)
+ .set(KUBERNETES_KERBEROS_DT_SECRET_ITEM_KEY, "dtokens")
+ val step = createStep(sparkConf)
+
+ checkPodForTokens(step.configurePod(SparkPod.initialPod()), dtSecret)
+ assert(step.getAdditionalPodSystemProperties().isEmpty)
+ assert(step.getAdditionalKubernetesResources().isEmpty)
+ }
+
+ test("create delegation tokens if needed") {
+ // Since HadoopDelegationTokenManager does not create any tokens without proper configs and
+ // services, start with a test user that already has some tokens that will just be piped
+ // through to the driver.
+ val testUser = UserGroupInformation.createUserForTesting("k8s", Array())
+ testUser.doAs(new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ val creds = testUser.getCredentials()
+ creds.addSecretKey(new Text("K8S_TEST_KEY"), Array[Byte](0x4, 0x2))
+ testUser.addCredentials(creds)
+
+ val tokens = SparkHadoopUtil.get.serialize(creds)
+
+ val step = createStep(new SparkConf(false))
+
+ val dtSecret = filter[Secret](step.getAdditionalKubernetesResources()).head
+ assert(dtSecret.getData().get(KERBEROS_SECRET_KEY) === Base64.encodeBase64String(tokens))
+
+ checkPodForTokens(step.configurePod(SparkPod.initialPod()),
+ dtSecret.getMetadata().getName())
+
+ assert(step.getAdditionalPodSystemProperties().isEmpty)
+ }
+ })
+ }
+
+ test("do nothing if no config and no tokens") {
+ val step = createStep(new SparkConf(false))
+ val initial = SparkPod.initialPod()
+ assert(step.configurePod(initial) === initial)
+ assert(step.getAdditionalPodSystemProperties().isEmpty)
+ assert(step.getAdditionalKubernetesResources().isEmpty)
+ }
+
+ private def checkPodForKrbConf(pod: SparkPod, confMapName: String): Unit = {
+ val podVolume = pod.pod.getSpec().getVolumes().asScala.find(_.getName() == KRB_FILE_VOLUME)
+ assert(podVolume.isDefined)
+ assert(containerHasVolume(pod.container, KRB_FILE_VOLUME, KRB_FILE_DIR_PATH + "/krb5.conf"))
+ assert(podVolume.get.getConfigMap().getName() === confMapName)
+ }
+
+ private def checkPodForTokens(pod: SparkPod, dtSecretName: String): Unit = {
+ val podVolume = pod.pod.getSpec().getVolumes().asScala
+ .find(_.getName() == SPARK_APP_HADOOP_SECRET_VOLUME_NAME)
+ assert(podVolume.isDefined)
+ assert(containerHasVolume(pod.container, SPARK_APP_HADOOP_SECRET_VOLUME_NAME,
+ SPARK_APP_HADOOP_CREDENTIALS_BASE_DIR))
+ assert(containerHasEnvVar(pod.container, ENV_HADOOP_TOKEN_FILE_LOCATION))
+ assert(podVolume.get.getSecret().getSecretName() === dtSecretName)
+ }
+
+ private def createStep(conf: SparkConf): KerberosConfDriverFeatureStep = {
+ val kconf = KubernetesTestConf.createDriverConf(sparkConf = conf)
+ new KerberosConfDriverFeatureStep(kconf)
+ }
+
+}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala
index f90380e30e52a..076b681be2397 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala
@@ -17,6 +17,7 @@
package org.apache.spark.deploy.k8s.features
import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
import io.fabric8.kubernetes.api.model.{Container, HasMetadata, PodBuilder, SecretBuilder}
import org.mockito.Matchers
@@ -63,4 +64,9 @@ object KubernetesFeaturesTestUtils {
def containerHasEnvVar(container: Container, envVarName: String): Boolean = {
container.getEnv.asScala.exists(envVar => envVar.getName == envVarName)
}
+
+ def filter[T: ClassTag](list: Seq[HasMetadata]): Seq[T] = {
+ val desired = implicitly[ClassTag[T]].runtimeClass
+ list.filter(_.getClass() == desired).map(_.asInstanceOf[T]).toSeq
+ }
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala
index 7295b82ca4799..5e7388dc8e672 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/PodTemplateConfigMapStepSuite.scala
@@ -20,25 +20,32 @@ import java.io.{File, PrintWriter}
import java.nio.file.Files
import io.fabric8.kubernetes.api.model.ConfigMap
-import org.scalatest.BeforeAndAfter
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.deploy.k8s._
-class PodTemplateConfigMapStepSuite extends SparkFunSuite with BeforeAndAfter {
- private var kubernetesConf : KubernetesConf = _
- private var templateFile: File = _
+class PodTemplateConfigMapStepSuite extends SparkFunSuite {
- before {
- templateFile = Files.createTempFile("pod-template", "yml").toFile
+ test("Do nothing when executor template is not specified") {
+ val conf = KubernetesTestConf.createDriverConf()
+ val step = new PodTemplateConfigMapStep(conf)
+
+ val initialPod = SparkPod.initialPod()
+ val configuredPod = step.configurePod(initialPod)
+ assert(configuredPod === initialPod)
+
+ assert(step.getAdditionalKubernetesResources().isEmpty)
+ assert(step.getAdditionalPodSystemProperties().isEmpty)
+ }
+
+ test("Mounts executor template volume if config specified") {
+ val templateFile = Files.createTempFile("pod-template", "yml").toFile
templateFile.deleteOnExit()
val sparkConf = new SparkConf(false)
.set(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, templateFile.getAbsolutePath)
- kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf)
- }
+ val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf)
- test("Mounts executor template volume if config specified") {
val writer = new PrintWriter(templateFile)
writer.write("pod-template-contents")
writer.close()
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala
index e9c05fef6f5db..1bb926cbca23d 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala
@@ -126,7 +126,7 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter {
MockitoAnnotations.initMocks(this)
kconf = KubernetesTestConf.createDriverConf(
resourceNamePrefix = Some(KUBERNETES_RESOURCE_PREFIX))
- when(driverBuilder.buildFromFeatures(kconf)).thenReturn(BUILT_KUBERNETES_SPEC)
+ when(driverBuilder.buildFromFeatures(kconf, kubernetesClient)).thenReturn(BUILT_KUBERNETES_SPEC)
when(kubernetesClient.pods()).thenReturn(podOperations)
when(podOperations.withName(POD_NAME)).thenReturn(namedPods)
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
index 7e7dc4763c2e7..6518c91a1a1fd 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala
@@ -16,201 +16,21 @@
*/
package org.apache.spark.deploy.k8s.submit
-import io.fabric8.kubernetes.api.model.PodBuilder
import io.fabric8.kubernetes.client.KubernetesClient
-import org.mockito.Mockito._
-import org.apache.spark.{SparkConf, SparkException, SparkFunSuite}
+import org.apache.spark.SparkConf
import org.apache.spark.deploy.k8s._
-import org.apache.spark.deploy.k8s.Config.{CONTAINER_IMAGE, KUBERNETES_DRIVER_PODTEMPLATE_FILE, KUBERNETES_EXECUTOR_PODTEMPLATE_FILE}
-import org.apache.spark.deploy.k8s.features._
+import org.apache.spark.internal.config.ConfigEntry
-class KubernetesDriverBuilderSuite extends SparkFunSuite {
+class KubernetesDriverBuilderSuite extends PodBuilderSuite {
- private val BASIC_STEP_TYPE = "basic"
- private val CREDENTIALS_STEP_TYPE = "credentials"
- private val SERVICE_STEP_TYPE = "service"
- private val LOCAL_DIRS_STEP_TYPE = "local-dirs"
- private val SECRETS_STEP_TYPE = "mount-secrets"
- private val DRIVER_CMD_STEP_TYPE = "driver-command"
- private val ENV_SECRETS_STEP_TYPE = "env-secrets"
- private val HADOOP_GLOBAL_STEP_TYPE = "hadoop-global"
- private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes"
- private val TEMPLATE_VOLUME_STEP_TYPE = "template-volume"
-
- private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep])
-
- private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep])
-
- private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep])
-
- private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep])
-
- private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep])
-
- private val driverCommandStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- DRIVER_CMD_STEP_TYPE, classOf[DriverCommandFeatureStep])
-
- private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep])
-
- private val hadoopGlobalStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- HADOOP_GLOBAL_STEP_TYPE, classOf[KerberosConfDriverFeatureStep])
-
- private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep])
-
- private val templateVolumeStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- TEMPLATE_VOLUME_STEP_TYPE, classOf[PodTemplateConfigMapStep]
- )
-
- private val builderUnderTest: KubernetesDriverBuilder =
- new KubernetesDriverBuilder(
- _ => basicFeatureStep,
- _ => credentialsStep,
- _ => serviceStep,
- _ => secretsStep,
- _ => envSecretsStep,
- _ => localDirsStep,
- _ => mountVolumesStep,
- _ => driverCommandStep,
- _ => hadoopGlobalStep,
- _ => templateVolumeStep)
-
- test("Apply fundamental steps all the time.") {
- val conf = KubernetesTestConf.createDriverConf()
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf),
- BASIC_STEP_TYPE,
- CREDENTIALS_STEP_TYPE,
- SERVICE_STEP_TYPE,
- LOCAL_DIRS_STEP_TYPE,
- DRIVER_CMD_STEP_TYPE,
- HADOOP_GLOBAL_STEP_TYPE)
+ override protected def templateFileConf: ConfigEntry[_] = {
+ Config.KUBERNETES_DRIVER_PODTEMPLATE_FILE
}
- test("Apply secrets step if secrets are present.") {
- val conf = KubernetesTestConf.createDriverConf(
- secretEnvNamesToKeyRefs = Map("EnvName" -> "SecretName:secretKey"),
- secretNamesToMountPaths = Map("secret" -> "secretMountPath"))
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf),
- BASIC_STEP_TYPE,
- CREDENTIALS_STEP_TYPE,
- SERVICE_STEP_TYPE,
- LOCAL_DIRS_STEP_TYPE,
- SECRETS_STEP_TYPE,
- ENV_SECRETS_STEP_TYPE,
- DRIVER_CMD_STEP_TYPE,
- HADOOP_GLOBAL_STEP_TYPE)
- }
-
- test("Apply volumes step if mounts are present.") {
- val volumeSpec = KubernetesVolumeSpec(
- "volume",
- "/tmp",
- "",
- false,
- KubernetesHostPathVolumeConf("/path"))
- val conf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeSpec))
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf),
- BASIC_STEP_TYPE,
- CREDENTIALS_STEP_TYPE,
- SERVICE_STEP_TYPE,
- LOCAL_DIRS_STEP_TYPE,
- MOUNT_VOLUMES_STEP_TYPE,
- DRIVER_CMD_STEP_TYPE,
- HADOOP_GLOBAL_STEP_TYPE)
- }
-
- test("Apply volumes step if a mount subpath is present.") {
- val volumeSpec = KubernetesVolumeSpec(
- "volume",
- "/tmp",
- "foo",
- false,
- KubernetesHostPathVolumeConf("/path"))
- val conf = KubernetesTestConf.createDriverConf(volumes = Seq(volumeSpec))
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf),
- BASIC_STEP_TYPE,
- CREDENTIALS_STEP_TYPE,
- SERVICE_STEP_TYPE,
- LOCAL_DIRS_STEP_TYPE,
- MOUNT_VOLUMES_STEP_TYPE,
- DRIVER_CMD_STEP_TYPE,
- HADOOP_GLOBAL_STEP_TYPE)
- }
-
- test("Apply template volume step if executor template is present.") {
- val sparkConf = new SparkConf(false)
- .set(KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, "filename")
+ override protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod = {
val conf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf)
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf),
- BASIC_STEP_TYPE,
- CREDENTIALS_STEP_TYPE,
- SERVICE_STEP_TYPE,
- LOCAL_DIRS_STEP_TYPE,
- DRIVER_CMD_STEP_TYPE,
- HADOOP_GLOBAL_STEP_TYPE,
- TEMPLATE_VOLUME_STEP_TYPE)
- }
-
- private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*)
- : Unit = {
- val addedProperties = resolvedSpec.systemProperties
- .filter { case (k, _) => !k.startsWith("spark.") }
- .toMap
- assert(addedProperties.keys.toSet === stepTypes.toSet)
- stepTypes.foreach { stepType =>
- assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType)
- assert(resolvedSpec.driverKubernetesResources.containsSlice(
- KubernetesFeaturesTestUtils.getSecretsForStepType(stepType)))
- assert(resolvedSpec.systemProperties(stepType) === stepType)
- }
- }
-
- test("Start with empty pod if template is not specified") {
- val kubernetesClient = mock(classOf[KubernetesClient])
- val driverBuilder = KubernetesDriverBuilder.apply(kubernetesClient, new SparkConf())
- verify(kubernetesClient, never()).pods()
+ new KubernetesDriverBuilder().buildFromFeatures(conf, client).pod
}
- test("Starts with template if specified") {
- val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient()
- val sparkConf = new SparkConf(false)
- .set(CONTAINER_IMAGE, "spark-driver:latest")
- .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml")
- val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf)
- val driverSpec = KubernetesDriverBuilder
- .apply(kubernetesClient, sparkConf)
- .buildFromFeatures(kubernetesConf)
- PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(driverSpec.pod)
- }
-
- test("Throws on misconfigured pod template") {
- val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient(
- new PodBuilder()
- .withNewMetadata()
- .addToLabels("test-label-key", "test-label-value")
- .endMetadata()
- .build())
- val sparkConf = new SparkConf(false)
- .set(CONTAINER_IMAGE, "spark-driver:latest")
- .set(KUBERNETES_DRIVER_PODTEMPLATE_FILE, "template-file.yaml")
- val kubernetesConf = KubernetesTestConf.createDriverConf(sparkConf = sparkConf)
- val exception = intercept[SparkException] {
- KubernetesDriverBuilder
- .apply(kubernetesClient, sparkConf)
- .buildFromFeatures(kubernetesConf)
- }
- assert(exception.getMessage.contains("Could not load pod from template file."))
- }
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala
deleted file mode 100644
index c92e9e6e3b6b3..0000000000000
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/PodBuilderSuiteUtils.scala
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.deploy.k8s.submit
-
-import java.io.File
-
-import io.fabric8.kubernetes.api.model._
-import io.fabric8.kubernetes.client.KubernetesClient
-import io.fabric8.kubernetes.client.dsl.{MixedOperation, PodResource}
-import org.mockito.Matchers.any
-import org.mockito.Mockito.{mock, when}
-import org.scalatest.FlatSpec
-import scala.collection.JavaConverters._
-
-import org.apache.spark.deploy.k8s.SparkPod
-
-object PodBuilderSuiteUtils extends FlatSpec {
-
- def loadingMockKubernetesClient(pod: Pod = podWithSupportedFeatures()): KubernetesClient = {
- val kubernetesClient = mock(classOf[KubernetesClient])
- val pods =
- mock(classOf[MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]]])
- val podResource = mock(classOf[PodResource[Pod, DoneablePod]])
- when(kubernetesClient.pods()).thenReturn(pods)
- when(pods.load(any(classOf[File]))).thenReturn(podResource)
- when(podResource.get()).thenReturn(pod)
- kubernetesClient
- }
-
- def verifyPodWithSupportedFeatures(pod: SparkPod): Unit = {
- val metadata = pod.pod.getMetadata
- assert(metadata.getLabels.containsKey("test-label-key"))
- assert(metadata.getAnnotations.containsKey("test-annotation-key"))
- assert(metadata.getNamespace === "namespace")
- assert(metadata.getOwnerReferences.asScala.exists(_.getName == "owner-reference"))
- val spec = pod.pod.getSpec
- assert(!spec.getContainers.asScala.exists(_.getName == "executor-container"))
- assert(spec.getDnsPolicy === "dns-policy")
- assert(spec.getHostAliases.asScala.exists(_.getHostnames.asScala.exists(_ == "hostname")))
- assert(spec.getImagePullSecrets.asScala.exists(_.getName == "local-reference"))
- assert(spec.getInitContainers.asScala.exists(_.getName == "init-container"))
- assert(spec.getNodeName == "node-name")
- assert(spec.getNodeSelector.get("node-selector-key") === "node-selector-value")
- assert(spec.getSchedulerName === "scheduler")
- assert(spec.getSecurityContext.getRunAsUser === 1000L)
- assert(spec.getServiceAccount === "service-account")
- assert(spec.getSubdomain === "subdomain")
- assert(spec.getTolerations.asScala.exists(_.getKey == "toleration-key"))
- assert(spec.getVolumes.asScala.exists(_.getName == "test-volume"))
- val container = pod.container
- assert(container.getName === "executor-container")
- assert(container.getArgs.contains("arg"))
- assert(container.getCommand.equals(List("command").asJava))
- assert(container.getEnv.asScala.exists(_.getName == "env-key"))
- assert(container.getResources.getLimits.get("gpu") ===
- new QuantityBuilder().withAmount("1").build())
- assert(container.getSecurityContext.getRunAsNonRoot)
- assert(container.getStdin)
- assert(container.getTerminationMessagePath === "termination-message-path")
- assert(container.getTerminationMessagePolicy === "termination-message-policy")
- assert(pod.container.getVolumeMounts.asScala.exists(_.getName == "test-volume"))
-
- }
-
-
- def podWithSupportedFeatures(): Pod = new PodBuilder()
- .withNewMetadata()
- .addToLabels("test-label-key", "test-label-value")
- .addToAnnotations("test-annotation-key", "test-annotation-value")
- .withNamespace("namespace")
- .addNewOwnerReference()
- .withController(true)
- .withName("owner-reference")
- .endOwnerReference()
- .endMetadata()
- .withNewSpec()
- .withDnsPolicy("dns-policy")
- .withHostAliases(new HostAliasBuilder().withHostnames("hostname").build())
- .withImagePullSecrets(
- new LocalObjectReferenceBuilder().withName("local-reference").build())
- .withInitContainers(new ContainerBuilder().withName("init-container").build())
- .withNodeName("node-name")
- .withNodeSelector(Map("node-selector-key" -> "node-selector-value").asJava)
- .withSchedulerName("scheduler")
- .withNewSecurityContext()
- .withRunAsUser(1000L)
- .endSecurityContext()
- .withServiceAccount("service-account")
- .withSubdomain("subdomain")
- .withTolerations(new TolerationBuilder()
- .withKey("toleration-key")
- .withOperator("Equal")
- .withEffect("NoSchedule")
- .build())
- .addNewVolume()
- .withNewHostPath()
- .withPath("/test")
- .endHostPath()
- .withName("test-volume")
- .endVolume()
- .addNewContainer()
- .withArgs("arg")
- .withCommand("command")
- .addNewEnv()
- .withName("env-key")
- .withValue("env-value")
- .endEnv()
- .withImagePullPolicy("Always")
- .withName("executor-container")
- .withNewResources()
- .withLimits(Map("gpu" -> new QuantityBuilder().withAmount("1").build()).asJava)
- .endResources()
- .withNewSecurityContext()
- .withRunAsNonRoot(true)
- .endSecurityContext()
- .withStdin(true)
- .withTerminationMessagePath("termination-message-path")
- .withTerminationMessagePolicy("termination-message-policy")
- .addToVolumeMounts(
- new VolumeMountBuilder()
- .withName("test-volume")
- .withMountPath("/test")
- .build())
- .endContainer()
- .endSpec()
- .build()
-
-}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala
index 303e24b8f4977..278a3821a6f3d 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodsAllocatorSuite.scala
@@ -20,13 +20,13 @@ import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder}
import io.fabric8.kubernetes.client.KubernetesClient
import io.fabric8.kubernetes.client.dsl.PodResource
import org.mockito.{ArgumentMatcher, Matchers, Mock, MockitoAnnotations}
-import org.mockito.Matchers.any
+import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito.{never, times, verify, when}
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
import org.scalatest.BeforeAndAfter
-import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, KubernetesTestConf, SparkPod}
import org.apache.spark.deploy.k8s.Config._
import org.apache.spark.deploy.k8s.Constants._
@@ -52,6 +52,7 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter {
private val podAllocationSize = conf.get(KUBERNETES_ALLOCATION_BATCH_SIZE)
private val podAllocationDelay = conf.get(KUBERNETES_ALLOCATION_BATCH_DELAY)
private val podCreationTimeout = math.max(podAllocationDelay * 5, 60000L)
+ private val secMgr = new SecurityManager(conf)
private var waitForExecutorPodsClock: ManualClock = _
@@ -79,12 +80,12 @@ class ExecutorPodsAllocatorSuite extends SparkFunSuite with BeforeAndAfter {
when(kubernetesClient.pods()).thenReturn(podOperations)
when(podOperations.withName(driverPodName)).thenReturn(driverPodOperations)
when(driverPodOperations.get).thenReturn(driverPod)
- when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf])))
- .thenAnswer(executorPodAnswer())
+ when(executorBuilder.buildFromFeatures(any(classOf[KubernetesExecutorConf]), meq(secMgr),
+ meq(kubernetesClient))).thenAnswer(executorPodAnswer())
snapshotsStore = new DeterministicExecutorPodsSnapshotsStore()
waitForExecutorPodsClock = new ManualClock(0L)
podsAllocatorUnderTest = new ExecutorPodsAllocator(
- conf, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock)
+ conf, secMgr, executorBuilder, kubernetesClient, snapshotsStore, waitForExecutorPodsClock)
podsAllocatorUnderTest.start(TEST_SPARK_APP_ID)
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
index 52e7a12dbaf06..6e182bed459f8 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala
@@ -23,7 +23,7 @@ import org.mockito.Matchers.{eq => mockitoEq}
import org.mockito.Mockito.{never, verify, when}
import org.scalatest.BeforeAndAfter
-import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite}
import org.apache.spark.deploy.k8s.Constants._
import org.apache.spark.deploy.k8s.Fabric8Aliases._
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
@@ -37,10 +37,14 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn
private val requestExecutorsService = new DeterministicScheduler()
private val sparkConf = new SparkConf(false)
.set("spark.executor.instances", "3")
+ .set("spark.app.id", TEST_SPARK_APP_ID)
@Mock
private var sc: SparkContext = _
+ @Mock
+ private var env: SparkEnv = _
+
@Mock
private var rpcEnv: RpcEnv = _
@@ -81,23 +85,25 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn
MockitoAnnotations.initMocks(this)
when(taskScheduler.sc).thenReturn(sc)
when(sc.conf).thenReturn(sparkConf)
+ when(sc.env).thenReturn(env)
+ when(env.rpcEnv).thenReturn(rpcEnv)
driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint])
- when(rpcEnv.setupEndpoint(
- mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture()))
+ when(
+ rpcEnv.setupEndpoint(
+ mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME),
+ driverEndpoint.capture()))
.thenReturn(driverEndpointRef)
when(kubernetesClient.pods()).thenReturn(podOperations)
schedulerBackendUnderTest = new KubernetesClusterSchedulerBackend(
taskScheduler,
- rpcEnv,
+ sc,
kubernetesClient,
requestExecutorsService,
eventQueue,
podAllocator,
lifecycleEventHandler,
watchEvents,
- pollEvents) {
- override def applicationId(): String = TEST_SPARK_APP_ID
- }
+ pollEvents)
}
test("Start all components") {
@@ -122,8 +128,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn
test("Remove executor") {
schedulerBackendUnderTest.start()
- schedulerBackendUnderTest.doRemoveExecutor(
- "1", ExecutorKilled)
+ schedulerBackendUnderTest.doRemoveExecutor("1", ExecutorKilled)
verify(driverEndpointRef).send(RemoveExecutor("1", ExecutorKilled))
}
diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
index b6a75b15af85a..bd716174a8271 100644
--- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
+++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala
@@ -16,145 +16,23 @@
*/
package org.apache.spark.scheduler.cluster.k8s
-import scala.collection.JavaConverters._
-
-import io.fabric8.kubernetes.api.model.{Config => _, _}
import io.fabric8.kubernetes.client.KubernetesClient
-import org.mockito.Mockito.{mock, never, verify}
-import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.k8s._
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.features._
-import org.apache.spark.deploy.k8s.submit.PodBuilderSuiteUtils
-import org.apache.spark.util.SparkConfWithEnv
-
-class KubernetesExecutorBuilderSuite extends SparkFunSuite {
- private val BASIC_STEP_TYPE = "basic"
- private val SECRETS_STEP_TYPE = "mount-secrets"
- private val ENV_SECRETS_STEP_TYPE = "env-secrets"
- private val LOCAL_DIRS_STEP_TYPE = "local-dirs"
- private val HADOOP_CONF_STEP_TYPE = "hadoop-conf-step"
- private val HADOOP_SPARK_USER_STEP_TYPE = "hadoop-spark-user"
- private val KERBEROS_CONF_STEP_TYPE = "kerberos-step"
- private val MOUNT_VOLUMES_STEP_TYPE = "mount-volumes"
-
- private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep])
- private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep])
- private val envSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- ENV_SECRETS_STEP_TYPE, classOf[EnvSecretsFeatureStep])
- private val localDirsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- LOCAL_DIRS_STEP_TYPE, classOf[LocalDirsFeatureStep])
- private val hadoopConfStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- HADOOP_CONF_STEP_TYPE, classOf[HadoopConfExecutorFeatureStep])
- private val hadoopSparkUser = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- HADOOP_SPARK_USER_STEP_TYPE, classOf[HadoopSparkUserExecutorFeatureStep])
- private val kerberosConf = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- KERBEROS_CONF_STEP_TYPE, classOf[KerberosConfExecutorFeatureStep])
- private val mountVolumesStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType(
- MOUNT_VOLUMES_STEP_TYPE, classOf[MountVolumesFeatureStep])
-
- private val builderUnderTest = new KubernetesExecutorBuilder(
- _ => basicFeatureStep,
- _ => mountSecretsStep,
- _ => envSecretsStep,
- _ => localDirsStep,
- _ => mountVolumesStep,
- _ => hadoopConfStep,
- _ => kerberosConf,
- _ => hadoopSparkUser)
-
- test("Basic steps are consistently applied.") {
- val conf = KubernetesTestConf.createExecutorConf()
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE, LOCAL_DIRS_STEP_TYPE)
- }
-
- test("Apply secrets step if secrets are present.") {
- val conf = KubernetesTestConf.createExecutorConf(
- secretEnvNamesToKeyRefs = Map("secret-name" -> "secret-key"),
- secretNamesToMountPaths = Map("secret" -> "secretMountPath"))
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf),
- BASIC_STEP_TYPE,
- LOCAL_DIRS_STEP_TYPE,
- SECRETS_STEP_TYPE,
- ENV_SECRETS_STEP_TYPE)
- }
-
- test("Apply volumes step if mounts are present.") {
- val volumeSpec = KubernetesVolumeSpec(
- "volume",
- "/tmp",
- "",
- false,
- KubernetesHostPathVolumeConf("/checkpoint"))
- val conf = KubernetesTestConf.createExecutorConf(
- volumes = Seq(volumeSpec))
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf),
- BASIC_STEP_TYPE,
- LOCAL_DIRS_STEP_TYPE,
- MOUNT_VOLUMES_STEP_TYPE)
- }
+import org.apache.spark.internal.config.ConfigEntry
- test("Apply basicHadoop step if HADOOP_CONF_DIR is defined") {
- // HADOOP_DELEGATION_TOKEN
- val conf = KubernetesTestConf.createExecutorConf(
- sparkConf = new SparkConfWithEnv(Map("HADOOP_CONF_DIR" -> "/var/hadoop-conf"))
- .set(HADOOP_CONFIG_MAP_NAME, "hadoop-conf-map-name")
- .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name"))
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf),
- BASIC_STEP_TYPE,
- LOCAL_DIRS_STEP_TYPE,
- HADOOP_CONF_STEP_TYPE,
- HADOOP_SPARK_USER_STEP_TYPE)
- }
+class KubernetesExecutorBuilderSuite extends PodBuilderSuite {
- test("Apply kerberos step if DT secrets created") {
- val conf = KubernetesTestConf.createExecutorConf(
- sparkConf = new SparkConf(false)
- .set(HADOOP_CONFIG_MAP_NAME, "hadoop-conf-map-name")
- .set(KRB5_CONFIG_MAP_NAME, "krb5-conf-map-name")
- .set(KERBEROS_SPARK_USER_NAME, "spark-user")
- .set(KERBEROS_DT_SECRET_NAME, "dt-secret")
- .set(KERBEROS_DT_SECRET_KEY, "dt-key" ))
- validateStepTypesApplied(
- builderUnderTest.buildFromFeatures(conf),
- BASIC_STEP_TYPE,
- LOCAL_DIRS_STEP_TYPE,
- HADOOP_CONF_STEP_TYPE,
- KERBEROS_CONF_STEP_TYPE)
+ override protected def templateFileConf: ConfigEntry[_] = {
+ Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE
}
- private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = {
- assert(resolvedPod.pod.getMetadata.getLabels.asScala.keys.toSet === stepTypes.toSet)
+ override protected def buildPod(sparkConf: SparkConf, client: KubernetesClient): SparkPod = {
+ sparkConf.set("spark.driver.host", "https://driver.host.com")
+ val conf = KubernetesTestConf.createExecutorConf(sparkConf = sparkConf)
+ val secMgr = new SecurityManager(sparkConf)
+ new KubernetesExecutorBuilder().buildFromFeatures(conf, secMgr, client)
}
- test("Starts with empty executor pod if template is not specified") {
- val kubernetesClient = mock(classOf[KubernetesClient])
- val executorBuilder = KubernetesExecutorBuilder.apply(kubernetesClient, new SparkConf())
- verify(kubernetesClient, never()).pods()
- }
-
- test("Starts with executor template if specified") {
- val kubernetesClient = PodBuilderSuiteUtils.loadingMockKubernetesClient()
- val sparkConf = new SparkConf(false)
- .set("spark.driver.host", "https://driver.host.com")
- .set(Config.CONTAINER_IMAGE, "spark-executor:latest")
- .set(Config.KUBERNETES_EXECUTOR_PODTEMPLATE_FILE, "template-file.yaml")
- val kubernetesConf = KubernetesTestConf.createExecutorConf(
- sparkConf = sparkConf,
- driverPod = Some(new PodBuilder()
- .withNewMetadata()
- .withName("driver")
- .endMetadata()
- .build()))
- val sparkPod = KubernetesExecutorBuilder(kubernetesClient, sparkConf)
- .buildFromFeatures(kubernetesConf)
- PodBuilderSuiteUtils.verifyPodWithSupportedFeatures(sparkPod)
- }
}
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
index b746a01eb5294..f8f4b4177f3bd 100644
--- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/k8s/integrationtest/KubernetesSuite.scala
@@ -36,6 +36,7 @@ import org.apache.spark.{SPARK_VERSION, SparkFunSuite}
import org.apache.spark.deploy.k8s.integrationtest.TestConstants._
import org.apache.spark.deploy.k8s.integrationtest.backend.{IntegrationTestBackend, IntegrationTestBackendFactory}
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
class KubernetesSuite extends SparkFunSuite
with BeforeAndAfterAll with BeforeAndAfter with BasicTestsSuite with SecretsTestsSuite
@@ -138,6 +139,7 @@ class KubernetesSuite extends SparkFunSuite
.set("spark.kubernetes.driver.pod.name", driverPodName)
.set("spark.kubernetes.driver.label.spark-app-locator", appLocator)
.set("spark.kubernetes.executor.label.spark-app-locator", appLocator)
+ .set(NETWORK_AUTH_ENABLED.key, "true")
if (!kubernetesTestComponents.hasUserSpecifiedNamespace) {
kubernetesTestComponents.createNamespace()
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index c1f3211bcab29..e46c4f970c4a3 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -449,7 +449,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends
val ms = MetricsSystem.createMetricsSystem("applicationMaster", sparkConf, securityMgr)
val prefix = _sparkConf.get(YARN_METRICS_NAMESPACE).getOrElse(appId)
ms.registerSource(new ApplicationMasterSource(prefix, allocator))
- ms.start()
+ // do not register static sources in this case as per SPARK-25277
+ ms.start(false)
metricsSystem = Some(ms)
reporterThread = launchReporterThread()
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 6240f7b68d2c8..184fb6a8ad13e 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -116,6 +116,8 @@ private[spark] class Client(
}
}
+ require(keytab == null || !Utils.isLocalUri(keytab), "Keytab should reference a local file.")
+
private val launcherBackend = new LauncherBackend() {
override protected def conf: SparkConf = sparkConf
@@ -472,7 +474,7 @@ private[spark] class Client(
appMasterOnly: Boolean = false): (Boolean, String) = {
val trimmedPath = path.trim()
val localURI = Utils.resolveURI(trimmedPath)
- if (localURI.getScheme != LOCAL_SCHEME) {
+ if (localURI.getScheme != Utils.LOCAL_SCHEME) {
if (addDistributedUri(localURI)) {
val localPath = getQualifiedLocalPath(localURI, hadoopConf)
val linkname = targetDir.map(_ + "/").getOrElse("") +
@@ -515,7 +517,7 @@ private[spark] class Client(
val sparkArchive = sparkConf.get(SPARK_ARCHIVE)
if (sparkArchive.isDefined) {
val archive = sparkArchive.get
- require(!isLocalUri(archive), s"${SPARK_ARCHIVE.key} cannot be a local URI.")
+ require(!Utils.isLocalUri(archive), s"${SPARK_ARCHIVE.key} cannot be a local URI.")
distribute(Utils.resolveURI(archive).toString,
resType = LocalResourceType.ARCHIVE,
destName = Some(LOCALIZED_LIB_DIR))
@@ -525,7 +527,7 @@ private[spark] class Client(
// Break the list of jars to upload, and resolve globs.
val localJars = new ArrayBuffer[String]()
jars.foreach { jar =>
- if (!isLocalUri(jar)) {
+ if (!Utils.isLocalUri(jar)) {
val path = getQualifiedLocalPath(Utils.resolveURI(jar), hadoopConf)
val pathFs = FileSystem.get(path.toUri(), hadoopConf)
pathFs.globStatus(path).filter(_.isFile()).foreach { entry =>
@@ -814,7 +816,7 @@ private[spark] class Client(
}
(pySparkArchives ++ pyArchives).foreach { path =>
val uri = Utils.resolveURI(path)
- if (uri.getScheme != LOCAL_SCHEME) {
+ if (uri.getScheme != Utils.LOCAL_SCHEME) {
pythonPath += buildPath(Environment.PWD.$$(), new Path(uri).getName())
} else {
pythonPath += uri.getPath()
@@ -1183,9 +1185,6 @@ private object Client extends Logging {
// Alias for the user jar
val APP_JAR_NAME: String = "__app__.jar"
- // URI scheme that identifies local resources
- val LOCAL_SCHEME = "local"
-
// Staging directory for any temporary jars or files
val SPARK_STAGING: String = ".sparkStaging"
@@ -1307,7 +1306,7 @@ private object Client extends Logging {
addClasspathEntry(buildPath(Environment.PWD.$$(), LOCALIZED_LIB_DIR, "*"), env)
if (sparkConf.get(SPARK_ARCHIVE).isEmpty) {
sparkConf.get(SPARK_JARS).foreach { jars =>
- jars.filter(isLocalUri).foreach { jar =>
+ jars.filter(Utils.isLocalUri).foreach { jar =>
val uri = new URI(jar)
addClasspathEntry(getClusterPath(sparkConf, uri.getPath()), env)
}
@@ -1340,7 +1339,7 @@ private object Client extends Logging {
private def getMainJarUri(mainJar: Option[String]): Option[URI] = {
mainJar.flatMap { path =>
val uri = Utils.resolveURI(path)
- if (uri.getScheme == LOCAL_SCHEME) Some(uri) else None
+ if (uri.getScheme == Utils.LOCAL_SCHEME) Some(uri) else None
}.orElse(Some(new URI(APP_JAR_NAME)))
}
@@ -1368,7 +1367,7 @@ private object Client extends Logging {
uri: URI,
fileName: String,
env: HashMap[String, String]): Unit = {
- if (uri != null && uri.getScheme == LOCAL_SCHEME) {
+ if (uri != null && uri.getScheme == Utils.LOCAL_SCHEME) {
addClasspathEntry(getClusterPath(conf, uri.getPath), env)
} else if (fileName != null) {
addClasspathEntry(buildPath(Environment.PWD.$$(), fileName), env)
@@ -1489,11 +1488,6 @@ private object Client extends Logging {
components.mkString(Path.SEPARATOR)
}
- /** Returns whether the URI is a "local:" URI. */
- def isLocalUri(uri: String): Boolean = {
- uri.startsWith(s"$LOCAL_SCHEME:")
- }
-
def createAppReport(report: ApplicationReport): YarnAppReport = {
val diags = report.getDiagnostics()
val diagsOpt = if (diags != null && diags.nonEmpty) Some(diags) else None
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 9497530805c1a..54b1ec266113f 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -127,7 +127,7 @@ private[yarn] class YarnAllocator(
private var numUnexpectedContainerRelease = 0L
private val containerIdToExecutorId = new HashMap[ContainerId, String]
- // Executor memory in MB.
+ // Executor memory in MiB.
protected val executorMemory = sparkConf.get(EXECUTOR_MEMORY).toInt
// Additional memory overhead.
protected val memoryOverhead: Int = sparkConf.get(EXECUTOR_MEMORY_OVERHEAD).getOrElse(
@@ -294,6 +294,15 @@ private[yarn] class YarnAllocator(
s"pending: $numPendingAllocate, running: ${runningExecutors.size}, " +
s"executorsStarting: ${numExecutorsStarting.get}")
+ // Split the pending container request into three groups: locality matched list, locality
+ // unmatched list and non-locality list. Take the locality matched container request into
+ // consideration of container placement, treat as allocated containers.
+ // For locality unmatched and locality free container requests, cancel these container
+ // requests, since required locality preference has been changed, recalculating using
+ // container placement strategy.
+ val (localRequests, staleRequests, anyHostRequests) = splitPendingAllocationsByLocality(
+ hostToLocalTaskCounts, pendingAllocate)
+
if (missing > 0) {
if (log.isInfoEnabled()) {
var requestContainerMessage = s"Will request $missing executor container(s), each with " +
@@ -306,15 +315,6 @@ private[yarn] class YarnAllocator(
logInfo(requestContainerMessage)
}
- // Split the pending container request into three groups: locality matched list, locality
- // unmatched list and non-locality list. Take the locality matched container request into
- // consideration of container placement, treat as allocated containers.
- // For locality unmatched and locality free container requests, cancel these container
- // requests, since required locality preference has been changed, recalculating using
- // container placement strategy.
- val (localRequests, staleRequests, anyHostRequests) = splitPendingAllocationsByLocality(
- hostToLocalTaskCounts, pendingAllocate)
-
// cancel "stale" requests for locations that are no longer needed
staleRequests.foreach { stale =>
amClient.removeContainerRequest(stale)
@@ -374,14 +374,9 @@ private[yarn] class YarnAllocator(
val numToCancel = math.min(numPendingAllocate, -missing)
logInfo(s"Canceling requests for $numToCancel executor container(s) to have a new desired " +
s"total $targetNumExecutors executors.")
-
- val matchingRequests = amClient.getMatchingRequests(RM_REQUEST_PRIORITY, ANY_HOST, resource)
- if (!matchingRequests.isEmpty) {
- matchingRequests.iterator().next().asScala
- .take(numToCancel).foreach(amClient.removeContainerRequest)
- } else {
- logWarning("Expected to find pending requests, but found none.")
- }
+ // cancel pending allocate requests by taking locality preference into account
+ val cancelRequests = (staleRequests ++ anyHostRequests ++ localRequests).take(numToCancel)
+ cancelRequests.foreach(amClient.removeContainerRequest)
}
}
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 67c36aac49266..1289d4be79ea4 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -168,8 +168,10 @@ private[spark] abstract class YarnSchedulerBackend(
filterName != null && filterName.nonEmpty &&
filterParams != null && filterParams.nonEmpty
if (hasFilter) {
+ // SPARK-26255: Append user provided filters(spark.ui.filters) with yarn filter.
+ val allFilters = filterName + "," + conf.get("spark.ui.filters", "")
logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
- conf.set("spark.ui.filters", filterName)
+ conf.set("spark.ui.filters", allFilters)
filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index b3286e8fd824e..a6f57fcdb2461 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -100,7 +100,7 @@ class ClientSuite extends SparkFunSuite with Matchers {
val cp = env("CLASSPATH").split(":|;|")
s"$SPARK,$USER,$ADDED".split(",").foreach({ entry =>
val uri = new URI(entry)
- if (LOCAL_SCHEME.equals(uri.getScheme())) {
+ if (Utils.LOCAL_SCHEME.equals(uri.getScheme())) {
cp should contain (uri.getPath())
} else {
cp should not contain (uri.getPath())
@@ -136,7 +136,7 @@ class ClientSuite extends SparkFunSuite with Matchers {
val expected = ADDED.split(",")
.map(p => {
val uri = new URI(p)
- if (LOCAL_SCHEME == uri.getScheme()) {
+ if (Utils.LOCAL_SCHEME == uri.getScheme()) {
p
} else {
Option(uri.getFragment()).getOrElse(new File(p).getName())
@@ -249,7 +249,7 @@ class ClientSuite extends SparkFunSuite with Matchers {
any(classOf[MutableHashMap[URI, Path]]), anyBoolean(), any())
classpath(client) should contain (buildPath(PWD, LOCALIZED_LIB_DIR, "*"))
- sparkConf.set(SPARK_ARCHIVE, LOCAL_SCHEME + ":" + archive.getPath())
+ sparkConf.set(SPARK_ARCHIVE, Utils.LOCAL_SCHEME + ":" + archive.getPath())
intercept[IllegalArgumentException] {
client.prepareLocalResources(new Path(temp.getAbsolutePath()), Nil)
}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
index 95263a0da95a8..7553ab8cf7000 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeWriter.java
@@ -198,11 +198,46 @@ protected final void writeLong(long offset, long value) {
Platform.putLong(getBuffer(), offset, value);
}
+ // We need to take care of NaN and -0.0 in several places:
+ // 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be
+ // treated as same.
+ // 2. In GROUP BY, different NaNs should belong to the same group, -0.0 and 0.0 should belong
+ // to the same group.
+ // 3. As join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be
+ // treated as same.
+ // 4. As window partition keys, different NaNs should be treated as same, `-0.0` and `0.0`
+ // should be treated as same.
+ //
+ // Case 1 is fine, as we handle NaN and -0.0 well during comparison. For complex types, we
+ // recursively compare the fields/elements, so it's also fine.
+ //
+ // Case 2, 3 and 4 are problematic, as they compare `UnsafeRow` binary directly, and different
+ // NaNs have different binary representation, and the same thing happens for -0.0 and 0.0.
+ //
+ // Here we normalize NaN and -0.0, so that `UnsafeProjection` will normalize them when writing
+ // float/double columns and nested fields to `UnsafeRow`.
+ //
+ // Note that, we must do this for all the `UnsafeProjection`s, not only the ones that extract
+ // join/grouping/window partition keys. `UnsafeProjection` copies unsafe data directly for complex
+ // types, so nested float/double may not be normalized. We need to make sure that all the unsafe
+ // data(`UnsafeRow`, `UnsafeArrayData`, `UnsafeMapData`) will have flat/double normalized during
+ // creation.
protected final void writeFloat(long offset, float value) {
+ if (Float.isNaN(value)) {
+ value = Float.NaN;
+ } else if (value == -0.0f) {
+ value = 0.0f;
+ }
Platform.putFloat(getBuffer(), offset, value);
}
+ // See comments for `writeFloat`.
protected final void writeDouble(long offset, double value) {
+ if (Double.isNaN(value)) {
+ value = Double.NaN;
+ } else if (value == -0.0d) {
+ value = 0.0d;
+ }
Platform.putDouble(getBuffer(), offset, value);
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index e12bf9616e2de..4f5af9ac80b10 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -57,6 +57,7 @@ object Row {
/**
* Merge multiple rows into a single row, one after another.
*/
+ @deprecated("This method is deprecated and will be removed in future versions.", "3.0.0")
def merge(rows: Row*): Row = {
// TODO: Improve the performance of this if used in performance critical part.
new GenericRow(rows.flatMap(_.toSeq).toArray)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 9f6ccc441f06c..c28a97839fe49 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -135,11 +135,6 @@ trait CheckAnalysis extends PredicateHelper {
failAnalysis("An offset window function can only be evaluated in an ordered " +
s"row-based window frame with a single offset: $w")
- case _ @ WindowExpression(_: PythonUDF,
- WindowSpecDefinition(_, _, frame: SpecifiedWindowFrame))
- if !frame.isUnbounded =>
- failAnalysis("Only unbounded window frame is supported with Pandas UDFs.")
-
case w @ WindowExpression(e, s) =>
// Only allow window functions with an aggregate expression or an offset window
// function or a Pandas window UDF.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index d209bb0b7c8e7..b19aa50ba2156 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -879,6 +879,37 @@ object TypeCoercion {
}
}
e.withNewChildren(children)
+
+ case udf: ScalaUDF if udf.inputTypes.nonEmpty =>
+ val children = udf.children.zip(udf.inputTypes).map { case (in, expected) =>
+ implicitCast(in, udfInputToCastType(in.dataType, expected)).getOrElse(in)
+ }
+ udf.withNewChildren(children)
+ }
+
+ private def udfInputToCastType(input: DataType, expectedType: DataType): DataType = {
+ (input, expectedType) match {
+ // SPARK-26308: avoid casting to an arbitrary precision and scale for decimals. Please note
+ // that precision and scale cannot be inferred properly for a ScalaUDF because, when it is
+ // created, it is not bound to any column. So here the precision and scale of the input
+ // column is used.
+ case (in: DecimalType, _: DecimalType) => in
+ case (ArrayType(dtIn, _), ArrayType(dtExp, nullableExp)) =>
+ ArrayType(udfInputToCastType(dtIn, dtExp), nullableExp)
+ case (MapType(keyDtIn, valueDtIn, _), MapType(keyDtExp, valueDtExp, nullableExp)) =>
+ MapType(udfInputToCastType(keyDtIn, keyDtExp),
+ udfInputToCastType(valueDtIn, valueDtExp),
+ nullableExp)
+ case (StructType(fieldsIn), StructType(fieldsExp)) =>
+ val fieldTypes =
+ fieldsIn.map(_.dataType).zip(fieldsExp.map(_.dataType)).map { case (dtIn, dtExp) =>
+ udfInputToCastType(dtIn, dtExp)
+ }
+ StructType(fieldsExp.zip(fieldTypes).map { case (field, newDt) =>
+ field.copy(dataType = newDt)
+ })
+ case (_, other) => other
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
index a8a7bbd9f9cd0..1cd7f412bb678 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/higherOrderFunctions.scala
@@ -150,13 +150,14 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
val lambdaMap = l.arguments.map(v => canonicalizer(v.name) -> v).toMap
l.mapChildren(resolve(_, parentLambdaMap ++ lambdaMap))
- case u @ UnresolvedAttribute(name +: nestedFields) =>
+ case u @ UnresolvedNamedLambdaVariable(name +: nestedFields) =>
parentLambdaMap.get(canonicalizer(name)) match {
case Some(lambda) =>
nestedFields.foldLeft(lambda: Expression) { (expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), conf.resolver)
}
- case None => u
+ case None =>
+ UnresolvedAttribute(u.nameParts)
}
case _ =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
index 345dc4d41993e..35ade136cc607 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/CSVInferSchema.scala
@@ -22,13 +22,13 @@ import scala.util.control.Exception.allCatch
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions.ExprUtils
-import org.apache.spark.sql.catalyst.util.DateTimeFormatter
+import org.apache.spark.sql.catalyst.util.TimestampFormatter
import org.apache.spark.sql.types._
class CSVInferSchema(val options: CSVOptions) extends Serializable {
@transient
- private lazy val timeParser = DateTimeFormatter(
+ private lazy val timestampParser = TimestampFormatter(
options.timestampFormat,
options.timeZone,
options.locale)
@@ -160,7 +160,7 @@ class CSVInferSchema(val options: CSVOptions) extends Serializable {
private def tryParseTimestamp(field: String): DataType = {
// This case infers a custom `dataFormat` is set.
- if ((allCatch opt timeParser.parse(field)).isDefined) {
+ if ((allCatch opt timestampParser.parse(field)).isDefined) {
TimestampType
} else {
tryParseBoolean(field)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala
index af09cd6c8449b..f012d96138f37 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityGenerator.scala
@@ -22,7 +22,7 @@ import java.io.Writer
import com.univocity.parsers.csv.CsvWriter
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeFormatter}
+import org.apache.spark.sql.catalyst.util.{DateFormatter, TimestampFormatter}
import org.apache.spark.sql.types._
class UnivocityGenerator(
@@ -41,18 +41,18 @@ class UnivocityGenerator(
private val valueConverters: Array[ValueConverter] =
schema.map(_.dataType).map(makeConverter).toArray
- private val timeFormatter = DateTimeFormatter(
+ private val timestampFormatter = TimestampFormatter(
options.timestampFormat,
options.timeZone,
options.locale)
- private val dateFormatter = DateFormatter(options.dateFormat, options.timeZone, options.locale)
+ private val dateFormatter = DateFormatter(options.dateFormat, options.locale)
private def makeConverter(dataType: DataType): ValueConverter = dataType match {
case DateType =>
(row: InternalRow, ordinal: Int) => dateFormatter.format(row.getInt(ordinal))
case TimestampType =>
- (row: InternalRow, ordinal: Int) => timeFormatter.format(row.getLong(ordinal))
+ (row: InternalRow, ordinal: Int) => timestampFormatter.format(row.getLong(ordinal))
case udt: UserDefinedType[_] => makeConverter(udt.sqlType)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
index 0f375e036029c..82a5b3c302b18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/csv/UnivocityParser.scala
@@ -74,11 +74,11 @@ class UnivocityParser(
private val row = new GenericInternalRow(requiredSchema.length)
- private val timeFormatter = DateTimeFormatter(
+ private val timestampFormatter = TimestampFormatter(
options.timestampFormat,
options.timeZone,
options.locale)
- private val dateFormatter = DateFormatter(options.dateFormat, options.timeZone, options.locale)
+ private val dateFormatter = DateFormatter(options.dateFormat, options.locale)
// Retrieve the raw record string.
private def getCurrentInput: UTF8String = {
@@ -158,7 +158,7 @@ class UnivocityParser(
}
case _: TimestampType => (d: String) =>
- nullSafeDatum(d, name, nullable, options)(timeFormatter.parse)
+ nullSafeDatum(d, name, nullable, options)(timestampFormatter.parse)
case _: DateType => (d: String) =>
nullSafeDatum(d, name, nullable, options)(dateFormatter.parse)
@@ -239,6 +239,7 @@ class UnivocityParser(
} catch {
case NonFatal(e) =>
badRecordException = badRecordException.orElse(Some(e))
+ row.setNullAt(i)
}
i += 1
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 176ea823b1fcd..151481c80ee96 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -136,7 +136,7 @@ package object dsl {
implicit def longToLiteral(l: Long): Literal = Literal(l)
implicit def floatToLiteral(f: Float): Literal = Literal(f)
implicit def doubleToLiteral(d: Double): Literal = Literal(d)
- implicit def stringToLiteral(s: String): Literal = Literal(s)
+ implicit def stringToLiteral(s: String): Literal = Literal.create(s, StringType)
implicit def dateToLiteral(d: Date): Literal = Literal(d)
implicit def bigDecimalToLiteral(d: BigDecimal): Literal = Literal(d.underlying())
implicit def bigDecimalToLiteral(d: java.math.BigDecimal): Literal = Literal(d)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index d905f8f9858e8..8ca3d356f3bdc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -106,12 +106,12 @@ object RowEncoder {
returnNullable = false)
case d: DecimalType =>
- StaticInvoke(
+ CheckOverflow(StaticInvoke(
Decimal.getClass,
d,
"fromDecimal",
inputObject :: Nil,
- returnNullable = false)
+ returnNullable = false), d)
case StringType =>
StaticInvoke(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index fae90caebf96c..a23aaa3a0b3ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -52,7 +52,7 @@ case class ScalaUDF(
udfName: Option[String] = None,
nullable: Boolean = true,
udfDeterministic: Boolean = true)
- extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression {
+ extends Expression with NonSQLExpression with UserDefinedExpression {
// The constructor for SPARK 2.1 and 2.2
def this(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index fa8e38acd522d..67f6739b1e18f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -554,13 +554,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
return null
}
- val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements())
- if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
- throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " +
- s"elements due to exceeding the map size limit " +
- s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
- }
-
for (map <- maps) {
mapBuilder.putAll(map.keyArray(), map.valueArray())
}
@@ -569,8 +562,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val mapCodes = children.map(_.genCode(ctx))
- val keyType = dataType.keyType
- val valueType = dataType.valueType
val argsName = ctx.freshName("args")
val hasNullName = ctx.freshName("hasNull")
val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder)
@@ -610,41 +601,12 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
)
val idxName = ctx.freshName("idx")
- val numElementsName = ctx.freshName("numElems")
- val finKeysName = ctx.freshName("finalKeys")
- val finValsName = ctx.freshName("finalValues")
-
- val keyConcat = genCodeForArrays(ctx, keyType, false)
-
- val valueConcat =
- if (valueType.sameType(keyType) &&
- !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
- keyConcat
- } else {
- genCodeForArrays(ctx, valueType, dataType.valueContainsNull)
- }
-
- val keyArgsName = ctx.freshName("keyArgs")
- val valArgsName = ctx.freshName("valArgs")
-
val mapMerge =
s"""
- |ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}];
- |ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}];
- |long $numElementsName = 0;
|for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
- | $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
- | $valArgsName[$idxName] = $argsName[$idxName].valueArray();
- | $numElementsName += $argsName[$idxName].numElements();
+ | $builderTerm.putAll($argsName[$idxName].keyArray(), $argsName[$idxName].valueArray());
|}
- |if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
- | throw new RuntimeException("Unsuccessful attempt to concat maps with " +
- | $numElementsName + " elements due to exceeding the map size limit " +
- | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
- |}
- |ArrayData $finKeysName = $keyConcat($keyArgsName, (int) $numElementsName);
- |ArrayData $finValsName = $valueConcat($valArgsName, (int) $numElementsName);
- |${ev.value} = $builderTerm.from($finKeysName, $finValsName);
+ |${ev.value} = $builderTerm.build();
""".stripMargin
ev.copy(
@@ -660,41 +622,6 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
""".stripMargin)
}
- private def genCodeForArrays(
- ctx: CodegenContext,
- elementType: DataType,
- checkForNull: Boolean): String = {
- val counter = ctx.freshName("counter")
- val arrayData = ctx.freshName("arrayData")
- val argsName = ctx.freshName("args")
- val numElemName = ctx.freshName("numElements")
- val y = ctx.freshName("y")
- val z = ctx.freshName("z")
-
- val allocation = CodeGenerator.createArrayData(
- arrayData, elementType, numElemName, s" $prettyName failed.")
- val assignment = CodeGenerator.createArrayAssignment(
- arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull)
-
- val concat = ctx.freshName("concat")
- val concatDef =
- s"""
- |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
- | $allocation
- | int $counter = 0;
- | for (int $y = 0; $y < ${children.length}; $y++) {
- | for (int $z = 0; $z < $argsName[$y].numElements(); $z++) {
- | $assignment
- | $counter++;
- | }
- | }
- | return $arrayData;
- |}
- """.stripMargin
-
- ctx.addNewFunction(concat, concatDef)
- }
-
override def prettyName: String = "map_concat"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
index 75eea1223a854..e6cc11d1ad280 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala
@@ -22,12 +22,34 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, UnresolvedAttribute, UnresolvedException}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods
+/**
+ * A placeholder of lambda variables to prevent unexpected resolution of [[LambdaFunction]].
+ */
+case class UnresolvedNamedLambdaVariable(nameParts: Seq[String])
+ extends LeafExpression with NamedExpression with Unevaluable {
+
+ override def name: String =
+ nameParts.map(n => if (n.contains(".")) s"`$n`" else n).mkString(".")
+
+ override def exprId: ExprId = throw new UnresolvedException(this, "exprId")
+ override def dataType: DataType = throw new UnresolvedException(this, "dataType")
+ override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+ override def qualifier: Seq[String] = throw new UnresolvedException(this, "qualifier")
+ override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute")
+ override def newInstance(): NamedExpression = throw new UnresolvedException(this, "newInstance")
+ override lazy val resolved = false
+
+ override def toString: String = s"lambda '$name"
+
+ override def sql: String = name
+}
+
/**
* A named lambda variable.
*/
@@ -81,7 +103,7 @@ case class LambdaFunction(
object LambdaFunction {
val identity: LambdaFunction = {
- val id = UnresolvedAttribute.quoted("id")
+ val id = UnresolvedNamedLambdaVariable(Seq("id"))
LambdaFunction(id, Seq(id))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 0083ee64653e9..bf18e8bcb52df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -101,7 +101,7 @@ package object expressions {
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
}
- // It's possible that `attrs` is a linked list, which can lead to bad O(n^2) loops when
+ // It's possible that `attrs` is a linked list, which can lead to bad O(n) loops when
// accessing attributes by their ordinals. To avoid this performance penalty, convert the input
// to an array.
@transient private lazy val attrsArray = attrs.toArray
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
index e10b8a327c01a..eaff3fa7bec25 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JSONOptions.scala
@@ -21,7 +21,6 @@ import java.nio.charset.{Charset, StandardCharsets}
import java.util.{Locale, TimeZone}
import com.fasterxml.jackson.core.{JsonFactory, JsonParser}
-import org.apache.commons.lang3.time.FastDateFormat
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util._
@@ -82,13 +81,10 @@ private[sql] class JSONOptions(
val timeZone: TimeZone = DateTimeUtils.getTimeZone(
parameters.getOrElse(DateTimeUtils.TIMEZONE_OPTION, defaultTimeZoneId))
- // Uses `FastDateFormat` which can be direct replacement for `SimpleDateFormat` and thread-safe.
- val dateFormat: FastDateFormat =
- FastDateFormat.getInstance(parameters.getOrElse("dateFormat", "yyyy-MM-dd"), locale)
+ val dateFormat: String = parameters.getOrElse("dateFormat", "yyyy-MM-dd")
- val timestampFormat: FastDateFormat =
- FastDateFormat.getInstance(
- parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), timeZone, locale)
+ val timestampFormat: String =
+ parameters.getOrElse("timestampFormat", "yyyy-MM-dd'T'HH:mm:ss.SSSXXX")
val multiLine = parameters.get("multiLine").map(_.toBoolean).getOrElse(false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
index d02a2be8ddad6..951f5190cd504 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonGenerator.scala
@@ -23,7 +23,7 @@ import com.fasterxml.jackson.core._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
-import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData}
+import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
/**
@@ -77,6 +77,12 @@ private[sql] class JacksonGenerator(
private val lineSeparator: String = options.lineSeparatorInWrite
+ private val timestampFormatter = TimestampFormatter(
+ options.timestampFormat,
+ options.timeZone,
+ options.locale)
+ private val dateFormatter = DateFormatter(options.dateFormat, options.locale)
+
private def makeWriter(dataType: DataType): ValueWriter = dataType match {
case NullType =>
(row: SpecializedGetters, ordinal: Int) =>
@@ -116,14 +122,12 @@ private[sql] class JacksonGenerator(
case TimestampType =>
(row: SpecializedGetters, ordinal: Int) =>
- val timestampString =
- options.timestampFormat.format(DateTimeUtils.toJavaTimestamp(row.getLong(ordinal)))
+ val timestampString = timestampFormatter.format(row.getLong(ordinal))
gen.writeString(timestampString)
case DateType =>
(row: SpecializedGetters, ordinal: Int) =>
- val dateString =
- options.dateFormat.format(DateTimeUtils.toJavaDate(row.getInt(ordinal)))
+ val dateString = dateFormatter.format(row.getInt(ordinal))
gen.writeString(dateString)
case BinaryType =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
index 2357595906b11..3f245e1400fa1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala
@@ -22,6 +22,7 @@ import java.nio.charset.MalformedInputException
import scala.collection.mutable.ArrayBuffer
import scala.util.Try
+import scala.util.control.NonFatal
import com.fasterxml.jackson.core._
@@ -29,7 +30,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
@@ -55,6 +55,12 @@ class JacksonParser(
private val factory = new JsonFactory()
options.setJacksonOptions(factory)
+ private val timestampFormatter = TimestampFormatter(
+ options.timestampFormat,
+ options.timeZone,
+ options.locale)
+ private val dateFormatter = DateFormatter(options.dateFormat, options.locale)
+
/**
* Create a converter which converts the JSON documents held by the `JsonParser`
* to a value according to a desired schema. This is a wrapper for the method
@@ -218,17 +224,7 @@ class JacksonParser(
case TimestampType =>
(parser: JsonParser) => parseJsonToken[java.lang.Long](parser, dataType) {
case VALUE_STRING if parser.getTextLength >= 1 =>
- val stringValue = parser.getText
- // This one will lose microseconds parts.
- // See https://issues.apache.org/jira/browse/SPARK-10681.
- Long.box {
- Try(options.timestampFormat.parse(stringValue).getTime * 1000L)
- .getOrElse {
- // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
- // compatibility.
- DateTimeUtils.stringToTime(stringValue).getTime * 1000L
- }
- }
+ timestampFormatter.parse(parser.getText)
case VALUE_NUMBER_INT =>
parser.getLongValue * 1000000L
@@ -237,22 +233,7 @@ class JacksonParser(
case DateType =>
(parser: JsonParser) => parseJsonToken[java.lang.Integer](parser, dataType) {
case VALUE_STRING if parser.getTextLength >= 1 =>
- val stringValue = parser.getText
- // This one will lose microseconds parts.
- // See https://issues.apache.org/jira/browse/SPARK-10681.x
- Int.box {
- Try(DateTimeUtils.millisToDays(options.dateFormat.parse(stringValue).getTime))
- .orElse {
- // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
- // compatibility.
- Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(stringValue).getTime))
- }
- .getOrElse {
- // In Spark 1.5.0, we store the data as number of days since epoch in string.
- // So, we just convert it to Int.
- stringValue.toInt
- }
- }
+ dateFormatter.parse(parser.getText)
}
case BinaryType =>
@@ -347,17 +328,28 @@ class JacksonParser(
schema: StructType,
fieldConverters: Array[ValueConverter]): InternalRow = {
val row = new GenericInternalRow(schema.length)
+ var badRecordException: Option[Throwable] = None
+
while (nextUntil(parser, JsonToken.END_OBJECT)) {
schema.getFieldIndex(parser.getCurrentName) match {
case Some(index) =>
- row.update(index, fieldConverters(index).apply(parser))
-
+ try {
+ row.update(index, fieldConverters(index).apply(parser))
+ } catch {
+ case NonFatal(e) =>
+ badRecordException = badRecordException.orElse(Some(e))
+ parser.skipChildren()
+ }
case None =>
parser.skipChildren()
}
}
- row
+ if (badRecordException.isEmpty) {
+ row
+ } else {
+ throw PartialResultException(row, badRecordException.get)
+ }
}
/**
@@ -428,6 +420,11 @@ class JacksonParser(
val wrappedCharException = new CharConversionException(msg)
wrappedCharException.initCause(e)
throw BadRecordException(() => recordLiteral(record), () => None, wrappedCharException)
+ case PartialResultException(row, cause) =>
+ throw BadRecordException(
+ record = () => recordLiteral(record),
+ partialResult = () => Some(row),
+ cause)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
index 263e05de32075..d1bc00c08c1c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala
@@ -28,7 +28,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.TypeCoercion
import org.apache.spark.sql.catalyst.expressions.ExprUtils
import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil
-import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode}
+import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -37,6 +37,12 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable {
private val decimalParser = ExprUtils.getDecimalParser(options.locale)
+ @transient
+ private lazy val timestampFormatter = TimestampFormatter(
+ options.timestampFormat,
+ options.timeZone,
+ options.locale)
+
/**
* Infer the type of a collection of json records in three stages:
* 1. Infer the type of each record
@@ -115,13 +121,19 @@ private[sql] class JsonInferSchema(options: JSONOptions) extends Serializable {
// record fields' types have been combined.
NullType
- case VALUE_STRING if options.prefersDecimal =>
+ case VALUE_STRING =>
+ val field = parser.getText
val decimalTry = allCatch opt {
- val bigDecimal = decimalParser(parser.getText)
+ val bigDecimal = decimalParser(field)
DecimalType(bigDecimal.precision, bigDecimal.scale)
}
- decimalTry.getOrElse(StringType)
- case VALUE_STRING => StringType
+ if (options.prefersDecimal && decimalTry.isDefined) {
+ decimalTry.get
+ } else if ((allCatch opt timestampFormatter.parse(field)).isDefined) {
+ TimestampType
+ } else {
+ StringType
+ }
case START_OBJECT =>
val builder = Array.newBuilder[StructField]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
index 064ca68b7a628..01634a9d852c6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala
@@ -48,6 +48,7 @@ object CostBasedJoinReorder extends Rule[LogicalPlan] with PredicateHelper {
if projectList.forall(_.isInstanceOf[Attribute]) =>
reorder(p, p.output)
}
+
// After reordering is finished, convert OrderedJoin back to Join
result transformDown {
case OrderedJoin(left, right, jt, cond) => Join(left, right, jt, cond)
@@ -175,11 +176,20 @@ object JoinReorderDP extends PredicateHelper with Logging {
assert(topOutputSet == p.outputSet)
// Keep the same order of final output attributes.
p.copy(projectList = output)
+ case finalPlan if !sameOutput(finalPlan, output) =>
+ Project(output, finalPlan)
case finalPlan =>
finalPlan
}
}
+ private def sameOutput(plan: LogicalPlan, expectedOutput: Seq[Attribute]): Boolean = {
+ val thisOutput = plan.output
+ thisOutput.length == expectedOutput.length && thisOutput.zip(expectedOutput).forall {
+ case (a1, a2) => a1.semanticEquals(a2)
+ }
+ }
+
/** Find all possible plans at the next level, based on existing levels. */
private def searchLevel(
existingLevels: Seq[JoinPlanMap],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 8d251eeab8484..44d5543114902 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -73,6 +73,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
CombineLimits,
CombineUnions,
// Constant folding and strength reduction
+ TransposeWindow,
NullPropagation,
ConstantPropagation,
FoldablePropagation,
@@ -92,7 +93,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
RewriteCorrelatedScalarSubquery,
EliminateSerialization,
RemoveRedundantAliases,
- RemoveRedundantProject,
+ RemoveNoopOperators,
SimplifyExtractValueOps,
CombineConcats) ++
extendedOperatorOptimizationRules
@@ -131,11 +132,11 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
// since the other rules might make two separate Unions operators adjacent.
Batch("Union", Once,
CombineUnions) ::
- // run this once earlier. this might simplify the plan and reduce cost of optimizer.
- // for example, a query such as Filter(LocalRelation) would go through all the heavy
+ // Run this once earlier. This might simplify the plan and reduce cost of optimizer.
+ // For example, a query such as Filter(LocalRelation) would go through all the heavy
// optimizer rules that are triggered when there is a filter
- // (e.g. InferFiltersFromConstraints). if we run this batch earlier, the query becomes just
- // LocalRelation and does not trigger many rules
+ // (e.g. InferFiltersFromConstraints). If we run this batch earlier, the query becomes just
+ // LocalRelation and does not trigger many rules.
Batch("LocalRelation early", fixedPoint,
ConvertToLocalRelation,
PropagateEmptyRelation) ::
@@ -176,7 +177,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
RewritePredicateSubquery,
ColumnPruning,
CollapseProject,
- RemoveRedundantProject) :+
+ RemoveNoopOperators) :+
Batch("UpdateAttributeReferences", Once,
UpdateNullabilityInAttributeReferences)
}
@@ -402,11 +403,15 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
}
/**
- * Remove projections from the query plan that do not make any modifications.
+ * Remove no-op operators from the query plan that do not make any modifications.
*/
-object RemoveRedundantProject extends Rule[LogicalPlan] {
+object RemoveNoopOperators extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p @ Project(_, child) if p.output == child.output => child
+ // Eliminate no-op Projects
+ case p @ Project(_, child) if child.sameOutput(p) => child
+
+ // Eliminate no-op Window
+ case w: Window if w.windowExpressions.isEmpty => w.child
}
}
@@ -601,17 +606,12 @@ object ColumnPruning extends Rule[LogicalPlan] {
p.copy(child = w.copy(
windowExpressions = w.windowExpressions.filter(p.references.contains)))
- // Eliminate no-op Window
- case w: Window if w.windowExpressions.isEmpty => w.child
-
- // Eliminate no-op Projects
- case p @ Project(_, child) if child.sameOutput(p) => child
-
// Can't prune the columns on LeafNode
case p @ Project(_, _: LeafNode) => p
// for all other logical plans that inherits the output from it's children
- case p @ Project(_, child) =>
+ // Project over project is handled by the first case, skip it here.
+ case p @ Project(_, child) if !child.isInstanceOf[Project] =>
val required = child.references ++ p.references
if (!child.inputSet.subsetOf(required)) {
val newChildren = child.children.map(c => prunedChild(c, required))
@@ -1370,10 +1370,8 @@ object DecimalAggregates extends Rule[LogicalPlan] {
}
/**
- * Converts local operations (i.e. ones that don't require data exchange) on LocalRelation to
- * another LocalRelation.
- *
- * This is relatively simple as it currently handles only 2 single case: Project and Limit.
+ * Converts local operations (i.e. ones that don't require data exchange) on `LocalRelation` to
+ * another `LocalRelation`.
*/
object ConvertToLocalRelation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
index efd3944eba7f5..4996d24dfd298 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceExceptWithFilter.scala
@@ -36,7 +36,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
* Note:
* Before flipping the filter condition of the right node, we should:
* 1. Combine all it's [[Filter]].
- * 2. Apply InferFiltersFromConstraints rule (to take into account of NULL values in the condition).
+ * 2. Update the attribute references to the left node;
+ * 3. Add a Coalesce(condition, False) (to take into account of NULL values in the condition).
*/
object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
@@ -47,23 +48,28 @@ object ReplaceExceptWithFilter extends Rule[LogicalPlan] {
plan.transform {
case e @ Except(left, right, false) if isEligible(left, right) =>
- val newCondition = transformCondition(left, skipProject(right))
- newCondition.map { c =>
- Distinct(Filter(Not(c), left))
- }.getOrElse {
+ val filterCondition = combineFilters(skipProject(right)).asInstanceOf[Filter].condition
+ if (filterCondition.deterministic) {
+ transformCondition(left, filterCondition).map { c =>
+ Distinct(Filter(Not(c), left))
+ }.getOrElse {
+ e
+ }
+ } else {
e
}
}
}
- private def transformCondition(left: LogicalPlan, right: LogicalPlan): Option[Expression] = {
- val filterCondition =
- InferFiltersFromConstraints(combineFilters(right)).asInstanceOf[Filter].condition
-
- val attributeNameMap: Map[String, Attribute] = left.output.map(x => (x.name, x)).toMap
-
- if (filterCondition.references.forall(r => attributeNameMap.contains(r.name))) {
- Some(filterCondition.transform { case a: AttributeReference => attributeNameMap(a.name) })
+ private def transformCondition(plan: LogicalPlan, condition: Expression): Option[Expression] = {
+ val attributeNameMap: Map[String, Attribute] = plan.output.map(x => (x.name, x)).toMap
+ if (condition.references.forall(r => attributeNameMap.contains(r.name))) {
+ val rewrittenCondition = condition.transform {
+ case a: AttributeReference => attributeNameMap(a.name)
+ }
+ // We need to consider as False when the condition is NULL, otherwise we do not return those
+ // rows containing NULL which are instead filtered in the Except right plan
+ Some(Coalesce(Seq(rewrittenCondition, Literal.FalseLiteral)))
} else {
None
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
index 6ebb194d71c2e..0b6471289a471 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala
@@ -86,9 +86,9 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case ExtractFiltersAndInnerJoins(input, conditions)
+ case p @ ExtractFiltersAndInnerJoins(input, conditions)
if input.size > 2 && conditions.nonEmpty =>
- if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
+ val reordered = if (SQLConf.get.starSchemaDetection && !SQLConf.get.cboEnabled) {
val starJoinPlan = StarSchemaDetection.reorderStarJoins(input, conditions)
if (starJoinPlan.nonEmpty) {
val rest = input.filterNot(starJoinPlan.contains(_))
@@ -99,6 +99,14 @@ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper {
} else {
createOrderedJoin(input, conditions)
}
+
+ if (p.sameOutput(reordered)) {
+ reordered
+ } else {
+ // Reordering the joins have changed the order of the columns.
+ // Inject a projection to make sure we restore to the expected ordering.
+ Project(p.output, reordered)
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index e9b7a8b76e683..34840c6c977a6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -43,31 +43,53 @@ import org.apache.spark.sql.types._
* condition.
*/
object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
- private def dedupJoin(joinPlan: LogicalPlan): LogicalPlan = joinPlan match {
+
+ private def buildJoin(
+ outerPlan: LogicalPlan,
+ subplan: LogicalPlan,
+ joinType: JoinType,
+ condition: Option[Expression]): Join = {
+ // Deduplicate conflicting attributes if any.
+ val dedupSubplan = dedupSubqueryOnSelfJoin(outerPlan, subplan, None, condition)
+ Join(outerPlan, dedupSubplan, joinType, condition)
+ }
+
+ private def dedupSubqueryOnSelfJoin(
+ outerPlan: LogicalPlan,
+ subplan: LogicalPlan,
+ valuesOpt: Option[Seq[Expression]],
+ condition: Option[Expression] = None): LogicalPlan = {
// SPARK-21835: It is possibly that the two sides of the join have conflicting attributes,
// the produced join then becomes unresolved and break structural integrity. We should
- // de-duplicate conflicting attributes. We don't use transformation here because we only
- // care about the most top join converted from correlated predicate subquery.
- case j @ Join(left, right, joinType @ (LeftSemi | LeftAnti | ExistenceJoin(_)), joinCond) =>
- val duplicates = right.outputSet.intersect(left.outputSet)
- if (duplicates.nonEmpty) {
- val aliasMap = AttributeMap(duplicates.map { dup =>
- dup -> Alias(dup, dup.toString)()
- }.toSeq)
- val aliasedExpressions = right.output.map { ref =>
- aliasMap.getOrElse(ref, ref)
- }
- val newRight = Project(aliasedExpressions, right)
- val newJoinCond = joinCond.map { condExpr =>
- condExpr transform {
- case a: Attribute => aliasMap.getOrElse(a, a).toAttribute
+ // de-duplicate conflicting attributes.
+ // SPARK-26078: it may also happen that the subquery has conflicting attributes with the outer
+ // values. In this case, the resulting join would contain trivially true conditions (eg.
+ // id#3 = id#3) which cannot be de-duplicated after. In this method, if there are conflicting
+ // attributes in the join condition, the subquery's conflicting attributes are changed using
+ // a projection which aliases them and resolves the problem.
+ val outerReferences = valuesOpt.map(values =>
+ AttributeSet.fromAttributeSets(values.map(_.references))).getOrElse(AttributeSet.empty)
+ val outerRefs = outerPlan.outputSet ++ outerReferences
+ val duplicates = outerRefs.intersect(subplan.outputSet)
+ if (duplicates.nonEmpty) {
+ condition.foreach { e =>
+ val conflictingAttrs = e.references.intersect(duplicates)
+ if (conflictingAttrs.nonEmpty) {
+ throw new AnalysisException("Found conflicting attributes " +
+ s"${conflictingAttrs.mkString(",")} in the condition joining outer plan:\n " +
+ s"$outerPlan\nand subplan:\n $subplan")
}
- }
- Join(left, newRight, joinType, newJoinCond)
- } else {
- j
}
- case _ => joinPlan
+ val rewrites = AttributeMap(duplicates.map { dup =>
+ dup -> Alias(dup, dup.toString)()
+ }.toSeq)
+ val aliasedExpressions = subplan.output.map { ref =>
+ rewrites.getOrElse(ref, ref)
+ }
+ Project(aliasedExpressions, subplan)
+ } else {
+ subplan
+ }
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -85,17 +107,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
withSubquery.foldLeft(newFilter) {
case (p, Exists(sub, conditions, _)) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
- // Deduplicate conflicting attributes if any.
- dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
+ buildJoin(outerPlan, sub, LeftSemi, joinCond)
case (p, Not(Exists(sub, conditions, _))) =>
val (joinCond, outerPlan) = rewriteExistentialExpr(conditions, p)
- // Deduplicate conflicting attributes if any.
- dedupJoin(Join(outerPlan, sub, LeftAnti, joinCond))
+ buildJoin(outerPlan, sub, LeftAnti, joinCond)
case (p, InSubquery(values, ListQuery(sub, conditions, _, _))) =>
- val inConditions = values.zip(sub.output).map(EqualTo.tupled)
- val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
// Deduplicate conflicting attributes if any.
- dedupJoin(Join(outerPlan, sub, LeftSemi, joinCond))
+ val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
+ val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
+ val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions ++ conditions, p)
+ Join(outerPlan, newSub, LeftSemi, joinCond)
case (p, Not(InSubquery(values, ListQuery(sub, conditions, _, _)))) =>
// This is a NULL-aware (left) anti join (NAAJ) e.g. col NOT IN expr
// Construct the condition. A NULL in one of the conditions is regarded as a positive
@@ -103,7 +124,10 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// Note that will almost certainly be planned as a Broadcast Nested Loop join.
// Use EXISTS if performance matters to you.
- val inConditions = values.zip(sub.output).map(EqualTo.tupled)
+
+ // Deduplicate conflicting attributes if any.
+ val newSub = dedupSubqueryOnSelfJoin(p, sub, Some(values))
+ val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
val (joinCond, outerPlan) = rewriteExistentialExpr(inConditions, p)
// Expand the NOT IN expression with the NULL-aware semantic
// to its full form. That is from:
@@ -118,8 +142,7 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
// will have the final conditions in the LEFT ANTI as
// (A.A1 = B.B1 OR ISNULL(A.A1 = B.B1)) AND (B.B2 = A.A2) AND B.B3 > 1
val finalJoinCond = (nullAwareJoinConds ++ conditions).reduceLeft(And)
- // Deduplicate conflicting attributes if any.
- dedupJoin(Join(outerPlan, sub, LeftAnti, Option(finalJoinCond)))
+ Join(outerPlan, newSub, LeftAnti, Option(finalJoinCond))
case (p, predicate) =>
val (newCond, inputPlan) = rewriteExistentialExpr(Seq(predicate), p)
Project(p.output, Filter(newCond.get, inputPlan))
@@ -140,16 +163,16 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
e transformUp {
case Exists(sub, conditions, _) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
- // Deduplicate conflicting attributes if any.
- newPlan = dedupJoin(
- Join(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And)))
+ newPlan =
+ buildJoin(newPlan, sub, ExistenceJoin(exists), conditions.reduceLeftOption(And))
exists
case InSubquery(values, ListQuery(sub, conditions, _, _)) =>
val exists = AttributeReference("exists", BooleanType, nullable = false)()
- val inConditions = values.zip(sub.output).map(EqualTo.tupled)
- val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
// Deduplicate conflicting attributes if any.
- newPlan = dedupJoin(Join(newPlan, sub, ExistenceJoin(exists), newConditions))
+ val newSub = dedupSubqueryOnSelfJoin(newPlan, sub, Some(values))
+ val inConditions = values.zip(newSub.output).map(EqualTo.tupled)
+ val newConditions = (inConditions ++ conditions).reduceLeftOption(And)
+ newPlan = Join(newPlan, newSub, ExistenceJoin(exists), newConditions)
exists
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 672bffcfc0cad..8959f78b656d2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -1338,9 +1338,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
*/
override def visitLambda(ctx: LambdaContext): Expression = withOrigin(ctx) {
val arguments = ctx.IDENTIFIER().asScala.map { name =>
- UnresolvedAttribute.quoted(name.getText)
+ UnresolvedNamedLambdaVariable(UnresolvedAttribute.quoted(name.getText).nameParts)
}
- LambdaFunction(expression(ctx.expression), arguments)
+ val function = expression(ctx.expression).transformUp {
+ case a: UnresolvedAttribute => UnresolvedNamedLambdaVariable(a.nameParts)
+ }
+ LambdaFunction(function, arguments)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 5e78aabc480bf..51e0f4b4c84dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -93,7 +93,7 @@ abstract class LogicalPlan
/**
* Optionally resolves the given strings to a [[NamedExpression]] using the input from all child
* nodes of this LogicalPlan. The attribute is expressed as
- * as string in the following form: `[scope].AttributeName.[nested].[fields]...`.
+ * string in the following form: `[scope].AttributeName.[nested].[fields]...`.
*/
def resolveChildren(
nameParts: Seq[String],
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index cc1a5e835d9cd..17e1cb416fc8a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -22,13 +22,11 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
/**
* Specifies how tuples that share common expressions will be distributed when a query is executed
- * in parallel on many machines. Distribution can be used to refer to two distinct physical
- * properties:
- * - Inter-node partitioning of data: In this case the distribution describes how tuples are
- * partitioned across physical machines in a cluster. Knowing this property allows some
- * operators (e.g., Aggregate) to perform partition local operations instead of global ones.
- * - Intra-partition ordering of data: In this case the distribution describes guarantees made
- * about how tuples are distributed within a single partition.
+ * in parallel on many machines.
+ *
+ * Distribution here refers to inter-node partitioning of data. That is, it describes how tuples
+ * are partitioned across physical machines in a cluster. Knowing this property allows some
+ * operators (e.g., Aggregate) to perform partition local operations instead of global ones.
*/
sealed trait Distribution {
/**
@@ -70,9 +68,7 @@ case object AllTuples extends Distribution {
/**
* Represents data where tuples that share the same values for the `clustering`
- * [[Expression Expressions]] will be co-located. Based on the context, this
- * can mean such tuples are either co-located in the same partition or they will be contiguous
- * within a single partition.
+ * [[Expression Expressions]] will be co-located in the same partition.
*/
case class ClusteredDistribution(
clustering: Seq[Expression],
@@ -118,10 +114,12 @@ case class HashClusteredDistribution(
/**
* Represents data where tuples have been ordered according to the `ordering`
- * [[Expression Expressions]]. This is a strictly stronger guarantee than
- * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the
- * same value for the ordering expressions are contiguous and will never be split across
- * partitions.
+ * [[Expression Expressions]]. Its requirement is defined as the following:
+ * - Given any 2 adjacent partitions, all the rows of the second partition must be larger than or
+ * equal to any row in the first partition, according to the `ordering` expressions.
+ *
+ * In other words, this distribution requires the rows to be ordered across partitions, but not
+ * necessarily within a partition.
*/
case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
require(
@@ -241,12 +239,12 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
/**
* Represents a partitioning where rows are split across partitions based on some total ordering of
- * the expressions specified in `ordering`. When data is partitioned in this manner the following
- * two conditions are guaranteed to hold:
- * - All row where the expressions in `ordering` evaluate to the same values will be in the same
- * partition.
- * - Each partition will have a `min` and `max` row, relative to the given ordering. All rows
- * that are in between `min` and `max` in this `ordering` will reside in this partition.
+ * the expressions specified in `ordering`. When data is partitioned in this manner, it guarantees:
+ * Given any 2 adjacent partitions, all the rows of the second partition must be larger than any row
+ * in the first partition, according to the `ordering` expressions.
+ *
+ * This is a strictly stronger guarantee than what `OrderedDistribution(ordering)` requires, as
+ * there is no overlap between partitions.
*
* This class extends expression primarily so that transformations over expression will descend
* into its child.
@@ -262,6 +260,22 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
super.satisfies0(required) || {
required match {
case OrderedDistribution(requiredOrdering) =>
+ // If `ordering` is a prefix of `requiredOrdering`:
+ // Let's say `ordering` is [a, b] and `requiredOrdering` is [a, b, c]. According to the
+ // RangePartitioning definition, any [a, b] in a previous partition must be smaller
+ // than any [a, b] in the following partition. This also means any [a, b, c] in a
+ // previous partition must be smaller than any [a, b, c] in the following partition.
+ // Thus `RangePartitioning(a, b)` satisfies `OrderedDistribution(a, b, c)`.
+ //
+ // If `requiredOrdering` is a prefix of `ordering`:
+ // Let's say `ordering` is [a, b, c] and `requiredOrdering` is [a, b]. According to the
+ // RangePartitioning definition, any [a, b, c] in a previous partition must be smaller
+ // than any [a, b, c] in the following partition. If there is a [a1, b1] from a previous
+ // partition which is larger than a [a2, b2] from the following partition, then there
+ // must be a [a1, b1 c1] larger than [a2, b2, c2], which violates RangePartitioning
+ // definition. So it's guaranteed that, any [a, b] in a previous partition must not be
+ // greater(i.e. smaller or equal to) than any [a, b] in the following partition. Thus
+ // `RangePartitioning(a, b, c)` satisfies `OrderedDistribution(a, b)`.
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering, _) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
index e7cd61655dc9a..98934368205ec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala
@@ -21,6 +21,7 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.array.ByteArrayMethods
/**
* A builder of [[ArrayBasedMapData]], which fails if a null map key is detected, and removes
@@ -54,6 +55,10 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
val index = keyToIndex.getOrDefault(key, -1)
if (index == -1) {
+ if (size >= ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful attempt to build maps with $size elements " +
+ s"due to exceeding the map size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
+ }
keyToIndex.put(key, values.length)
keys.append(key)
values.append(value)
@@ -117,4 +122,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
build()
}
}
+
+ /**
+ * Returns the current size of the map which is going to be produced by the current builder.
+ */
+ def size: Int = keys.size
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala
index 985f0dc1cd60e..d719a33929fcc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala
@@ -20,6 +20,16 @@ package org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.UTF8String
+/**
+ * Exception thrown when the underlying parser returns a partial result of parsing.
+ * @param partialResult the partial result of parsing a bad record.
+ * @param cause the actual exception about why the parser cannot return full result.
+ */
+case class PartialResultException(
+ partialResult: InternalRow,
+ cause: Throwable)
+ extends Exception(cause)
+
/**
* Exception thrown when the underlying parser meet a bad record and can't parse it.
* @param record a function to return the record that cause the parser to fail
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala
new file mode 100644
index 0000000000000..9e8d51cc65f03
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateFormatter.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import java.time.{Instant, ZoneId}
+import java.util.Locale
+
+import scala.util.Try
+
+import org.apache.commons.lang3.time.FastDateFormat
+
+import org.apache.spark.sql.internal.SQLConf
+
+sealed trait DateFormatter {
+ def parse(s: String): Int // returns days since epoch
+ def format(days: Int): String
+}
+
+class Iso8601DateFormatter(
+ pattern: String,
+ locale: Locale) extends DateFormatter with DateTimeFormatterHelper {
+
+ private val formatter = buildFormatter(pattern, locale)
+ private val UTC = ZoneId.of("UTC")
+
+ private def toInstant(s: String): Instant = {
+ val temporalAccessor = formatter.parse(s)
+ toInstantWithZoneId(temporalAccessor, UTC)
+ }
+
+ override def parse(s: String): Int = {
+ val seconds = toInstant(s).getEpochSecond
+ val days = Math.floorDiv(seconds, DateTimeUtils.SECONDS_PER_DAY)
+ days.toInt
+ }
+
+ override def format(days: Int): String = {
+ val instant = Instant.ofEpochSecond(days * DateTimeUtils.SECONDS_PER_DAY)
+ formatter.withZone(UTC).format(instant)
+ }
+}
+
+class LegacyDateFormatter(pattern: String, locale: Locale) extends DateFormatter {
+ private val format = FastDateFormat.getInstance(pattern, locale)
+
+ override def parse(s: String): Int = {
+ val milliseconds = format.parse(s).getTime
+ DateTimeUtils.millisToDays(milliseconds)
+ }
+
+ override def format(days: Int): String = {
+ val date = DateTimeUtils.toJavaDate(days)
+ format.format(date)
+ }
+}
+
+class LegacyFallbackDateFormatter(
+ pattern: String,
+ locale: Locale) extends LegacyDateFormatter(pattern, locale) {
+ override def parse(s: String): Int = {
+ Try(super.parse(s)).orElse {
+ // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
+ // compatibility.
+ Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(s).getTime))
+ }.getOrElse {
+ // In Spark 1.5.0, we store the data as number of days since epoch in string.
+ // So, we just convert it to Int.
+ s.toInt
+ }
+ }
+}
+
+object DateFormatter {
+ def apply(format: String, locale: Locale): DateFormatter = {
+ if (SQLConf.get.legacyTimeParserEnabled) {
+ new LegacyFallbackDateFormatter(format, locale)
+ } else {
+ new Iso8601DateFormatter(format, locale)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala
deleted file mode 100644
index ad1f4131de2f6..0000000000000
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatter.scala
+++ /dev/null
@@ -1,179 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.util
-
-import java.time._
-import java.time.format.DateTimeFormatterBuilder
-import java.time.temporal.{ChronoField, TemporalQueries}
-import java.util.{Locale, TimeZone}
-
-import scala.util.Try
-
-import org.apache.commons.lang3.time.FastDateFormat
-
-import org.apache.spark.sql.internal.SQLConf
-
-sealed trait DateTimeFormatter {
- def parse(s: String): Long // returns microseconds since epoch
- def format(us: Long): String
-}
-
-class Iso8601DateTimeFormatter(
- pattern: String,
- timeZone: TimeZone,
- locale: Locale) extends DateTimeFormatter {
- val formatter = new DateTimeFormatterBuilder()
- .appendPattern(pattern)
- .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970)
- .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1)
- .parseDefaulting(ChronoField.DAY_OF_MONTH, 1)
- .parseDefaulting(ChronoField.HOUR_OF_DAY, 0)
- .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0)
- .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0)
- .toFormatter(locale)
-
- def toInstant(s: String): Instant = {
- val temporalAccessor = formatter.parse(s)
- if (temporalAccessor.query(TemporalQueries.offset()) == null) {
- val localDateTime = LocalDateTime.from(temporalAccessor)
- val zonedDateTime = ZonedDateTime.of(localDateTime, timeZone.toZoneId)
- Instant.from(zonedDateTime)
- } else {
- Instant.from(temporalAccessor)
- }
- }
-
- private def instantToMicros(instant: Instant): Long = {
- val sec = Math.multiplyExact(instant.getEpochSecond, DateTimeUtils.MICROS_PER_SECOND)
- val result = Math.addExact(sec, instant.getNano / DateTimeUtils.NANOS_PER_MICROS)
- result
- }
-
- def parse(s: String): Long = instantToMicros(toInstant(s))
-
- def format(us: Long): String = {
- val secs = Math.floorDiv(us, DateTimeUtils.MICROS_PER_SECOND)
- val mos = Math.floorMod(us, DateTimeUtils.MICROS_PER_SECOND)
- val instant = Instant.ofEpochSecond(secs, mos * DateTimeUtils.NANOS_PER_MICROS)
-
- formatter.withZone(timeZone.toZoneId).format(instant)
- }
-}
-
-class LegacyDateTimeFormatter(
- pattern: String,
- timeZone: TimeZone,
- locale: Locale) extends DateTimeFormatter {
- val format = FastDateFormat.getInstance(pattern, timeZone, locale)
-
- protected def toMillis(s: String): Long = format.parse(s).getTime
-
- def parse(s: String): Long = toMillis(s) * DateTimeUtils.MICROS_PER_MILLIS
-
- def format(us: Long): String = {
- format.format(DateTimeUtils.toJavaTimestamp(us))
- }
-}
-
-class LegacyFallbackDateTimeFormatter(
- pattern: String,
- timeZone: TimeZone,
- locale: Locale) extends LegacyDateTimeFormatter(pattern, timeZone, locale) {
- override def toMillis(s: String): Long = {
- Try {super.toMillis(s)}.getOrElse(DateTimeUtils.stringToTime(s).getTime)
- }
-}
-
-object DateTimeFormatter {
- def apply(format: String, timeZone: TimeZone, locale: Locale): DateTimeFormatter = {
- if (SQLConf.get.legacyTimeParserEnabled) {
- new LegacyFallbackDateTimeFormatter(format, timeZone, locale)
- } else {
- new Iso8601DateTimeFormatter(format, timeZone, locale)
- }
- }
-}
-
-sealed trait DateFormatter {
- def parse(s: String): Int // returns days since epoch
- def format(days: Int): String
-}
-
-class Iso8601DateFormatter(
- pattern: String,
- timeZone: TimeZone,
- locale: Locale) extends DateFormatter {
-
- val dateTimeFormatter = new Iso8601DateTimeFormatter(pattern, timeZone, locale)
-
- override def parse(s: String): Int = {
- val seconds = dateTimeFormatter.toInstant(s).getEpochSecond
- val days = Math.floorDiv(seconds, DateTimeUtils.SECONDS_PER_DAY)
-
- days.toInt
- }
-
- override def format(days: Int): String = {
- val instant = Instant.ofEpochSecond(days * DateTimeUtils.SECONDS_PER_DAY)
- dateTimeFormatter.formatter.withZone(timeZone.toZoneId).format(instant)
- }
-}
-
-class LegacyDateFormatter(
- pattern: String,
- timeZone: TimeZone,
- locale: Locale) extends DateFormatter {
- val format = FastDateFormat.getInstance(pattern, timeZone, locale)
-
- def parse(s: String): Int = {
- val milliseconds = format.parse(s).getTime
- DateTimeUtils.millisToDays(milliseconds)
- }
-
- def format(days: Int): String = {
- val date = DateTimeUtils.toJavaDate(days)
- format.format(date)
- }
-}
-
-class LegacyFallbackDateFormatter(
- pattern: String,
- timeZone: TimeZone,
- locale: Locale) extends LegacyDateFormatter(pattern, timeZone, locale) {
- override def parse(s: String): Int = {
- Try(super.parse(s)).orElse {
- // If it fails to parse, then tries the way used in 2.0 and 1.x for backwards
- // compatibility.
- Try(DateTimeUtils.millisToDays(DateTimeUtils.stringToTime(s).getTime))
- }.getOrElse {
- // In Spark 1.5.0, we store the data as number of days since epoch in string.
- // So, we just convert it to Int.
- s.toInt
- }
- }
-}
-
-object DateFormatter {
- def apply(format: String, timeZone: TimeZone, locale: Locale): DateFormatter = {
- if (SQLConf.get.legacyTimeParserEnabled) {
- new LegacyFallbackDateFormatter(format, timeZone, locale)
- } else {
- new Iso8601DateFormatter(format, timeZone, locale)
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
new file mode 100644
index 0000000000000..b85101d38d9e6
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeFormatterHelper.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import java.time.{Instant, LocalDateTime, ZonedDateTime, ZoneId}
+import java.time.format.{DateTimeFormatter, DateTimeFormatterBuilder}
+import java.time.temporal.{ChronoField, TemporalAccessor}
+import java.util.Locale
+
+trait DateTimeFormatterHelper {
+
+ protected def buildFormatter(pattern: String, locale: Locale): DateTimeFormatter = {
+ new DateTimeFormatterBuilder()
+ .appendPattern(pattern)
+ .parseDefaulting(ChronoField.YEAR_OF_ERA, 1970)
+ .parseDefaulting(ChronoField.MONTH_OF_YEAR, 1)
+ .parseDefaulting(ChronoField.DAY_OF_MONTH, 1)
+ .parseDefaulting(ChronoField.HOUR_OF_DAY, 0)
+ .parseDefaulting(ChronoField.MINUTE_OF_HOUR, 0)
+ .parseDefaulting(ChronoField.SECOND_OF_MINUTE, 0)
+ .toFormatter(locale)
+ }
+
+ protected def toInstantWithZoneId(temporalAccessor: TemporalAccessor, zoneId: ZoneId): Instant = {
+ val localDateTime = LocalDateTime.from(temporalAccessor)
+ val zonedDateTime = ZonedDateTime.of(localDateTime, zoneId)
+ Instant.from(zonedDateTime)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
new file mode 100644
index 0000000000000..eb1303303463d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TimestampFormatter.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import java.time._
+import java.time.temporal.TemporalQueries
+import java.util.{Locale, TimeZone}
+
+import scala.util.Try
+
+import org.apache.commons.lang3.time.FastDateFormat
+
+import org.apache.spark.sql.internal.SQLConf
+
+sealed trait TimestampFormatter {
+ def parse(s: String): Long // returns microseconds since epoch
+ def format(us: Long): String
+}
+
+class Iso8601TimestampFormatter(
+ pattern: String,
+ timeZone: TimeZone,
+ locale: Locale) extends TimestampFormatter with DateTimeFormatterHelper {
+ private val formatter = buildFormatter(pattern, locale)
+
+ private def toInstant(s: String): Instant = {
+ val temporalAccessor = formatter.parse(s)
+ if (temporalAccessor.query(TemporalQueries.offset()) == null) {
+ toInstantWithZoneId(temporalAccessor, timeZone.toZoneId)
+ } else {
+ Instant.from(temporalAccessor)
+ }
+ }
+
+ private def instantToMicros(instant: Instant): Long = {
+ val sec = Math.multiplyExact(instant.getEpochSecond, DateTimeUtils.MICROS_PER_SECOND)
+ val result = Math.addExact(sec, instant.getNano / DateTimeUtils.NANOS_PER_MICROS)
+ result
+ }
+
+ override def parse(s: String): Long = instantToMicros(toInstant(s))
+
+ override def format(us: Long): String = {
+ val secs = Math.floorDiv(us, DateTimeUtils.MICROS_PER_SECOND)
+ val mos = Math.floorMod(us, DateTimeUtils.MICROS_PER_SECOND)
+ val instant = Instant.ofEpochSecond(secs, mos * DateTimeUtils.NANOS_PER_MICROS)
+
+ formatter.withZone(timeZone.toZoneId).format(instant)
+ }
+}
+
+class LegacyTimestampFormatter(
+ pattern: String,
+ timeZone: TimeZone,
+ locale: Locale) extends TimestampFormatter {
+ private val format = FastDateFormat.getInstance(pattern, timeZone, locale)
+
+ protected def toMillis(s: String): Long = format.parse(s).getTime
+
+ override def parse(s: String): Long = toMillis(s) * DateTimeUtils.MICROS_PER_MILLIS
+
+ override def format(us: Long): String = {
+ format.format(DateTimeUtils.toJavaTimestamp(us))
+ }
+}
+
+class LegacyFallbackTimestampFormatter(
+ pattern: String,
+ timeZone: TimeZone,
+ locale: Locale) extends LegacyTimestampFormatter(pattern, timeZone, locale) {
+ override def toMillis(s: String): Long = {
+ Try {super.toMillis(s)}.getOrElse(DateTimeUtils.stringToTime(s).getTime)
+ }
+}
+
+object TimestampFormatter {
+ def apply(format: String, timeZone: TimeZone, locale: Locale): TimestampFormatter = {
+ if (SQLConf.get.legacyTimeParserEnabled) {
+ new LegacyFallbackTimestampFormatter(format, timeZone, locale)
+ } else {
+ new Iso8601TimestampFormatter(format, timeZone, locale)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 451b051f8407e..86e068bf632bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -1396,6 +1396,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val VALIDATE_PARTITION_COLUMNS =
+ buildConf("spark.sql.sources.validatePartitionColumns")
+ .internal()
+ .doc("When this option is set to true, partition column values will be validated with " +
+ "user-specified schema. If the validation fails, a runtime exception is thrown." +
+ "When this option is set to false, the partition column value will be converted to null " +
+ "if it can not be casted to corresponding user-specified schema.")
+ .booleanConf
+ .createWithDefault(true)
+
val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE =
buildConf("spark.sql.streaming.continuous.executorQueueSize")
.internal()
@@ -1611,8 +1621,8 @@ object SQLConf {
.intConf
.createWithDefault(25)
- val SET_COMMAND_REJECTS_SPARK_CONFS =
- buildConf("spark.sql.legacy.execution.setCommandRejectsSparkConfs")
+ val SET_COMMAND_REJECTS_SPARK_CORE_CONFS =
+ buildConf("spark.sql.legacy.setCommandRejectsSparkCoreConfs")
.internal()
.doc("If it is set to true, SET command will fail when the key is registered as " +
"a SparkConf entry.")
@@ -2014,6 +2024,8 @@ class SQLConf extends Serializable with Logging {
def allowCreatingManagedTableUsingNonemptyLocation: Boolean =
getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION)
+ def validatePartitionColumns: Boolean = getConf(VALIDATE_PARTITION_COLUMNS)
+
def partitionOverwriteMode: PartitionOverwriteMode.Value =
PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE))
@@ -2045,7 +2057,8 @@ class SQLConf extends Serializable with Logging {
def maxToStringFields: Int = getConf(SQLConf.MAX_TO_STRING_FIELDS)
- def setCommandRejectsSparkConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CONFS)
+ def setCommandRejectsSparkCoreConfs: Boolean =
+ getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS)
def legacyTimeParserEnabled: Boolean = getConf(SQLConf.LEGACY_TIME_PARSER_ENABLED)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
index c4171c75ecd03..a5847ba7c522d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveLambdaVariablesSuite.scala
@@ -49,19 +49,21 @@ class ResolveLambdaVariablesSuite extends PlanTest {
comparePlans(Analyzer.execute(plan(e1)), plan(e2))
}
+ private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
test("resolution - no op") {
checkExpression(key, key)
}
test("resolution - simple") {
- val in = ArrayTransform(values1, LambdaFunction('x.attr + 1, 'x.attr :: Nil))
+ val in = ArrayTransform(values1, LambdaFunction(lv('x) + 1, lv('x) :: Nil))
val out = ArrayTransform(values1, LambdaFunction(lvInt + 1, lvInt :: Nil))
checkExpression(in, out)
}
test("resolution - nested") {
val in = ArrayTransform(values2, LambdaFunction(
- ArrayTransform('x.attr, LambdaFunction('x.attr + 1, 'x.attr :: Nil)), 'x.attr :: Nil))
+ ArrayTransform(lv('x), LambdaFunction(lv('x) + 1, lv('x) :: Nil)), lv('x) :: Nil))
val out = ArrayTransform(values2, LambdaFunction(
ArrayTransform(lvArray, LambdaFunction(lvInt + 1, lvInt :: Nil)), lvArray :: Nil))
checkExpression(in, out)
@@ -75,14 +77,14 @@ class ResolveLambdaVariablesSuite extends PlanTest {
test("fail - name collisions") {
val p = plan(ArrayTransform(values1,
- LambdaFunction('x.attr + 'X.attr, 'x.attr :: 'X.attr :: Nil)))
+ LambdaFunction(lv('x) + lv('X), lv('x) :: lv('X) :: Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("arguments should not have names that are semantically the same"))
}
test("fail - lambda arguments") {
val p = plan(ArrayTransform(values1,
- LambdaFunction('x.attr + 'y.attr + 'z.attr, 'x.attr :: 'y.attr :: 'z.attr :: Nil)))
+ LambdaFunction(lv('x) + lv('y) + lv('z), lv('x) :: lv('y) :: lv('z) :: Nil)))
val msg = intercept[AnalysisException](Analyzer.execute(p)).getMessage
assert(msg.contains("does not match the number of arguments expected"))
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index b4fd170467d81..1c91adab71375 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.analysis.{ResolveTimeZone, SimpleAnalyzer}
+import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.PlanTestBase
@@ -70,7 +70,9 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
private def prepareEvaluation(expression: Expression): Expression = {
val serializer = new JavaSerializer(new SparkConf()).newInstance
val resolver = ResolveTimeZone(new SQLConf)
- resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression)))
+ val expr = resolver.resolveTimeZones(expression)
+ assert(expr.resolved)
+ serializer.deserialize(serializer.serialize(expr))
}
protected def checkEvaluation(
@@ -296,9 +298,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks with PlanTestBa
expected: Any,
inputRow: InternalRow = EmptyRow): Unit = {
val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation())
- // We should analyze the plan first, otherwise we possibly optimize an unresolved plan.
- val analyzedPlan = SimpleAnalyzer.execute(plan)
- val optimizedPlan = SimpleTestOptimizer.execute(analyzedPlan)
+ val optimizedPlan = SimpleTestOptimizer.execute(plan)
checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
index 5d60cefc13896..238e6e34b4ae5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.plans.PlanTestBase
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.internal.SQLConf
@@ -694,11 +694,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
val mapType2 = MapType(IntegerType, CalendarIntervalType)
val schema2 = StructType(StructField("a", mapType2) :: Nil)
val struct2 = Literal.create(null, schema2)
- intercept[TreeNodeException[_]] {
- checkEvaluation(
- StructsToJson(Map.empty, struct2, gmtId),
- null
- )
+ StructsToJson(Map.empty, struct2, gmtId).checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("Unable to convert column a of type calendarinterval to JSON"))
+ case _ => fail("from_json should not work on interval map value type.")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
index 8818d0135b297..b7ce367230810 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala
@@ -160,7 +160,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow)
}
- test("Coalesce should not throw 64kb exception") {
+ test("Coalesce should not throw 64KiB exception") {
val inputs = (1 to 2500).map(x => Literal(s"x_$x"))
checkEvaluation(Coalesce(inputs), "x_1")
}
@@ -171,7 +171,7 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(ctx.inlinedMutableStates.size == 1)
}
- test("AtLeastNNonNulls should not throw 64kb exception") {
+ test("AtLeastNNonNulls should not throw 64KiB exception") {
val inputs = (1 to 4000).map(x => Literal(s"x_$x"))
checkEvaluation(AtLeastNNonNulls(1, inputs), true)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
index d0604b8eb7675..94e251d90bcfa 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala
@@ -128,7 +128,7 @@ class OrderingSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
- test("SPARK-16845: GeneratedClass$SpecificOrdering grows beyond 64 KB") {
+ test("SPARK-16845: GeneratedClass$SpecificOrdering grows beyond 64 KiB") {
val sortOrder = Literal("abc").asc
// this is passing prior to SPARK-16845, and it should also be passing after SPARK-16845
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 0f63717f9daf2..3541afcd2144d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -24,6 +24,7 @@ import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
@@ -231,22 +232,12 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
testWithRandomDataGeneration(structType, nullable)
}
- // Map types: not supported
- for (
- keyType <- atomicTypes;
- valueType <- atomicTypes;
- nullable <- Seq(true, false)) {
- val mapType = MapType(keyType, valueType)
- val e = intercept[Exception] {
- testWithRandomDataGeneration(mapType, nullable)
- }
- if (e.getMessage.contains("Code generation of")) {
- // If the `value` expression is null, `eval` will be short-circuited.
- // Codegen version evaluation will be run then.
- assert(e.getMessage.contains("cannot generate equality code for un-comparable type"))
- } else {
- assert(e.getMessage.contains("Exception evaluating"))
- }
+ // In doesn't support map type and will fail the analyzer.
+ val map = Literal.create(create_map(1 -> 1), MapType(IntegerType, IntegerType))
+ In(map, Seq(map)).checkInputDataTypes() match {
+ case TypeCheckResult.TypeCheckFailure(msg) =>
+ assert(msg.contains("function in does not support ordering on type map"))
+ case _ => fail("In should not work on map type")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index aa334e040d5fc..e95f2dff231b9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -744,16 +744,14 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("ParseUrl") {
def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = {
- checkEvaluation(
- ParseUrl(Seq(Literal(urlStr), Literal(partToExtract))), expected)
+ checkEvaluation(ParseUrl(Seq(urlStr, partToExtract)), expected)
}
def checkParseUrlWithKey(
expected: String,
urlStr: String,
partToExtract: String,
key: String): Unit = {
- checkEvaluation(
- ParseUrl(Seq(Literal(urlStr), Literal(partToExtract), Literal(key))), expected)
+ checkEvaluation(ParseUrl(Seq(urlStr, partToExtract, key)), expected)
}
checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST")
@@ -798,7 +796,6 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Sentences(nullString, nullString, nullString), null)
checkEvaluation(Sentences(nullString, nullString), null)
checkEvaluation(Sentences(nullString), null)
- checkEvaluation(Sentences(Literal.create(null, NullType)), null)
checkEvaluation(Sentences("", nullString, nullString), Seq.empty)
checkEvaluation(Sentences("", nullString), Seq.empty)
checkEvaluation(Sentences(""), Seq.empty)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
index fb651b76fc16d..22e1fa6dfed4f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/UnsafeRowWriterSuite.scala
@@ -50,4 +50,24 @@ class UnsafeRowWriterSuite extends SparkFunSuite {
assert(res1 == res2)
}
+ test("SPARK-26021: normalize float/double NaN and -0.0") {
+ val unsafeRowWriter1 = new UnsafeRowWriter(4)
+ unsafeRowWriter1.resetRowWriter()
+ unsafeRowWriter1.write(0, Float.NaN)
+ unsafeRowWriter1.write(1, Double.NaN)
+ unsafeRowWriter1.write(2, 0.0f)
+ unsafeRowWriter1.write(3, 0.0)
+ val res1 = unsafeRowWriter1.getRow
+
+ val unsafeRowWriter2 = new UnsafeRowWriter(4)
+ unsafeRowWriter2.resetRowWriter()
+ unsafeRowWriter2.write(0, 0.0f/0.0f)
+ unsafeRowWriter2.write(1, 0.0/0.0)
+ unsafeRowWriter2.write(2, -0.0f)
+ unsafeRowWriter2.write(3, -0.0)
+ val res2 = unsafeRowWriter2.getRow
+
+ // The two rows should be the equal
+ assert(res1 == res2)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala
new file mode 100644
index 0000000000000..9307f9b47b807
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JsonInferSchemaSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.json
+
+import com.fasterxml.jackson.core.JsonFactory
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+
+class JsonInferSchemaSuite extends SparkFunSuite with SQLHelper {
+
+ def checkType(options: Map[String, String], json: String, dt: DataType): Unit = {
+ val jsonOptions = new JSONOptions(options, "UTC", "")
+ val inferSchema = new JsonInferSchema(jsonOptions)
+ val factory = new JsonFactory()
+ jsonOptions.setJacksonOptions(factory)
+ val parser = CreateJacksonParser.string(factory, json)
+ parser.nextToken()
+ val expectedType = StructType(Seq(StructField("a", dt, true)))
+
+ assert(inferSchema.inferField(parser) === expectedType)
+ }
+
+ def checkTimestampType(pattern: String, json: String): Unit = {
+ checkType(Map("timestampFormat" -> pattern), json, TimestampType)
+ }
+
+ test("inferring timestamp type") {
+ Seq(true, false).foreach { legacyParser =>
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) {
+ checkTimestampType("yyyy", """{"a": "2018"}""")
+ checkTimestampType("yyyy=MM", """{"a": "2018=12"}""")
+ checkTimestampType("yyyy MM dd", """{"a": "2018 12 02"}""")
+ checkTimestampType(
+ "yyyy-MM-dd'T'HH:mm:ss.SSS",
+ """{"a": "2018-12-02T21:04:00.123"}""")
+ checkTimestampType(
+ "yyyy-MM-dd'T'HH:mm:ss.SSSSSSXXX",
+ """{"a": "2018-12-02T21:04:00.123567+01:00"}""")
+ }
+ }
+ }
+
+ test("prefer decimals over timestamps") {
+ Seq(true, false).foreach { legacyParser =>
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) {
+ checkType(
+ options = Map(
+ "prefersDecimal" -> "true",
+ "timestampFormat" -> "yyyyMMdd.HHmmssSSS"
+ ),
+ json = """{"a": "20181202.210400123"}""",
+ dt = DecimalType(17, 9)
+ )
+ }
+ }
+ }
+
+ test("skip decimal type inferring") {
+ Seq(true, false).foreach { legacyParser =>
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) {
+ checkType(
+ options = Map(
+ "prefersDecimal" -> "false",
+ "timestampFormat" -> "yyyyMMdd.HHmmssSSS"
+ ),
+ json = """{"a": "20181202.210400123"}""",
+ dt = TimestampType
+ )
+ }
+ }
+ }
+
+ test("fallback to string type") {
+ Seq(true, false).foreach { legacyParser =>
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) {
+ checkType(
+ options = Map("timestampFormat" -> "yyyy,MM,dd.HHmmssSSS"),
+ json = """{"a": "20181202.210400123"}""",
+ dt = StringType
+ )
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 8d7c9bf220bc2..57195d5fda7c5 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -34,6 +34,7 @@ class ColumnPruningSuite extends PlanTest {
val batches = Batch("Column pruning", FixedPoint(100),
PushDownPredicate,
ColumnPruning,
+ RemoveNoopOperators,
CollapseProject) :: Nil
}
@@ -340,10 +341,8 @@ class ColumnPruningSuite extends PlanTest {
test("Column pruning on Union") {
val input1 = LocalRelation('a.int, 'b.string, 'c.double)
val input2 = LocalRelation('c.int, 'd.string, 'e.double)
- val query = Project('b :: Nil,
- Union(input1 :: input2 :: Nil)).analyze
- val expected = Project('b :: Nil,
- Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze
+ val query = Project('b :: Nil, Union(input1 :: input2 :: Nil)).analyze
+ val expected = Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil).analyze
comparePlans(Optimize.execute(query), expected)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
index ef4b848924f06..b190dd5a7c220 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -27,8 +27,9 @@ class CombiningLimitsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
- Batch("Filter Pushdown", FixedPoint(100),
- ColumnPruning) ::
+ Batch("Column Pruning", FixedPoint(100),
+ ColumnPruning,
+ RemoveNoopOperators) ::
Batch("Combine Limit", FixedPoint(10),
CombineLimits) ::
Batch("Constant Folding", FixedPoint(10),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index ccd9d8dd4d213..6fe5e619d03ad 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -39,6 +39,7 @@ class JoinOptimizationSuite extends PlanTest {
ReorderJoin,
PushPredicateThroughJoin,
ColumnPruning,
+ RemoveNoopOperators,
CollapseProject) :: Nil
}
@@ -102,16 +103,19 @@ class JoinOptimizationSuite extends PlanTest {
x.join(y).join(z).where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
x.join(z, condition = Some("x.b".attr === "z.b".attr))
.join(y, condition = Some("y.d".attr === "z.a".attr))
+ .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*)
),
(
x.join(y, Cross).join(z, Cross)
.where(("x.b".attr === "z.b".attr) && ("y.d".attr === "z.a".attr)),
x.join(z, Cross, Some("x.b".attr === "z.b".attr))
.join(y, Cross, Some("y.d".attr === "z.a".attr))
+ .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*)
),
(
x.join(y, Inner).join(z, Cross).where("x.b".attr === "z.a".attr),
x.join(z, Cross, Some("x.b".attr === "z.a".attr)).join(y, Inner)
+ .select(Seq("x.a", "x.b", "x.c", "y.d", "z.a", "z.b", "z.c").map(_.attr): _*)
)
)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
index 565b0a10154a8..c94a8b9e318f6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap}
-import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.{Cross, Inner, PlanTest}
+import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan}
import org.apache.spark.sql.internal.SQLConf.{CBO_ENABLED, JOIN_REORDER_ENABLED}
@@ -124,7 +124,8 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
// the original order (t1 J t2) J t3.
val bestPlan =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
- .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+ .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+ .select(outputsOf(t1, t2, t3): _*)
assertEqualPlans(originalPlan, bestPlan)
}
@@ -139,7 +140,9 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
val bestPlan =
t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")))
.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
+ .select(outputsOf(t1, t2, t3): _*) // this is redundant but we'll take it for now
.join(t4)
+ .select(outputsOf(t1, t2, t4, t3): _*)
assertEqualPlans(originalPlan, bestPlan)
}
@@ -202,6 +205,7 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")))
.join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))),
Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")))
+ .select(outputsOf(t1, t4, t2, t3): _*)
assertEqualPlans(originalPlan, bestPlan)
}
@@ -219,6 +223,23 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
}
}
+ test("SPARK-26352: join reordering should not change the order of attributes") {
+ // This test case does not rely on CBO.
+ // It's similar to the test case above, but catches a reordering bug that the one above doesn't
+ val tab1 = LocalRelation('x.int, 'y.int)
+ val tab2 = LocalRelation('i.int, 'j.int)
+ val tab3 = LocalRelation('a.int, 'b.int)
+ val original =
+ tab1.join(tab2, Cross)
+ .join(tab3, Inner, Some('a === 'x && 'b === 'i))
+ val expected =
+ tab1.join(tab3, Inner, Some('a === 'x))
+ .join(tab2, Cross, Some('b === 'i))
+ .select(outputsOf(tab1, tab2, tab3): _*)
+
+ assertEqualPlans(original, expected)
+ }
+
test("reorder recursively") {
// Original order:
// Join
@@ -266,8 +287,17 @@ class JoinReorderSuite extends PlanTest with StatsEstimationTestBase {
private def assertEqualPlans(
originalPlan: LogicalPlan,
groundTruthBestPlan: LogicalPlan): Unit = {
- val optimized = Optimize.execute(originalPlan.analyze)
+ val analyzed = originalPlan.analyze
+ val optimized = Optimize.execute(analyzed)
val expected = groundTruthBestPlan.analyze
+
+ assert(analyzed.sameOutput(expected)) // if this fails, the expected plan itself is incorrect
+ assert(analyzed.sameOutput(optimized))
+
compareJoinOrder(optimized, expected)
}
+
+ private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = {
+ plans.map(_.output).reduce(_ ++ _)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala
index 1973b5abb462d..3802dbf5d6e06 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantAliasAndProjectSuite.scala
@@ -33,7 +33,7 @@ class RemoveRedundantAliasAndProjectSuite extends PlanTest with PredicateHelper
FixedPoint(50),
PushProjectionThroughUnion,
RemoveRedundantAliases,
- RemoveRedundantProject) :: Nil
+ RemoveNoopOperators) :: Nil
}
test("all expressions in project list are aliased child output") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
index ee0d04da3e46c..748075bfd6a68 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or}
+import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable}
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
@@ -306,22 +306,24 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest {
testProjection(originalExpr = column, expectedExpr = column)
}
+ private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
test("replace nulls in lambda function of ArrayFilter") {
- testHigherOrderFunc('a, ArrayFilter, Seq('e))
+ testHigherOrderFunc('a, ArrayFilter, Seq(lv('e)))
}
test("replace nulls in lambda function of ArrayExists") {
- testHigherOrderFunc('a, ArrayExists, Seq('e))
+ testHigherOrderFunc('a, ArrayExists, Seq(lv('e)))
}
test("replace nulls in lambda function of MapFilter") {
- testHigherOrderFunc('m, MapFilter, Seq('k, 'v))
+ testHigherOrderFunc('m, MapFilter, Seq(lv('k), lv('v)))
}
test("inability to replace nulls in arbitrary higher-order function") {
val lambdaFunc = LambdaFunction(
- function = If('e > 0, Literal(null, BooleanType), TrueLiteral),
- arguments = Seq[NamedExpression]('e))
+ function = If(lv('e) > 0, Literal(null, BooleanType), TrueLiteral),
+ arguments = Seq[NamedExpression](lv('e)))
val column = ArrayTransform('a, lambdaFunc)
testProjection(originalExpr = column, expectedExpr = column)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
index 3b1b2d588ef67..c8e15c7da763e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceOperatorSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, Not}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Coalesce, If, Literal, Not}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
class ReplaceOperatorSuite extends PlanTest {
@@ -65,8 +66,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)),
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
comparePlans(optimized, correctAnswer)
@@ -84,8 +84,8 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)), table1)).analyze
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
+ table1)).analyze
comparePlans(optimized, correctAnswer)
}
@@ -104,8 +104,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)),
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB), table1))).analyze
comparePlans(optimized, correctAnswer)
@@ -125,8 +124,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA >= 2 && attributeB < 1)),
+ Filter(Not(Coalesce(Seq(attributeA >= 2 && attributeB < 1, Literal.FalseLiteral))),
Filter(attributeB === 2, Filter(attributeA === 1, table1)))).analyze
comparePlans(optimized, correctAnswer)
@@ -146,8 +144,7 @@ class ReplaceOperatorSuite extends PlanTest {
val correctAnswer =
Aggregate(table1.output, table1.output,
- Filter(Not((attributeA.isNotNull && attributeB.isNotNull) &&
- (attributeA === 1 && attributeB === 2)),
+ Filter(Not(Coalesce(Seq(attributeA === 1 && attributeB === 2, Literal.FalseLiteral))),
Project(Seq(attributeA, attributeB),
Filter(attributeB < 1, Filter(attributeA >= 2, table1))))).analyze
@@ -229,4 +226,29 @@ class ReplaceOperatorSuite extends PlanTest {
comparePlans(optimized, query)
}
+
+ test("SPARK-26366: ReplaceExceptWithFilter should handle properly NULL") {
+ val basePlan = LocalRelation(Seq('a.int, 'b.int))
+ val otherPlan = basePlan.where('a.in(1, 2) || 'b.in())
+ val except = Except(basePlan, otherPlan, false)
+ val result = OptimizeIn(Optimize.execute(except.analyze))
+ val correctAnswer = Aggregate(basePlan.output, basePlan.output,
+ Filter(!Coalesce(Seq(
+ 'a.in(1, 2) || If('b.isNotNull, Literal.FalseLiteral, Literal(null, BooleanType)),
+ Literal.FalseLiteral)),
+ basePlan)).analyze
+ comparePlans(result, correctAnswer)
+ }
+
+ test("SPARK-26366: ReplaceExceptWithFilter should not transform non-detrministic") {
+ val basePlan = LocalRelation(Seq('a.int, 'b.int))
+ val otherPlan = basePlan.where('a > rand(1L))
+ val except = Except(basePlan, otherPlan, false)
+ val result = Optimize.execute(except.analyze)
+ val condition = basePlan.output.zip(otherPlan.output).map { case (a1, a2) =>
+ a1 <=> a2 }.reduce( _ && _)
+ val correctAnswer = Aggregate(basePlan.output, otherPlan.output,
+ Join(basePlan, otherPlan, LeftAnti, Option(condition))).analyze
+ comparePlans(result, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
index 6b3739c372c3a..f00d22e6e96a6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala
@@ -34,7 +34,7 @@ class RewriteSubquerySuite extends PlanTest {
RewritePredicateSubquery,
ColumnPruning,
CollapseProject,
- RemoveRedundantProject) :: Nil
+ RemoveNoopOperators) :: Nil
}
test("Column pruning after rewriting predicate subquery") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala
index d4d23ad69b2c2..baae934e1e4fe 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinCostBasedReorderSuite.scala
@@ -218,6 +218,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk")))
.join(t2, Inner, Some(nameToAttr("f1_c2") === nameToAttr("t2_c1")))
.join(t1, Inner, Some(nameToAttr("f1_c1") === nameToAttr("t1_c1")))
+ .select(outputsOf(f1, t1, t2, d1, d2): _*)
assertEqualPlans(query, expected)
}
@@ -256,6 +257,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
.join(t3.join(t2, Inner, Some(nameToAttr("t2_c2") === nameToAttr("t3_c1"))), Inner,
Some(nameToAttr("d1_c2") === nameToAttr("t2_c1")))
.join(t1, Inner, Some(nameToAttr("t1_c1") === nameToAttr("f1_c1")))
+ .select(outputsOf(d1, t1, t2, f1, d2, t3): _*)
assertEqualPlans(query, expected)
}
@@ -297,6 +299,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
Some(nameToAttr("t3_c1") === nameToAttr("t4_c1")))
.join(t1.join(t2, Inner, Some(nameToAttr("t1_c1") === nameToAttr("t2_c1"))), Inner,
Some(nameToAttr("t1_c2") === nameToAttr("t4_c2")))
+ .select(outputsOf(d1, t1, t2, t3, t4, f1, d2): _*)
assertEqualPlans(query, expected)
}
@@ -347,6 +350,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
Some(nameToAttr("d3_c2") === nameToAttr("t1_c1")))
.join(t5.join(t6, Inner, Some(nameToAttr("t5_c2") === nameToAttr("t6_c2"))), Inner,
Some(nameToAttr("d2_c2") === nameToAttr("t5_c1")))
+ .select(outputsOf(d1, t3, t4, f1, d2, t5, t6, d3, t1, t2): _*)
assertEqualPlans(query, expected)
}
@@ -375,6 +379,7 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
f1.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk")))
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk")))
+ .select(outputsOf(d1, d2, f1, d3): _*)
assertEqualPlans(query, expected)
}
@@ -400,13 +405,27 @@ class StarJoinCostBasedReorderSuite extends PlanTest with StatsEstimationTestBas
f1.join(t3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("t3_c1")))
.join(t2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("t2_c1")))
.join(t1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("t1_c1")))
+ .select(outputsOf(t1, f1, t2, t3): _*)
assertEqualPlans(query, expected)
}
private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
- val optimized = Optimize.execute(plan1.analyze)
+ val analyzed = plan1.analyze
+ val optimized = Optimize.execute(analyzed)
val expected = plan2.analyze
+
+ assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect
+ assert(equivalentOutput(analyzed, optimized))
+
compareJoinOrder(optimized, expected)
}
+
+ private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = {
+ plans.map(_.output).reduce(_ ++ _)
+ }
+
+ private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
+ normalizeExprIds(plan1).output == normalizeExprIds(plan2).output
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
index 4e0883e91e84a..9dc653b9d6c44 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/StarJoinReorderSuite.scala
@@ -182,6 +182,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d1, Inner, Some(nameToAttr("f1_fk1") === nameToAttr("d1_pk1")))
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+ .select(outputsOf(d1, d2, f1, d3, s3): _*)
assertEqualPlans(query, expected)
}
@@ -220,6 +221,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+ .select(outputsOf(d1, f1, d2, s3, d3): _*)
assertEqualPlans(query, expected)
}
@@ -255,7 +257,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("s3_c2")))
-
+ .select(outputsOf(d1, f1, d2, s3, d3): _*)
assertEqualPlans(query, expected)
}
@@ -292,6 +294,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_c2")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
+ .select(outputsOf(d1, f1, d2, s3, d3): _*)
assertEqualPlans(query, expected)
}
@@ -395,6 +398,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("f11_fk1") === nameToAttr("s3_pk1")))
+ .select(outputsOf(d1, f11, f1, d2, s3): _*)
assertEqualPlans(query, equivQuery)
}
@@ -430,6 +434,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_c4")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+ .select(outputsOf(d1, d3, f1, d2, s3): _*)
assertEqualPlans(query, expected)
}
@@ -465,6 +470,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2), Inner,
Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+ .select(outputsOf(d1, d3, f1, d2, s3): _*)
assertEqualPlans(query, expected)
}
@@ -499,6 +505,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d2.where(nameToAttr("d2_c2") === 2),
Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+ .select(outputsOf(d1, d3, f1, d2, s3): _*)
assertEqualPlans(query, expected)
}
@@ -532,6 +539,7 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") < nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") < nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") < nameToAttr("s3_pk1")))
+ .select(outputsOf(d1, d3, f1, d2, s3): _*)
assertEqualPlans(query, expected)
}
@@ -565,13 +573,27 @@ class StarJoinReorderSuite extends PlanTest with StatsEstimationTestBase {
.join(d3, Inner, Some(nameToAttr("f1_fk3") === nameToAttr("d3_pk1")))
.join(d2, Inner, Some(nameToAttr("f1_fk2") === nameToAttr("d2_pk1")))
.join(s3, Inner, Some(nameToAttr("d3_fk1") === nameToAttr("s3_pk1")))
+ .select(outputsOf(d1, d3, f1, d2, s3): _*)
assertEqualPlans(query, expected)
}
- private def assertEqualPlans( plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
- val optimized = Optimize.execute(plan1.analyze)
+ private def assertEqualPlans(plan1: LogicalPlan, plan2: LogicalPlan): Unit = {
+ val analyzed = plan1.analyze
+ val optimized = Optimize.execute(analyzed)
val expected = plan2.analyze
+
+ assert(equivalentOutput(analyzed, expected)) // if this fails, the expected itself is incorrect
+ assert(equivalentOutput(analyzed, optimized))
+
compareJoinOrder(optimized, expected)
}
+
+ private def outputsOf(plans: LogicalPlan*): Seq[Attribute] = {
+ plans.map(_.output).reduce(_ ++ _)
+ }
+
+ private def equivalentOutput(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = {
+ normalizeExprIds(plan1).output == normalizeExprIds(plan2).output
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala
index 58b3d1c98f3cd..4acd57832d2f6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransposeWindowSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
class TransposeWindowSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
- Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveRedundantProject) ::
+ Batch("CollapseProject", FixedPoint(100), CollapseProject, RemoveNoopOperators) ::
Batch("FlipWindow", Once, CollapseWindow, TransposeWindow) :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
index b4df22c5b29fa..8bcc69d580d83 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala
@@ -246,9 +246,11 @@ class ExpressionParserSuite extends PlanTest {
intercept("foo(a x)", "extraneous input 'x'")
}
+ private def lv(s: Symbol) = UnresolvedNamedLambdaVariable(Seq(s.name))
+
test("lambda functions") {
- assertEqual("x -> x + 1", LambdaFunction('x + 1, Seq('x.attr)))
- assertEqual("(x, y) -> x + y", LambdaFunction('x + 'y, Seq('x.attr, 'y.attr)))
+ assertEqual("x -> x + 1", LambdaFunction(lv('x) + 1, Seq(lv('x))))
+ assertEqual("(x, y) -> x + y", LambdaFunction(lv('x) + lv('y), Seq(lv('x), lv('y))))
}
test("window function expressions") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala
new file mode 100644
index 0000000000000..019615b81101c
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateFormatterSuite.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.util
+
+import java.util.Locale
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.internal.SQLConf
+
+class DateFormatterSuite extends SparkFunSuite with SQLHelper {
+ test("parsing dates") {
+ DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
+ val formatter = DateFormatter("yyyy-MM-dd", Locale.US)
+ val daysSinceEpoch = formatter.parse("2018-12-02")
+ assert(daysSinceEpoch === 17867)
+ }
+ }
+ }
+
+ test("format dates") {
+ DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
+ val formatter = DateFormatter("yyyy-MM-dd", Locale.US)
+ val date = formatter.format(17867)
+ assert(date === "2018-12-02")
+ }
+ }
+ }
+
+ test("roundtrip date -> days -> date") {
+ Seq(
+ "0050-01-01",
+ "0953-02-02",
+ "1423-03-08",
+ "1969-12-31",
+ "1972-08-25",
+ "1975-09-26",
+ "2018-12-12",
+ "2038-01-01",
+ "5010-11-17").foreach { date =>
+ DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
+ val formatter = DateFormatter("yyyy-MM-dd", Locale.US)
+ val days = formatter.parse(date)
+ val formatted = formatter.format(days)
+ assert(date === formatted)
+ }
+ }
+ }
+ }
+
+ test("roundtrip days -> date -> days") {
+ Seq(
+ -701265,
+ -371419,
+ -199722,
+ -1,
+ 0,
+ 967,
+ 2094,
+ 17877,
+ 24837,
+ 1110657).foreach { days =>
+ DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
+ withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> timeZone) {
+ val formatter = DateFormatter("yyyy-MM-dd", Locale.US)
+ val date = formatter.format(days)
+ val parsed = formatter.parse(date)
+ assert(days === parsed)
+ }
+ }
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
similarity index 60%
rename from sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala
rename to sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
index 02d4ee0490604..c110ffa01f733 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/DateTimeFormatterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/TimestampFormatterSuite.scala
@@ -20,26 +20,10 @@ package org.apache.spark.sql.util
import java.util.{Locale, TimeZone}
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeFormatter, DateTimeTestUtils}
+import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, TimestampFormatter}
-class DateTimeFormatterSuite extends SparkFunSuite {
- test("parsing dates using time zones") {
- val localDate = "2018-12-02"
- val expectedDays = Map(
- "UTC" -> 17867,
- "PST" -> 17867,
- "CET" -> 17866,
- "Africa/Dakar" -> 17867,
- "America/Los_Angeles" -> 17867,
- "Antarctica/Vostok" -> 17866,
- "Asia/Hong_Kong" -> 17866,
- "Europe/Amsterdam" -> 17866)
- DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
- val formatter = DateFormatter("yyyy-MM-dd", TimeZone.getTimeZone(timeZone), Locale.US)
- val daysSinceEpoch = formatter.parse(localDate)
- assert(daysSinceEpoch === expectedDays(timeZone))
- }
- }
+class TimestampFormatterSuite extends SparkFunSuite with SQLHelper {
test("parsing timestamps using time zones") {
val localDate = "2018-12-02T10:11:12.001234"
@@ -53,7 +37,7 @@ class DateTimeFormatterSuite extends SparkFunSuite {
"Asia/Hong_Kong" -> 1543716672001234L,
"Europe/Amsterdam" -> 1543741872001234L)
DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
- val formatter = DateTimeFormatter(
+ val formatter = TimestampFormatter(
"yyyy-MM-dd'T'HH:mm:ss.SSSSSS",
TimeZone.getTimeZone(timeZone),
Locale.US)
@@ -62,24 +46,6 @@ class DateTimeFormatterSuite extends SparkFunSuite {
}
}
- test("format dates using time zones") {
- val daysSinceEpoch = 17867
- val expectedDate = Map(
- "UTC" -> "2018-12-02",
- "PST" -> "2018-12-01",
- "CET" -> "2018-12-02",
- "Africa/Dakar" -> "2018-12-02",
- "America/Los_Angeles" -> "2018-12-01",
- "Antarctica/Vostok" -> "2018-12-02",
- "Asia/Hong_Kong" -> "2018-12-02",
- "Europe/Amsterdam" -> "2018-12-02")
- DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
- val formatter = DateFormatter("yyyy-MM-dd", TimeZone.getTimeZone(timeZone), Locale.US)
- val date = formatter.format(daysSinceEpoch)
- assert(date === expectedDate(timeZone))
- }
- }
-
test("format timestamps using time zones") {
val microsSinceEpoch = 1543745472001234L
val expectedTimestamp = Map(
@@ -92,7 +58,7 @@ class DateTimeFormatterSuite extends SparkFunSuite {
"Asia/Hong_Kong" -> "2018-12-02T18:11:12.001234",
"Europe/Amsterdam" -> "2018-12-02T11:11:12.001234")
DateTimeTestUtils.outstandingTimezonesIds.foreach { timeZone =>
- val formatter = DateTimeFormatter(
+ val formatter = TimestampFormatter(
"yyyy-MM-dd'T'HH:mm:ss.SSSSSS",
TimeZone.getTimeZone(timeZone),
Locale.US)
@@ -100,4 +66,44 @@ class DateTimeFormatterSuite extends SparkFunSuite {
assert(timestamp === expectedTimestamp(timeZone))
}
}
+
+ test("roundtrip micros -> timestamp -> micros using timezones") {
+ Seq(
+ -58710115316212000L,
+ -18926315945345679L,
+ -9463427405253013L,
+ -244000001L,
+ 0L,
+ 99628200102030L,
+ 1543749753123456L,
+ 2177456523456789L,
+ 11858049903010203L).foreach { micros =>
+ DateTimeTestUtils.outstandingTimezones.foreach { timeZone =>
+ val formatter = TimestampFormatter("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", timeZone, Locale.US)
+ val timestamp = formatter.format(micros)
+ val parsed = formatter.parse(timestamp)
+ assert(micros === parsed)
+ }
+ }
+ }
+
+ test("roundtrip timestamp -> micros -> timestamp using timezones") {
+ Seq(
+ "0109-07-20T18:38:03.788000",
+ "1370-04-01T10:00:54.654321",
+ "1670-02-11T14:09:54.746987",
+ "1969-12-31T23:55:55.999999",
+ "1970-01-01T00:00:00.000000",
+ "1973-02-27T02:30:00.102030",
+ "2018-12-02T11:22:33.123456",
+ "2039-01-01T01:02:03.456789",
+ "2345-10-07T22:45:03.010203").foreach { timestamp =>
+ DateTimeTestUtils.outstandingTimezones.foreach { timeZone =>
+ val formatter = TimestampFormatter("yyyy-MM-dd'T'HH:mm:ss.SSSSSS", timeZone, Locale.US)
+ val micros = formatter.parse(timestamp)
+ val formatted = formatter.format(micros)
+ assert(timestamp === formatted)
+ }
+ }
+ }
}
diff --git a/sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt b/sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt
new file mode 100644
index 0000000000000..338244ad542f4
--- /dev/null
+++ b/sql/core/benchmarks/HashedRelationMetricsBenchmark-results.txt
@@ -0,0 +1,11 @@
+================================================================================================
+LongToUnsafeRowMap metrics
+================================================================================================
+
+Java HotSpot(TM) 64-Bit Server VM 1.8.0_181-b13 on Mac OS X 10.13.6
+Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz
+LongToUnsafeRowMap metrics: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+------------------------------------------------------------------------------------------------
+LongToUnsafeRowMap 234 / 315 2.1 467.3 1.0X
+
+
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index c8cf44b51df77..7e76a651ba2cb 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -98,7 +98,7 @@ public UnsafeFixedWidthAggregationMap(
this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema);
this.groupingKeySchema = groupingKeySchema;
this.map = new BytesToBytesMap(
- taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes, true);
+ taskContext.taskMemoryManager(), initialCapacity, pageSizeBytes);
// Initialize the buffer for aggregation value
final UnsafeProjection valueProjection = UnsafeProjection.create(aggregationBufferSchema);
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java
index 0df89dbb608a4..6c5a95d2a75b7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java
@@ -24,10 +24,10 @@
/**
* An empty mix-in interface for {@link Table}, to indicate this table supports batch scan.
*
- * If a {@link Table} implements this interface, its {@link Table#newScanBuilder(DataSourceOptions)}
- * must return a {@link ScanBuilder} that builds {@link Scan} with {@link Scan#toBatch()}
- * implemented.
+ * If a {@link Table} implements this interface, the
+ * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that
+ * builds {@link Scan} with {@link Scan#toBatch()} implemented.
*
*/
@Evolving
-public interface SupportsBatchRead extends Table { }
+public interface SupportsBatchRead extends SupportsRead { }
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java
similarity index 52%
rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala
rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java
index e342110763196..e22738d20d507 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/HadoopSparkUserExecutorFeatureStep.scala
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java
@@ -14,21 +14,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.deploy.k8s.features
-import org.apache.spark.deploy.k8s.{KubernetesExecutorConf, SparkPod}
-import org.apache.spark.deploy.k8s.Constants._
-import org.apache.spark.deploy.k8s.features.hadooputils.HadoopBootstrapUtil
+package org.apache.spark.sql.sources.v2;
+
+import org.apache.spark.sql.sources.v2.reader.Scan;
+import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
/**
- * This step is responsible for setting ENV_SPARK_USER when HADOOP_FILES are detected
- * however, this step would not be run if Kerberos is enabled, as Kerberos sets SPARK_USER
+ * An internal base interface of mix-in interfaces for readable {@link Table}. This adds
+ * {@link #newScanBuilder(DataSourceOptions)} that is used to create a scan for batch, micro-batch,
+ * or continuous processing.
*/
-private[spark] class HadoopSparkUserExecutorFeatureStep(conf: KubernetesExecutorConf)
- extends KubernetesFeatureConfigStep {
+interface SupportsRead extends Table {
- override def configurePod(pod: SparkPod): SparkPod = {
- val sparkUserName = conf.get(KERBEROS_SPARK_USER_NAME)
- HadoopBootstrapUtil.bootstrapSparkUserPod(sparkUserName, pod)
- }
+ /**
+ * Returns a {@link ScanBuilder} which can be used to build a {@link Scan}. Spark will call this
+ * method to configure each scan.
+ */
+ ScanBuilder newScanBuilder(DataSourceOptions options);
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java
index 0c65fe0f9e76a..08664859b8de2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java
@@ -18,8 +18,6 @@
package org.apache.spark.sql.sources.v2;
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.sources.v2.reader.Scan;
-import org.apache.spark.sql.sources.v2.reader.ScanBuilder;
import org.apache.spark.sql.types.StructType;
/**
@@ -43,17 +41,8 @@ public interface Table {
String name();
/**
- * Returns the schema of this table.
+ * Returns the schema of this table. If the table is not readable and doesn't have a schema, an
+ * empty schema can be returned here.
*/
StructType schema();
-
- /**
- * Returns a {@link ScanBuilder} which can be used to build a {@link Scan} later. Spark will call
- * this method for each data scanning query.
- *
- * The builder can take some query specific information to do operators pushdown, and keep these
- * information in the created {@link Scan}.
- *
- */
- ScanBuilder newScanBuilder(DataSourceOptions options);
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 661fe98d8c901..9751528654ffb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -362,7 +362,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* during parsing.
*
* - `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a
- * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To
+ * field configured by `columnNameOfCorruptRecord`, and sets malformed fields to `null`. To
* keep corrupt records, an user can set a string type field named
* `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the
* field, it drops corrupt records during parsing. When inferring a schema, it implicitly
@@ -598,13 +598,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* during parsing. It supports the following case-insensitive modes.
*
* - `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a
- * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep
- * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
- * in an user-defined schema. If a schema does not have the field, it drops corrupt records
- * during parsing. A record with less/more tokens than schema is not a corrupted record to
- * CSV. When it meets a record having fewer tokens than the length of the schema, sets
- * `null` to extra fields. When the record has more tokens than the length of the schema,
- * it drops extra tokens.
+ * field configured by `columnNameOfCorruptRecord`, and sets malformed fields to `null`.
+ * To keep corrupt records, an user can set a string type field named
+ * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have
+ * the field, it drops corrupt records during parsing. A record with less/more tokens
+ * than schema is not a corrupted record to CSV. When it meets a record having fewer
+ * tokens than the length of the schema, sets `null` to extra fields. When the record
+ * has more tokens than the length of the schema, it drops extra tokens.
* - `DROPMALFORMED` : ignores the whole corrupted records.
* - `FAILFAST` : throws an exception when it meets corrupted records.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b10d66dfb1aef..a664c7338badb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql
-import java.io.CharArrayWriter
+import java.io.{CharArrayWriter, DataOutputStream}
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.util.control.NonFatal
@@ -3200,34 +3201,38 @@ class Dataset[T] private[sql](
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
withAction("collectAsArrowToPython", queryExecution) { plan =>
- PythonRDD.serveToStream("serve-Arrow") { out =>
+ PythonRDD.serveToStream("serve-Arrow") { outputStream =>
+ val out = new DataOutputStream(outputStream)
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length
- // Store collection results for worst case of 1 to N-1 partitions
- val results = new Array[Array[Array[Byte]]](numPartitions - 1)
- var lastIndex = -1 // index of last partition written
+ // Batches ordered by (index of partition, batch index in that partition) tuple
+ val batchOrder = new ArrayBuffer[(Int, Int)]()
+ var partitionCount = 0
- // Handler to eagerly write partitions to Python in order
+ // Handler to eagerly write batches to Python as they arrive, un-ordered
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
- // If result is from next partition in order
- if (index - 1 == lastIndex) {
+ if (arrowBatches.nonEmpty) {
+ // Write all batches (can be more than 1) in the partition, store the batch order tuple
batchWriter.writeBatches(arrowBatches.iterator)
- lastIndex += 1
- // Write stored partitions that come next in order
- while (lastIndex < results.length && results(lastIndex) != null) {
- batchWriter.writeBatches(results(lastIndex).iterator)
- results(lastIndex) = null
- lastIndex += 1
+ arrowBatches.indices.foreach {
+ partition_batch_index => batchOrder.append((index, partition_batch_index))
}
- // After last batch, end the stream
- if (lastIndex == results.length) {
- batchWriter.end()
+ }
+ partitionCount += 1
+
+ // After last batch, end the stream and write batch order indices
+ if (partitionCount == numPartitions) {
+ batchWriter.end()
+ out.writeInt(batchOrder.length)
+ // Sort by (index of partition, batch index in that partition) tuple to get the
+ // overall_batch_index from 0 to N-1 batches, which can be used to put the
+ // transferred batches in the correct order
+ batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) =>
+ out.writeInt(overall_batch_index)
}
- } else {
- // Store partitions received out of order
- results(index - 1) = arrowBatches
+ out.flush()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
index d83a01ff9ea65..0f5aab7f47d0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RuntimeConfig.scala
@@ -153,7 +153,7 @@ class RuntimeConfig private[sql](sqlConf: SQLConf = new SQLConf) {
if (SQLConf.staticConfKeys.contains(key)) {
throw new AnalysisException(s"Cannot modify the value of a static config: $key")
}
- if (sqlConf.setCommandRejectsSparkConfs &&
+ if (sqlConf.setCommandRejectsSparkCoreConfs &&
ConfigEntry.findEntry(key) != null && !SQLConf.sqlConfEntries.containsKey(key)) {
throw new AnalysisException(s"Cannot modify the value of a Spark config: $key")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index d642402c63310..1d7dd73706c48 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.fs.{BlockLocation, FileStatus, LocatedFileStatus, Path}
@@ -167,19 +167,26 @@ case class FileSourceScanExec(
partitionSchema = relation.partitionSchema,
relation.sparkSession.sessionState.conf)
+ val driverMetrics: HashMap[String, Long] = HashMap.empty
+
+ /**
+ * Send the driver-side metrics. Before calling this function, selectedPartitions has
+ * been initialized. See SPARK-26327 for more details.
+ */
+ private def sendDriverMetrics(): Unit = {
+ driverMetrics.foreach(e => metrics(e._1).add(e._2))
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ SQLMetrics.postDriverMetricUpdates(sparkContext, executionId,
+ metrics.filter(e => driverMetrics.contains(e._1)).values.toSeq)
+ }
+
@transient private lazy val selectedPartitions: Seq[PartitionDirectory] = {
val optimizerMetadataTimeNs = relation.location.metadataOpsTimeNs.getOrElse(0L)
val startTime = System.nanoTime()
val ret = relation.location.listFiles(partitionFilters, dataFilters)
+ driverMetrics("numFiles") = ret.map(_.files.size.toLong).sum
val timeTakenMs = ((System.nanoTime() - startTime) + optimizerMetadataTimeNs) / 1000 / 1000
-
- metrics("numFiles").add(ret.map(_.files.size.toLong).sum)
- metrics("metadataTime").add(timeTakenMs)
-
- val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- SQLMetrics.postDriverMetricUpdates(sparkContext, executionId,
- metrics("numFiles") :: metrics("metadataTime") :: Nil)
-
+ driverMetrics("metadataTime") = timeTakenMs
ret
}
@@ -301,12 +308,14 @@ case class FileSourceScanExec(
options = relation.options,
hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))
- relation.bucketSpec match {
+ val readRDD = relation.bucketSpec match {
case Some(bucketing) if relation.sparkSession.sessionState.conf.bucketingEnabled =>
createBucketedReadRDD(bucketing, readFile, selectedPartitions, relation)
case _ =>
createNonBucketedReadRDD(readFile, selectedPartitions, relation)
}
+ sendDriverMetrics()
+ readRDD
}
override def inputRDDs(): Seq[RDD[InternalRow]] = {
@@ -316,7 +325,7 @@ case class FileSourceScanExec(
override lazy val metrics =
Map("numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numFiles" -> SQLMetrics.createMetric(sparkContext, "number of files"),
- "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time (ms)"),
+ "metadataTime" -> SQLMetrics.createMetric(sparkContext, "metadata time"),
"scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time"))
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 6c493909645de..981ecae80a724 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -18,54 +18,14 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Encoder, Row, SparkSession}
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
+import org.apache.spark.sql.{Encoder, SparkSession}
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.DataType
-
-object RDDConversions {
- def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = {
- data.mapPartitions { iterator =>
- val numColumns = outputTypes.length
- val mutableRow = new GenericInternalRow(numColumns)
- val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter)
- iterator.map { r =>
- var i = 0
- while (i < numColumns) {
- mutableRow(i) = converters(i)(r.productElement(i))
- i += 1
- }
-
- mutableRow
- }
- }
- }
-
- /**
- * Convert the objects inside Row into the types Catalyst expected.
- */
- def rowToRowRdd(data: RDD[Row], outputTypes: Seq[DataType]): RDD[InternalRow] = {
- data.mapPartitions { iterator =>
- val numColumns = outputTypes.length
- val mutableRow = new GenericInternalRow(numColumns)
- val converters = outputTypes.map(CatalystTypeConverters.createToCatalystConverter)
- iterator.map { r =>
- var i = 0
- while (i < numColumns) {
- mutableRow(i) = converters(i)(r(i))
- i += 1
- }
-
- mutableRow
- }
- }
- }
-}
object ExternalRDD {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
index 9b05faaed0459..079ff25fcb67e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala
@@ -22,7 +22,7 @@ import java.util.Arrays
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleMetricsReporter}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter}
/**
* The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition
@@ -157,9 +157,9 @@ class ShuffledRowRDD(
override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {
val shuffledRowPartition = split.asInstanceOf[ShuffledRowRDDPartition]
val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics()
- // `SQLShuffleMetricsReporter` will update its own metrics for SQL exchange operator,
+ // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator,
// as well as the `tempMetrics` for basic shuffle metrics.
- val sqlMetricsReporter = new SQLShuffleMetricsReporter(tempMetrics, metrics)
+ val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics)
// The range of pre-shuffle partitions that we are fetching at here is
// [startPreShufflePartitionIndex, endPreShufflePartitionIndex - 1].
val reader =
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 9d9b020309d9f..a89ccca99d059 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -423,11 +423,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
}
}
-object SparkPlan {
- private[execution] val subqueryExecutionContext = ExecutionContext.fromExecutorService(
- ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
-}
-
trait LeafExecNode extends SparkPlan {
override final def children: Seq[SparkPlan] = Nil
override def producedAttributes: AttributeSet = outputSet
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index e1faecedd20ed..096481f68275d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -820,6 +820,14 @@ object DDLUtils {
table.provider.isDefined && table.provider.get.toLowerCase(Locale.ROOT) != HIVE_PROVIDER
}
+ def readHiveTable(table: CatalogTable): HiveTableRelation = {
+ HiveTableRelation(
+ table,
+ // Hive table columns are always nullable.
+ table.dataSchema.asNullable.toAttributes,
+ table.partitionSchema.asNullable.toAttributes)
+ }
+
/**
* Throws a standard error for actions that require partitionProvider = hive.
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 795a6d0b6b040..fefff68c4ba8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -122,21 +122,14 @@ case class DataSource(
* be any further inference in any triggers.
*
* @param format the file format object for this DataSource
- * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list
+ * @param getFileIndex [[InMemoryFileIndex]] for getting partition schema and file list
* @return A pair of the data schema (excluding partition columns) and the schema of the partition
* columns.
*/
private def getOrInferFileFormatSchema(
format: FileFormat,
- fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = {
- // The operations below are expensive therefore try not to do them if we don't need to, e.g.,
- // in streaming mode, we have already inferred and registered partition columns, we will
- // never have to materialize the lazy val below
- lazy val tempFileIndex = fileIndex.getOrElse {
- val globbedPaths =
- checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false)
- createInMemoryFileIndex(globbedPaths)
- }
+ getFileIndex: () => InMemoryFileIndex): (StructType, StructType) = {
+ lazy val tempFileIndex = getFileIndex()
val partitionSchema = if (partitionColumns.isEmpty) {
// Try to infer partitioning, because no DataSource in the read path provides the partitioning
@@ -236,7 +229,15 @@ case class DataSource(
"you may be able to create a static DataFrame on that directory with " +
"'spark.read.load(directory)' and infer schema from it.")
}
- val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format)
+
+ val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, () => {
+ // The operations below are expensive therefore try not to do them if we don't need to,
+ // e.g., in streaming mode, we have already inferred and registered partition columns,
+ // we will never have to materialize the lazy val below
+ val globbedPaths =
+ checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false)
+ createInMemoryFileIndex(globbedPaths)
+ })
SourceInfo(
s"FileSource[$path]",
StructType(dataSchema ++ partitionSchema),
@@ -370,7 +371,7 @@ case class DataSource(
} else {
val index = createInMemoryFileIndex(globbedPaths)
val (resultDataSchema, resultPartitionSchema) =
- getOrInferFileFormatSchema(format, Some(index))
+ getOrInferFileFormatSchema(format, () => index)
(index, resultDataSchema, resultPartitionSchema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
index c6000442fae76..b5cf8c9515bfb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala
@@ -29,11 +29,11 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, Quali
import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTable, LogicalPlan, Project}
-import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan}
import org.apache.spark.sql.execution.command._
@@ -244,27 +244,19 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan]
})
}
- private def readHiveTable(table: CatalogTable): LogicalPlan = {
- HiveTableRelation(
- table,
- // Hive table columns are always nullable.
- table.dataSchema.asNullable.toAttributes,
- table.partitionSchema.asNullable.toAttributes)
- }
-
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _)
if DDLUtils.isDatasourceTable(tableMeta) =>
i.copy(table = readDataSourceTable(tableMeta))
case i @ InsertIntoTable(UnresolvedCatalogRelation(tableMeta), _, _, _, _) =>
- i.copy(table = readHiveTable(tableMeta))
+ i.copy(table = DDLUtils.readHiveTable(tableMeta))
case UnresolvedCatalogRelation(tableMeta) if DDLUtils.isDatasourceTable(tableMeta) =>
readDataSourceTable(tableMeta)
case UnresolvedCatalogRelation(tableMeta) =>
- readHiveTable(tableMeta)
+ DDLUtils.readHiveTable(tableMeta)
}
}
@@ -416,7 +408,10 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with
output: Seq[Attribute],
rdd: RDD[Row]): RDD[InternalRow] = {
if (relation.relation.needConversion) {
- execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType))
+ val converters = RowEncoder(StructType.fromAttributes(output))
+ rdd.mapPartitions { iterator =>
+ iterator.map(converters.toRow)
+ }
} else {
rdd.asInstanceOf[RDD[InternalRow]]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
index 7b0e4dbcc25f4..b2e4155e6f49e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala
@@ -127,13 +127,13 @@ abstract class PartitioningAwareFileIndex(
val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION)
.getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone)
- val caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis
PartitioningUtils.parsePartitions(
leafDirs,
typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled,
basePaths = basePaths,
userSpecifiedSchema = userSpecifiedSchema,
- caseSensitive = caseSensitive,
+ caseSensitive = sparkSession.sqlContext.conf.caseSensitiveAnalysis,
+ validatePartitionColumns = sparkSession.sqlContext.conf.validatePartitionColumns,
timeZoneId = timeZoneId)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index d66cb09bda0cc..6458b65466fb5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -26,12 +26,13 @@ import scala.util.Try
import org.apache.hadoop.fs.Path
-import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{Resolver, TypeCoercion}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal}
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
@@ -96,9 +97,10 @@ object PartitioningUtils {
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
+ validatePartitionColumns: Boolean,
timeZoneId: String): PartitionSpec = {
- parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema,
- caseSensitive, DateTimeUtils.getTimeZone(timeZoneId))
+ parsePartitions(paths, typeInference, basePaths, userSpecifiedSchema, caseSensitive,
+ validatePartitionColumns, DateTimeUtils.getTimeZone(timeZoneId))
}
private[datasources] def parsePartitions(
@@ -107,6 +109,7 @@ object PartitioningUtils {
basePaths: Set[Path],
userSpecifiedSchema: Option[StructType],
caseSensitive: Boolean,
+ validatePartitionColumns: Boolean,
timeZone: TimeZone): PartitionSpec = {
val userSpecifiedDataTypes = if (userSpecifiedSchema.isDefined) {
val nameToDataType = userSpecifiedSchema.get.fields.map(f => f.name -> f.dataType).toMap
@@ -121,7 +124,8 @@ object PartitioningUtils {
// First, we need to parse every partition's path and see if we can find partition values.
val (partitionValues, optDiscoveredBasePaths) = paths.map { path =>
- parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes, timeZone)
+ parsePartition(path, typeInference, basePaths, userSpecifiedDataTypes,
+ validatePartitionColumns, timeZone)
}.unzip
// We create pairs of (path -> path's partition value) here
@@ -203,6 +207,7 @@ object PartitioningUtils {
typeInference: Boolean,
basePaths: Set[Path],
userSpecifiedDataTypes: Map[String, DataType],
+ validatePartitionColumns: Boolean,
timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = {
val columns = ArrayBuffer.empty[(String, Literal)]
// Old Hadoop versions don't have `Path.isRoot`
@@ -224,7 +229,8 @@ object PartitioningUtils {
// Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1.
// Once we get the string, we try to parse it and find the partition column and value.
val maybeColumn =
- parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes, timeZone)
+ parsePartitionColumn(currentPath.getName, typeInference, userSpecifiedDataTypes,
+ validatePartitionColumns, timeZone)
maybeColumn.foreach(columns += _)
// Now, we determine if we should stop.
@@ -258,6 +264,7 @@ object PartitioningUtils {
columnSpec: String,
typeInference: Boolean,
userSpecifiedDataTypes: Map[String, DataType],
+ validatePartitionColumns: Boolean,
timeZone: TimeZone): Option[(String, Literal)] = {
val equalSignIndex = columnSpec.indexOf('=')
if (equalSignIndex == -1) {
@@ -272,10 +279,15 @@ object PartitioningUtils {
val literal = if (userSpecifiedDataTypes.contains(columnName)) {
// SPARK-26188: if user provides corresponding column schema, get the column value without
// inference, and then cast it as user specified data type.
- val columnValue = inferPartitionColumnValue(rawColumnValue, false, timeZone)
- val castedValue =
- Cast(columnValue, userSpecifiedDataTypes(columnName), Option(timeZone.getID)).eval()
- Literal.create(castedValue, userSpecifiedDataTypes(columnName))
+ val dataType = userSpecifiedDataTypes(columnName)
+ val columnValueLiteral = inferPartitionColumnValue(rawColumnValue, false, timeZone)
+ val columnValue = columnValueLiteral.eval()
+ val castedValue = Cast(columnValueLiteral, dataType, Option(timeZone.getID)).eval()
+ if (validatePartitionColumns && columnValue != null && castedValue == null) {
+ throw new RuntimeException(s"Failed to cast value `$columnValue` to `$dataType` " +
+ s"for partition column `$columnName`")
+ }
+ Literal.create(castedValue, dataType)
} else {
inferPartitionColumnValue(rawColumnValue, typeInference, timeZone)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
index b46dfb94c133e..375cec597166c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala
@@ -35,6 +35,7 @@ import org.apache.spark.rdd.{BinaryFileRDD, RDD}
import org.apache.spark.sql.{Dataset, Encoders, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVInferSchema, CSVOptions, UnivocityParser}
+import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.text.TextFileFormat
import org.apache.spark.sql.types.StructType
@@ -135,7 +136,9 @@ object TextInputCSVDataSource extends CSVDataSource {
val parser = new CsvParser(parsedOptions.asParserSettings)
linesWithoutHeader.map(parser.parseLine)
}
- new CSVInferSchema(parsedOptions).infer(tokenRDD, header)
+ SQLExecution.withSQLConfPropagated(csv.sparkSession) {
+ new CSVInferSchema(parsedOptions).infer(tokenRDD, header)
+ }
case _ =>
// If the first line could not be read, just return the empty schema.
StructType(Nil)
@@ -208,7 +211,9 @@ object MultiLineCSVDataSource extends CSVDataSource {
encoding = parsedOptions.charset)
}
val sampled = CSVUtils.sample(tokenRDD, parsedOptions)
- new CSVInferSchema(parsedOptions).infer(sampled, header)
+ SQLExecution.withSQLConfPropagated(sparkSession) {
+ new CSVInferSchema(parsedOptions).infer(sampled, header)
+ }
case None =>
// If the first row could not be read, just return the empty schema.
StructType(Nil)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
index f7d8a9e1042d5..f4f139d180058 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVFileFormat.scala
@@ -189,5 +189,5 @@ private[csv] class CsvOutputWriter(
gen.write(row)
}
- override def close(): Unit = univocityGenerator.map(_.close())
+ override def close(): Unit = univocityGenerator.foreach(_.close())
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
index 3042133ee43aa..40f55e7068010 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala
@@ -190,5 +190,5 @@ private[json] class JsonOutputWriter(
gen.writeLineEnding()
}
- override def close(): Unit = jacksonGenerator.map(_.close())
+ override def close(): Unit = jacksonGenerator.foreach(_.close())
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
index 01948ab25d63c..0607f7b3c0d4a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala
@@ -153,7 +153,7 @@ class TextOutputWriter(
private var outputStream: Option[OutputStream] = None
override def write(row: InternalRow): Unit = {
- val os = outputStream.getOrElse{
+ val os = outputStream.getOrElse {
val newStream = CodecStreams.createOutputStream(context, new Path(path))
outputStream = Some(newStream)
newStream
@@ -167,6 +167,6 @@ class TextOutputWriter(
}
override def close(): Unit = {
- outputStream.map(_.close())
+ outputStream.foreach(_.close())
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 9a1fe1e0a328b..d7e20eed4cbc0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.sources.v2.writer._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{LongAccumulator, Utils}
/**
* Deprecated logical plan for writing data into data source v2. This is being replaced by more
@@ -47,6 +47,8 @@ case class WriteToDataSourceV2(writeSupport: BatchWriteSupport, query: LogicalPl
case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: SparkPlan)
extends UnaryExecNode {
+ var commitProgress: Option[StreamWriterCommitProgress] = None
+
override def child: SparkPlan = query
override def output: Seq[Attribute] = Nil
@@ -55,6 +57,7 @@ case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: Spark
val useCommitCoordinator = writeSupport.useCommitCoordinator
val rdd = query.execute()
val messages = new Array[WriterCommitMessage](rdd.partitions.length)
+ val totalNumRowsAccumulator = new LongAccumulator()
logInfo(s"Start processing data source write support: $writeSupport. " +
s"The input RDD has ${messages.length} partitions.")
@@ -65,15 +68,18 @@ case class WriteToDataSourceV2Exec(writeSupport: BatchWriteSupport, query: Spark
(context: TaskContext, iter: Iterator[InternalRow]) =>
DataWritingSparkTask.run(writerFactory, context, iter, useCommitCoordinator),
rdd.partitions.indices,
- (index, message: WriterCommitMessage) => {
- messages(index) = message
- writeSupport.onDataWriterCommit(message)
+ (index, result: DataWritingSparkTaskResult) => {
+ val commitMessage = result.writerCommitMessage
+ messages(index) = commitMessage
+ totalNumRowsAccumulator.add(result.numRows)
+ writeSupport.onDataWriterCommit(commitMessage)
}
)
logInfo(s"Data source write support $writeSupport is committing.")
writeSupport.commit(messages)
logInfo(s"Data source write support $writeSupport committed.")
+ commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.value))
} catch {
case cause: Throwable =>
logError(s"Data source write support $writeSupport is aborting.")
@@ -102,7 +108,7 @@ object DataWritingSparkTask extends Logging {
writerFactory: DataWriterFactory,
context: TaskContext,
iter: Iterator[InternalRow],
- useCommitCoordinator: Boolean): WriterCommitMessage = {
+ useCommitCoordinator: Boolean): DataWritingSparkTaskResult = {
val stageId = context.stageId()
val stageAttempt = context.stageAttemptNumber()
val partId = context.partitionId()
@@ -110,9 +116,12 @@ object DataWritingSparkTask extends Logging {
val attemptId = context.attemptNumber()
val dataWriter = writerFactory.createWriter(partId, taskId)
+ var count = 0L
// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
while (iter.hasNext) {
+ // Count is here.
+ count += 1
dataWriter.write(iter.next())
}
@@ -139,7 +148,7 @@ object DataWritingSparkTask extends Logging {
logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId" +
s"stage $stageId.$stageAttempt)")
- msg
+ DataWritingSparkTaskResult(count, msg)
})(catchBlock = {
// If there is an error, abort this writer
@@ -151,3 +160,12 @@ object DataWritingSparkTask extends Logging {
})
}
}
+
+private[v2] case class DataWritingSparkTaskResult(
+ numRows: Long,
+ writerCommitMessage: WriterCommitMessage)
+
+/**
+ * Sink progress information collected after commit.
+ */
+private[sql] case class StreamWriterCommitProgress(numOutputRows: Long)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
index f5d93ee5fa914..e4ec76f0b9a1f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
@@ -73,14 +73,14 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan}
* greater than the target size.
*
* For example, we have two stages with the following pre-shuffle partition size statistics:
- * stage 1: [100 MB, 20 MB, 100 MB, 10MB, 30 MB]
- * stage 2: [10 MB, 10 MB, 70 MB, 5 MB, 5 MB]
- * assuming the target input size is 128 MB, we will have four post-shuffle partitions,
+ * stage 1: [100 MiB, 20 MiB, 100 MiB, 10MiB, 30 MiB]
+ * stage 2: [10 MiB, 10 MiB, 70 MiB, 5 MiB, 5 MiB]
+ * assuming the target input size is 128 MiB, we will have four post-shuffle partitions,
* which are:
- * - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MB)
- * - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MB)
- * - post-shuffle partition 2: pre-shuffle partition 2 (size 170 MB)
- * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MB)
+ * - post-shuffle partition 0: pre-shuffle partition 0 (size 110 MiB)
+ * - post-shuffle partition 1: pre-shuffle partition 1 (size 30 MiB)
+ * - post-shuffle partition 2: pre-shuffle partition 2 (size 170 MiB)
+ * - post-shuffle partition 3: pre-shuffle partition 3 and 4 (size 50 MiB)
*/
class ExchangeCoordinator(
advisoryTargetPostShuffleInputSize: Long,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index c9ca395bceaa4..da7b0c6f43fbc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -23,6 +23,7 @@ import java.util.function.Supplier
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
@@ -30,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.metric.{SQLMetrics, SQLShuffleMetricsReporter}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.MutablePair
@@ -46,10 +47,13 @@ case class ShuffleExchangeExec(
// NOTE: coordinator can be null after serialization/deserialization,
// e.g. it can be null on the Executor side
-
+ private lazy val writeMetrics =
+ SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+ private lazy val readMetrics =
+ SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
override lazy val metrics = Map(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size")
- ) ++ SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
+ ) ++ readMetrics ++ writeMetrics
override def nodeName: String = {
val extraInfo = coordinator match {
@@ -90,7 +94,11 @@ case class ShuffleExchangeExec(
private[exchange] def prepareShuffleDependency()
: ShuffleDependency[Int, InternalRow, InternalRow] = {
ShuffleExchangeExec.prepareShuffleDependency(
- child.execute(), child.output, newPartitioning, serializer)
+ child.execute(),
+ child.output,
+ newPartitioning,
+ serializer,
+ writeMetrics)
}
/**
@@ -109,7 +117,7 @@ case class ShuffleExchangeExec(
assert(newPartitioning.isInstanceOf[HashPartitioning])
newPartitioning = UnknownPartitioning(indices.length)
}
- new ShuffledRowRDD(shuffleDependency, metrics, specifiedPartitionStartIndices)
+ new ShuffledRowRDD(shuffleDependency, readMetrics, specifiedPartitionStartIndices)
}
/**
@@ -204,7 +212,9 @@ object ShuffleExchangeExec {
rdd: RDD[InternalRow],
outputAttributes: Seq[Attribute],
newPartitioning: Partitioning,
- serializer: Serializer): ShuffleDependency[Int, InternalRow, InternalRow] = {
+ serializer: Serializer,
+ writeMetrics: Map[String, SQLMetric])
+ : ShuffleDependency[Int, InternalRow, InternalRow] = {
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
case HashPartitioning(_, n) =>
@@ -333,8 +343,22 @@ object ShuffleExchangeExec {
new ShuffleDependency[Int, InternalRow, InternalRow](
rddWithPartitionIds,
new PartitionIdPassthrough(part.numPartitions),
- serializer)
+ serializer,
+ shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics))
dependency
}
+
+ /**
+ * Create a customized [[ShuffleWriteProcessor]] for SQL which wrap the default metrics reporter
+ * with [[SQLShuffleWriteMetricsReporter]] as new reporter for [[ShuffleWriteProcessor]].
+ */
+ def createShuffleWriteProcessor(metrics: Map[String, SQLMetric]): ShuffleWriteProcessor = {
+ new ShuffleWriteProcessor {
+ override protected def createMetricsReporter(
+ context: TaskContext): ShuffleWriteMetricsReporter = {
+ new SQLShuffleWriteMetricsReporter(context.taskMetrics().shuffleWriteMetrics, metrics)
+ }
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
index a6f3ea47c8492..fd4a7897c7ad1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala
@@ -29,7 +29,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Dist
import org.apache.spark.sql.execution.{BinaryExecNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.{BooleanType, LongType}
-import org.apache.spark.util.TaskCompletionListener
/**
* Performs an inner hash join of two child relations. When the output RDD of this operator is
@@ -48,8 +47,7 @@ case class BroadcastHashJoinExec(
extends BinaryExecNode with HashJoin with CodegenSupport {
override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
- "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
+ "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
override def requiredChildDistribution: Seq[Distribution] = {
val mode = HashedRelationBroadcastMode(buildKeys)
@@ -63,13 +61,12 @@ case class BroadcastHashJoinExec(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val avgHashProbe = longMetric("avgHashProbe")
val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
val hashed = broadcastRelation.value.asReadOnlyCopy()
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashed.estimatedSize)
- join(streamedIter, hashed, numOutputRows, avgHashProbe)
+ join(streamedIter, hashed, numOutputRows)
}
}
@@ -111,23 +108,6 @@ case class BroadcastHashJoinExec(
}
}
- /**
- * Returns the codes used to add a task completion listener to update avg hash probe
- * at the end of the task.
- */
- private def genTaskListener(avgHashProbe: String, relationTerm: String): String = {
- val listenerClass = classOf[TaskCompletionListener].getName
- val taskContextClass = classOf[TaskContext].getName
- s"""
- | $taskContextClass$$.MODULE$$.get().addTaskCompletionListener(new $listenerClass() {
- | @Override
- | public void onTaskCompletion($taskContextClass context) {
- | $avgHashProbe.set($relationTerm.getAverageProbesPerLookup());
- | }
- | });
- """.stripMargin
- }
-
/**
* Returns a tuple of Broadcast of HashedRelation and the variable name for it.
*/
@@ -137,15 +117,11 @@ case class BroadcastHashJoinExec(
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
val clsName = broadcastRelation.value.getClass.getName
- // At the end of the task, we update the avg hash probe.
- val avgHashProbe = metricTerm(ctx, "avgHashProbe")
-
// Inline mutable state since not many join operations in a task
val relationTerm = ctx.addMutableState(clsName, "relation",
v => s"""
| $v = (($clsName) $broadcast.value()).asReadOnlyCopy();
| incPeakExecutionMemory($v.estimatedSize());
- | ${genTaskListener(avgHashProbe, v)}
""".stripMargin, forceInline = true)
(broadcastRelation, relationTerm)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index dab873bf9b9a0..1aef5f6864263 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
@@ -194,8 +193,7 @@ trait HashJoin {
protected def join(
streamedIter: Iterator[InternalRow],
hashed: HashedRelation,
- numOutputRows: SQLMetric,
- avgHashProbe: SQLMetric): Iterator[InternalRow] = {
+ numOutputRows: SQLMetric): Iterator[InternalRow] = {
val joinedIter = joinType match {
case _: InnerLike =>
@@ -213,10 +211,6 @@ trait HashJoin {
s"BroadcastHashJoin should not take $x as the JoinType")
}
- // At the end of the task, we update the avg hash probe.
- TaskContext.get().addTaskCompletionListener[Unit](_ =>
- avgHashProbe.set(hashed.getAverageProbesPerLookup))
-
val resultProj = createResultProjection
joinedIter.map { r =>
numOutputRows += 1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index e8c01d46a84c0..7c21062c4cec3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -80,11 +80,6 @@ private[execution] sealed trait HashedRelation extends KnownSizeEstimation {
* Release any used resources.
*/
def close(): Unit
-
- /**
- * Returns the average number of probes per key lookup.
- */
- def getAverageProbesPerLookup: Double
}
private[execution] object HashedRelation {
@@ -248,8 +243,7 @@ private[joins] class UnsafeHashedRelation(
binaryMap = new BytesToBytesMap(
taskMemoryManager,
(nKeys * 1.5 + 1).toInt, // reduce hash collision
- pageSizeBytes,
- true)
+ pageSizeBytes)
var i = 0
var keyBuffer = new Array[Byte](1024)
@@ -280,8 +274,6 @@ private[joins] class UnsafeHashedRelation(
override def read(kryo: Kryo, in: Input): Unit = Utils.tryOrIOException {
read(() => in.readInt(), () => in.readLong(), in.readBytes)
}
-
- override def getAverageProbesPerLookup: Double = binaryMap.getAverageProbesPerLookup
}
private[joins] object UnsafeHashedRelation {
@@ -299,8 +291,7 @@ private[joins] object UnsafeHashedRelation {
taskMemoryManager,
// Only 70% of the slots can be used before growing, more capacity help to reduce collision
(sizeEstimate * 1.5 + 1).toInt,
- pageSizeBytes,
- true)
+ pageSizeBytes)
// Create a mapping of buildKeys -> rows
val keyGenerator = UnsafeProjection.create(key)
@@ -397,10 +388,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
// The number of unique keys.
private var numKeys = 0L
- // Tracking average number of probes per key lookup.
- private var numKeyLookups = 0L
- private var numProbes = 0L
-
// needed by serializer
def this() = {
this(
@@ -485,8 +472,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
if (isDense) {
- numKeyLookups += 1
- numProbes += 1
if (key >= minKey && key <= maxKey) {
val value = array((key - minKey).toInt)
if (value > 0) {
@@ -495,14 +480,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
var pos = firstSlot(key)
- numKeyLookups += 1
- numProbes += 1
while (array(pos + 1) != 0) {
if (array(pos) == key) {
return getRow(array(pos + 1), resultRow)
}
pos = nextSlot(pos)
- numProbes += 1
}
}
null
@@ -530,8 +512,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
if (isDense) {
- numKeyLookups += 1
- numProbes += 1
if (key >= minKey && key <= maxKey) {
val value = array((key - minKey).toInt)
if (value > 0) {
@@ -540,14 +520,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
var pos = firstSlot(key)
- numKeyLookups += 1
- numProbes += 1
while (array(pos + 1) != 0) {
if (array(pos) == key) {
return valueIter(array(pos + 1), resultRow)
}
pos = nextSlot(pos)
- numProbes += 1
}
}
null
@@ -587,11 +564,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
private def updateIndex(key: Long, address: Long): Unit = {
var pos = firstSlot(key)
assert(numKeys < array.length / 2)
- numKeyLookups += 1
- numProbes += 1
while (array(pos) != key && array(pos + 1) != 0) {
pos = nextSlot(pos)
- numProbes += 1
}
if (array(pos + 1) == 0) {
// this is the first value for this key, put the address in array.
@@ -723,8 +697,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
writeLong(maxKey)
writeLong(numKeys)
writeLong(numValues)
- writeLong(numKeyLookups)
- writeLong(numProbes)
writeLong(array.length)
writeLongArray(writeBuffer, array, array.length)
@@ -766,8 +738,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
maxKey = readLong()
numKeys = readLong()
numValues = readLong()
- numKeyLookups = readLong()
- numProbes = readLong()
val length = readLong().toInt
mask = length - 2
@@ -785,11 +755,6 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
override def read(kryo: Kryo, in: Input): Unit = {
read(() => in.readBoolean(), () => in.readLong(), in.readBytes)
}
-
- /**
- * Returns the average number of probes per key lookup.
- */
- def getAverageProbesPerLookup: Double = numProbes.toDouble / numKeyLookups
}
private[joins] class LongHashedRelation(
@@ -841,8 +806,6 @@ private[joins] class LongHashedRelation(
resultRow = new UnsafeRow(nFields)
map = in.readObject().asInstanceOf[LongToUnsafeRowMap]
}
-
- override def getAverageProbesPerLookup: Double = map.getAverageProbesPerLookup
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
index 2b59ed6e4d16b..524804d61e599 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala
@@ -42,8 +42,7 @@ case class ShuffledHashJoinExec(
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"buildDataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size of build side"),
- "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"),
- "avgHashProbe" -> SQLMetrics.createAverageMetric(sparkContext, "avg hash probe"))
+ "buildTime" -> SQLMetrics.createTimingMetric(sparkContext, "time to build hash map"))
override def requiredChildDistribution: Seq[Distribution] =
HashClusteredDistribution(leftKeys) :: HashClusteredDistribution(rightKeys) :: Nil
@@ -63,10 +62,9 @@ case class ShuffledHashJoinExec(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val avgHashProbe = longMetric("avgHashProbe")
streamedPlan.execute().zipPartitions(buildPlan.execute()) { (streamIter, buildIter) =>
val hashed = buildHashedRelation(buildIter)
- join(streamIter, hashed, numOutputRows, avgHashProbe)
+ join(streamIter, hashed, numOutputRows)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index d558d1fcaf06f..56973af8fd648 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGe
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
-import org.apache.spark.sql.execution.metric.SQLShuffleMetricsReporter
+import org.apache.spark.sql.execution.metric.{SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter}
/**
* Take the first `limit` elements and collect them to a single partition.
@@ -38,13 +38,21 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode
override def outputPartitioning: Partitioning = SinglePartition
override def executeCollect(): Array[InternalRow] = child.executeTake(limit)
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
- override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
+ private lazy val writeMetrics =
+ SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+ private lazy val readMetrics =
+ SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+ override lazy val metrics = readMetrics ++ writeMetrics
protected override def doExecute(): RDD[InternalRow] = {
val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit))
val shuffled = new ShuffledRowRDD(
ShuffleExchangeExec.prepareShuffleDependency(
- locallyLimited, child.output, SinglePartition, serializer),
- metrics)
+ locallyLimited,
+ child.output,
+ SinglePartition,
+ serializer,
+ writeMetrics),
+ readMetrics)
shuffled.mapPartitionsInternal(_.take(limit))
}
}
@@ -154,7 +162,11 @@ case class TakeOrderedAndProjectExec(
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
- override lazy val metrics = SQLShuffleMetricsReporter.createShuffleReadMetrics(sparkContext)
+ private lazy val writeMetrics =
+ SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+ private lazy val readMetrics =
+ SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+ override lazy val metrics = readMetrics ++ writeMetrics
protected override def doExecute(): RDD[InternalRow] = {
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
@@ -165,8 +177,12 @@ case class TakeOrderedAndProjectExec(
}
val shuffled = new ShuffledRowRDD(
ShuffleExchangeExec.prepareShuffleDependency(
- localTopK, child.output, SinglePartition, serializer),
- metrics)
+ localTopK,
+ child.output,
+ SinglePartition,
+ serializer,
+ writeMetrics),
+ readMetrics)
shuffled.mapPartitions { iter =>
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
if (projectList != child.output) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
index cbf707f4a9cfd..19809b07508d9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.metric
import java.text.NumberFormat
import java.util.Locale
+import scala.concurrent.duration._
+
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
@@ -78,6 +80,7 @@ object SQLMetrics {
private val SUM_METRIC = "sum"
private val SIZE_METRIC = "size"
private val TIMING_METRIC = "timing"
+ private val NS_TIMING_METRIC = "nsTiming"
private val AVERAGE_METRIC = "average"
private val baseForAvgMetric: Int = 10
@@ -121,6 +124,13 @@ object SQLMetrics {
acc
}
+ def createNanoTimingMetric(sc: SparkContext, name: String): SQLMetric = {
+ // Same with createTimingMetric, just normalize the unit of time to millisecond.
+ val acc = new SQLMetric(NS_TIMING_METRIC, -1)
+ acc.register(sc, name = Some(s"$name total (min, med, max)"), countFailedValues = false)
+ acc
+ }
+
/**
* Create a metric to report the average information (including min, med, max) like
* avg hash probe. As average metrics are double values, this kind of metrics should be
@@ -163,6 +173,8 @@ object SQLMetrics {
Utils.bytesToString
} else if (metricsType == TIMING_METRIC) {
Utils.msDurationToString
+ } else if (metricsType == NS_TIMING_METRIC) {
+ duration => Utils.msDurationToString(duration.nanos.toMillis)
} else {
throw new IllegalStateException("unexpected metrics type: " + metricsType)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala
index 780f0d7622294..2c0ea80495abb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLShuffleMetricsReporter.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.metric
import org.apache.spark.SparkContext
import org.apache.spark.executor.TempShuffleReadMetrics
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
/**
* A shuffle metrics reporter for SQL exchange operators.
@@ -26,23 +27,23 @@ import org.apache.spark.executor.TempShuffleReadMetrics
* @param metrics All metrics in current SparkPlan. This param should not empty and
* contains all shuffle metrics defined in createShuffleReadMetrics.
*/
-private[spark] class SQLShuffleMetricsReporter(
+class SQLShuffleReadMetricsReporter(
tempMetrics: TempShuffleReadMetrics,
metrics: Map[String, SQLMetric]) extends TempShuffleReadMetrics {
private[this] val _remoteBlocksFetched =
- metrics(SQLShuffleMetricsReporter.REMOTE_BLOCKS_FETCHED)
+ metrics(SQLShuffleReadMetricsReporter.REMOTE_BLOCKS_FETCHED)
private[this] val _localBlocksFetched =
- metrics(SQLShuffleMetricsReporter.LOCAL_BLOCKS_FETCHED)
+ metrics(SQLShuffleReadMetricsReporter.LOCAL_BLOCKS_FETCHED)
private[this] val _remoteBytesRead =
- metrics(SQLShuffleMetricsReporter.REMOTE_BYTES_READ)
+ metrics(SQLShuffleReadMetricsReporter.REMOTE_BYTES_READ)
private[this] val _remoteBytesReadToDisk =
- metrics(SQLShuffleMetricsReporter.REMOTE_BYTES_READ_TO_DISK)
+ metrics(SQLShuffleReadMetricsReporter.REMOTE_BYTES_READ_TO_DISK)
private[this] val _localBytesRead =
- metrics(SQLShuffleMetricsReporter.LOCAL_BYTES_READ)
+ metrics(SQLShuffleReadMetricsReporter.LOCAL_BYTES_READ)
private[this] val _fetchWaitTime =
- metrics(SQLShuffleMetricsReporter.FETCH_WAIT_TIME)
+ metrics(SQLShuffleReadMetricsReporter.FETCH_WAIT_TIME)
private[this] val _recordsRead =
- metrics(SQLShuffleMetricsReporter.RECORDS_READ)
+ metrics(SQLShuffleReadMetricsReporter.RECORDS_READ)
override def incRemoteBlocksFetched(v: Long): Unit = {
_remoteBlocksFetched.add(v)
@@ -74,7 +75,7 @@ private[spark] class SQLShuffleMetricsReporter(
}
}
-private[spark] object SQLShuffleMetricsReporter {
+object SQLShuffleReadMetricsReporter {
val REMOTE_BLOCKS_FETCHED = "remoteBlocksFetched"
val LOCAL_BLOCKS_FETCHED = "localBlocksFetched"
val REMOTE_BYTES_READ = "remoteBytesRead"
@@ -87,11 +88,65 @@ private[spark] object SQLShuffleMetricsReporter {
* Create all shuffle read relative metrics and return the Map.
*/
def createShuffleReadMetrics(sc: SparkContext): Map[String, SQLMetric] = Map(
- REMOTE_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "remote blocks fetched"),
- LOCAL_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "local blocks fetched"),
+ REMOTE_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "remote blocks read"),
+ LOCAL_BLOCKS_FETCHED -> SQLMetrics.createMetric(sc, "local blocks read"),
REMOTE_BYTES_READ -> SQLMetrics.createSizeMetric(sc, "remote bytes read"),
REMOTE_BYTES_READ_TO_DISK -> SQLMetrics.createSizeMetric(sc, "remote bytes read to disk"),
LOCAL_BYTES_READ -> SQLMetrics.createSizeMetric(sc, "local bytes read"),
FETCH_WAIT_TIME -> SQLMetrics.createTimingMetric(sc, "fetch wait time"),
RECORDS_READ -> SQLMetrics.createMetric(sc, "records read"))
}
+
+/**
+ * A shuffle write metrics reporter for SQL exchange operators.
+ * @param metricsReporter Other reporter need to be updated in this SQLShuffleWriteMetricsReporter.
+ * @param metrics Shuffle write metrics in current SparkPlan.
+ */
+class SQLShuffleWriteMetricsReporter(
+ metricsReporter: ShuffleWriteMetricsReporter,
+ metrics: Map[String, SQLMetric]) extends ShuffleWriteMetricsReporter {
+ private[this] val _bytesWritten =
+ metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_BYTES_WRITTEN)
+ private[this] val _recordsWritten =
+ metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_RECORDS_WRITTEN)
+ private[this] val _writeTime =
+ metrics(SQLShuffleWriteMetricsReporter.SHUFFLE_WRITE_TIME)
+
+ override def incBytesWritten(v: Long): Unit = {
+ metricsReporter.incBytesWritten(v)
+ _bytesWritten.add(v)
+ }
+ override def decRecordsWritten(v: Long): Unit = {
+ metricsReporter.decBytesWritten(v)
+ _recordsWritten.set(_recordsWritten.value - v)
+ }
+ override def incRecordsWritten(v: Long): Unit = {
+ metricsReporter.incRecordsWritten(v)
+ _recordsWritten.add(v)
+ }
+ override def incWriteTime(v: Long): Unit = {
+ metricsReporter.incWriteTime(v)
+ _writeTime.add(v)
+ }
+ override def decBytesWritten(v: Long): Unit = {
+ metricsReporter.decBytesWritten(v)
+ _bytesWritten.set(_bytesWritten.value - v)
+ }
+}
+
+object SQLShuffleWriteMetricsReporter {
+ val SHUFFLE_BYTES_WRITTEN = "shuffleBytesWritten"
+ val SHUFFLE_RECORDS_WRITTEN = "shuffleRecordsWritten"
+ val SHUFFLE_WRITE_TIME = "shuffleWriteTime"
+
+ /**
+ * Create all shuffle write relative metrics and return the Map.
+ */
+ def createShuffleWriteMetrics(sc: SparkContext): Map[String, SQLMetric] = Map(
+ SHUFFLE_BYTES_WRITTEN ->
+ SQLMetrics.createSizeMetric(sc, "shuffle bytes written"),
+ SHUFFLE_RECORDS_WRITTEN ->
+ SQLMetrics.createMetric(sc, "shuffle records written"),
+ SHUFFLE_WRITE_TIME ->
+ SQLMetrics.createNanoTimingMetric(sc, "shuffle write time"))
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
index 2b87796dc6833..a5203daea9cd0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala
@@ -60,8 +60,12 @@ private class BatchIterator[T](iter: Iterator[T], batchSize: Int)
/**
* A logical plan that evaluates a [[PythonUDF]].
*/
-case class ArrowEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
- extends UnaryNode
+case class ArrowEvalPython(
+ udfs: Seq[PythonUDF],
+ output: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+ override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+}
/**
* A physical plan that evaluates a [[PythonUDF]].
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index b08b7e60e130b..d3736d24e5019 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -32,8 +32,12 @@ import org.apache.spark.sql.types.{StructField, StructType}
/**
* A logical plan that evaluates a [[PythonUDF]]
*/
-case class BatchEvalPython(udfs: Seq[PythonUDF], output: Seq[Attribute], child: LogicalPlan)
- extends UnaryNode
+case class BatchEvalPython(
+ udfs: Seq[PythonUDF],
+ output: Seq[Attribute],
+ child: LogicalPlan) extends UnaryNode {
+ override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+}
/**
* A physical plan that evaluates a [[PythonUDF]]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 90b5325919e96..380c31baa6213 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -24,7 +24,7 @@ import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
-import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -131,8 +131,20 @@ object ExtractPythonUDFs extends Rule[LogicalPlan] with PredicateHelper {
expressions.flatMap(collectEvaluableUDFs)
}
- def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case plan: LogicalPlan => extract(plan)
+ def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ // SPARK-26293: A subquery will be rewritten into join later, and will go through this rule
+ // eventually. Here we skip subquery, as Python UDF only needs to be extracted once.
+ case _: Subquery => plan
+
+ case _ => plan transformUp {
+ // A safe guard. `ExtractPythonUDFs` only runs once, so we will not hit `BatchEvalPython` and
+ // `ArrowEvalPython` in the input plan. However if we hit them, we must skip them, as we can't
+ // extract Python UDFs from them.
+ case p: BatchEvalPython => p
+ case p: ArrowEvalPython => p
+
+ case plan: LogicalPlan => extract(plan)
+ }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
index 27bed1137e5b3..1ce1215bfdd62 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala
@@ -27,24 +27,71 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
+import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan}
import org.apache.spark.sql.execution.arrow.ArrowUtils
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.execution.window._
+import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
+/**
+ * This class calculates and outputs windowed aggregates over the rows in a single partition.
+ *
+ * This is similar to [[WindowExec]]. The main difference is that this node does not compute
+ * any window aggregation values. Instead, it computes the lower and upper bound for each window
+ * (i.e. window bounds) and pass the data and indices to Python worker to do the actual window
+ * aggregation.
+ *
+ * It currently materializes all data associated with the same partition key and passes them to
+ * Python worker. This is not strictly necessary for sliding windows and can be improved (by
+ * possibly slicing data into overlapping chunks and stitching them together).
+ *
+ * This class groups window expressions by their window boundaries so that window expressions
+ * with the same window boundaries can share the same window bounds. The window bounds are
+ * prepended to the data passed to the python worker.
+ *
+ * For example, if we have:
+ * avg(v) over specifiedwindowframe(RowFrame, -5, 5),
+ * avg(v) over specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing),
+ * avg(v) over specifiedwindowframe(RowFrame, -3, 3),
+ * max(v) over specifiedwindowframe(RowFrame, -3, 3)
+ *
+ * The python input will look like:
+ * (lower_bound_w1, upper_bound_w1, lower_bound_w3, upper_bound_w3, v)
+ *
+ * where w1 is specifiedwindowframe(RowFrame, -5, 5)
+ * w2 is specifiedwindowframe(RowFrame, UnboundedPreceding, UnboundedFollowing)
+ * w3 is specifiedwindowframe(RowFrame, -3, 3)
+ *
+ * Note that w2 doesn't have bound indices in the python input because it's unbounded window
+ * so it's bound indices will always be the same.
+ *
+ * Bounded window and Unbounded window are evaluated differently in Python worker:
+ * (1) Bounded window takes the window bound indices in addition to the input columns.
+ * Unbounded window takes only input columns.
+ * (2) Bounded window evaluates the udf once per input row.
+ * Unbounded window evaluates the udf once per window partition.
+ * This is controlled by Python runner conf "pandas_window_bound_types"
+ *
+ * The logic to compute window bounds is delegated to [[WindowFunctionFrame]] and shared with
+ * [[WindowExec]]
+ *
+ * Note this doesn't support partial aggregation and all aggregation is computed from the entire
+ * window.
+ */
case class WindowInPandasExec(
windowExpression: Seq[NamedExpression],
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
- child: SparkPlan) extends UnaryExecNode {
+ child: SparkPlan)
+ extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) {
override def output: Seq[Attribute] =
child.output ++ windowExpression.map(_.toAttribute)
override def requiredChildDistribution: Seq[Distribution] = {
if (partitionSpec.isEmpty) {
- // Only show warning when the number of bytes is larger than 100 MB?
+ // Only show warning when the number of bytes is larger than 100 MiB?
logWarning("No Partition Defined for Window operation! Moving all data to a single "
+ "partition, this can cause serious performance degradation.")
AllTuples :: Nil
@@ -60,6 +107,26 @@ case class WindowInPandasExec(
override def outputPartitioning: Partitioning = child.outputPartitioning
+ /**
+ * Helper functions and data structures for window bounds
+ *
+ * It contains:
+ * (1) Total number of window bound indices in the python input row
+ * (2) Function from frame index to its lower bound column index in the python input row
+ * (3) Function from frame index to its upper bound column index in the python input row
+ * (4) Seq from frame index to its window bound type
+ */
+ private type WindowBoundHelpers = (Int, Int => Int, Int => Int, Seq[WindowBoundType])
+
+ /**
+ * Enum for window bound types. Used only inside this class.
+ */
+ private sealed case class WindowBoundType(value: String)
+ private object UnboundedWindow extends WindowBoundType("unbounded")
+ private object BoundedWindow extends WindowBoundType("bounded")
+
+ private val windowBoundTypeConf = "pandas_window_bound_types"
+
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
@@ -73,68 +140,150 @@ case class WindowInPandasExec(
}
/**
- * Create the resulting projection.
- *
- * This method uses Code Generation. It can only be used on the executor side.
- *
- * @param expressions unbound ordered function expressions.
- * @return the final resulting projection.
+ * See [[WindowBoundHelpers]] for details.
*/
- private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
- val references = expressions.zipWithIndex.map { case (e, i) =>
- // Results of window expressions will be on the right side of child's output
- BoundReference(child.output.size + i, e.dataType, e.nullable)
+ private def computeWindowBoundHelpers(
+ factories: Seq[InternalRow => WindowFunctionFrame]
+ ): WindowBoundHelpers = {
+ val functionFrames = factories.map(_(EmptyRow))
+
+ val windowBoundTypes = functionFrames.map {
+ case _: UnboundedWindowFunctionFrame => UnboundedWindow
+ case _: UnboundedFollowingWindowFunctionFrame |
+ _: SlidingWindowFunctionFrame |
+ _: UnboundedPrecedingWindowFunctionFrame => BoundedWindow
+ // It should be impossible to get other types of window function frame here
+ case frame => throw new RuntimeException(s"Unexpected window function frame $frame.")
}
- val unboundToRefMap = expressions.zip(references).toMap
- val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
- UnsafeProjection.create(
- child.output ++ patchedWindowExpression,
- child.output)
+
+ val requiredIndices = functionFrames.map {
+ case _: UnboundedWindowFunctionFrame => 0
+ case _ => 2
+ }
+
+ val upperBoundIndices = requiredIndices.scan(0)(_ + _).tail
+
+ val boundIndices = requiredIndices.zip(upperBoundIndices).map { case (num, upperBoundIndex) =>
+ if (num == 0) {
+ // Sentinel values for unbounded window
+ (-1, -1)
+ } else {
+ (upperBoundIndex - 2, upperBoundIndex - 1)
+ }
+ }
+
+ def lowerBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._1
+ def upperBoundIndex(frameIndex: Int) = boundIndices(frameIndex)._2
+
+ (requiredIndices.sum, lowerBoundIndex, upperBoundIndex, windowBoundTypes)
}
protected override def doExecute(): RDD[InternalRow] = {
- val inputRDD = child.execute()
+ // Unwrap the expressions and factories from the map.
+ val expressionsWithFrameIndex =
+ windowFrameExpressionFactoryPairs.map(_._1).zipWithIndex.flatMap {
+ case (buffer, frameIndex) => buffer.map(expr => (expr, frameIndex))
+ }
+
+ val expressions = expressionsWithFrameIndex.map(_._1)
+ val expressionIndexToFrameIndex =
+ expressionsWithFrameIndex.map(_._2).zipWithIndex.map(_.swap).toMap
+
+ val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray
+ // Helper functions
+ val (numBoundIndices, lowerBoundIndex, upperBoundIndex, frameWindowBoundTypes) =
+ computeWindowBoundHelpers(factories)
+ val isBounded = { frameIndex: Int => lowerBoundIndex(frameIndex) >= 0 }
+ val numFrames = factories.length
+
+ val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold
+ val spillThreshold = conf.windowExecBufferSpillThreshold
val sessionLocalTimeZone = conf.sessionLocalTimeZone
- val pythonRunnerConf = ArrowUtils.getPythonRunnerConfMap(conf)
// Extract window expressions and window functions
- val expressions = windowExpression.flatMap(_.collect { case e: WindowExpression => e })
-
- val udfExpressions = expressions.map(_.windowFunction.asInstanceOf[PythonUDF])
+ val windowExpressions = expressions.flatMap(_.collect { case e: WindowExpression => e })
+ val udfExpressions = windowExpressions.map(_.windowFunction.asInstanceOf[PythonUDF])
+ // We shouldn't be chaining anything here.
+ // All chained python functions should only contain one function.
val (pyFuncs, inputs) = udfExpressions.map(collectFunctions).unzip
+ require(pyFuncs.length == expressions.length)
+
+ val udfWindowBoundTypes = pyFuncs.indices.map(i =>
+ frameWindowBoundTypes(expressionIndexToFrameIndex(i)))
+ val pythonRunnerConf: Map[String, String] = (ArrowUtils.getPythonRunnerConfMap(conf)
+ + (windowBoundTypeConf -> udfWindowBoundTypes.map(_.value).mkString(",")))
// Filter child output attributes down to only those that are UDF inputs.
- // Also eliminate duplicate UDF inputs.
- val allInputs = new ArrayBuffer[Expression]
- val dataTypes = new ArrayBuffer[DataType]
+ // Also eliminate duplicate UDF inputs. This is similar to how other Python UDF node
+ // handles UDF inputs.
+ val dataInputs = new ArrayBuffer[Expression]
+ val dataInputTypes = new ArrayBuffer[DataType]
val argOffsets = inputs.map { input =>
input.map { e =>
- if (allInputs.exists(_.semanticEquals(e))) {
- allInputs.indexWhere(_.semanticEquals(e))
+ if (dataInputs.exists(_.semanticEquals(e))) {
+ dataInputs.indexWhere(_.semanticEquals(e))
} else {
- allInputs += e
- dataTypes += e.dataType
- allInputs.length - 1
+ dataInputs += e
+ dataInputTypes += e.dataType
+ dataInputs.length - 1
}
}.toArray
}.toArray
- // Schema of input rows to the python runner
- val windowInputSchema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
- StructField(s"_$i", dt)
- })
+ // In addition to UDF inputs, we will prepend window bounds for each UDFs.
+ // For bounded windows, we prepend lower bound and upper bound. For unbounded windows,
+ // we no not add window bounds. (strictly speaking, we only need to lower or upper bound
+ // if the window is bounded only on one side, this can be improved in the future)
- inputRDD.mapPartitionsInternal { iter =>
- val context = TaskContext.get()
+ // Setting window bounds for each window frames. Each window frame has different bounds so
+ // each has its own window bound columns.
+ val windowBoundsInput = factories.indices.flatMap { frameIndex =>
+ if (isBounded(frameIndex)) {
+ Seq(
+ BoundReference(lowerBoundIndex(frameIndex), IntegerType, nullable = false),
+ BoundReference(upperBoundIndex(frameIndex), IntegerType, nullable = false)
+ )
+ } else {
+ Seq.empty
+ }
+ }
- val grouped = if (partitionSpec.isEmpty) {
- // Use an empty unsafe row as a place holder for the grouping key
- Iterator((new UnsafeRow(), iter))
+ // Setting the window bounds argOffset for each UDF. For UDFs with bounded window, argOffset
+ // for the UDF is (lowerBoundOffet, upperBoundOffset, inputOffset1, inputOffset2, ...)
+ // For UDFs with unbounded window, argOffset is (inputOffset1, inputOffset2, ...)
+ pyFuncs.indices.foreach { exprIndex =>
+ val frameIndex = expressionIndexToFrameIndex(exprIndex)
+ if (isBounded(frameIndex)) {
+ argOffsets(exprIndex) =
+ Array(lowerBoundIndex(frameIndex), upperBoundIndex(frameIndex)) ++
+ argOffsets(exprIndex).map(_ + windowBoundsInput.length)
} else {
- GroupedIterator(iter, partitionSpec, child.output)
+ argOffsets(exprIndex) = argOffsets(exprIndex).map(_ + windowBoundsInput.length)
}
+ }
+
+ val allInputs = windowBoundsInput ++ dataInputs
+ val allInputTypes = allInputs.map(_.dataType)
+
+ // Start processing.
+ child.execute().mapPartitions { iter =>
+ val context = TaskContext.get()
+
+ // Get all relevant projections.
+ val resultProj = createResultProjection(expressions)
+ val pythonInputProj = UnsafeProjection.create(
+ allInputs,
+ windowBoundsInput.map(ref =>
+ AttributeReference(s"i_${ref.ordinal}", ref.dataType)()) ++ child.output
+ )
+ val pythonInputSchema = StructType(
+ allInputTypes.zipWithIndex.map { case (dt, i) =>
+ StructField(s"_$i", dt)
+ }
+ )
+ val grouping = UnsafeProjection.create(partitionSpec, child.output)
// The queue used to buffer input rows so we can drain it to
// combine input with output from Python.
@@ -144,11 +293,94 @@ case class WindowInPandasExec(
queue.close()
}
- val inputProj = UnsafeProjection.create(allInputs, child.output)
- val pythonInput = grouped.map { case (_, rows) =>
- rows.map { row =>
- queue.add(row.asInstanceOf[UnsafeRow])
- inputProj(row)
+ val stream = iter.map { row =>
+ queue.add(row.asInstanceOf[UnsafeRow])
+ row
+ }
+
+ val pythonInput = new Iterator[Iterator[UnsafeRow]] {
+
+ // Manage the stream and the grouping.
+ var nextRow: UnsafeRow = null
+ var nextGroup: UnsafeRow = null
+ var nextRowAvailable: Boolean = false
+ private[this] def fetchNextRow() {
+ nextRowAvailable = stream.hasNext
+ if (nextRowAvailable) {
+ nextRow = stream.next().asInstanceOf[UnsafeRow]
+ nextGroup = grouping(nextRow)
+ } else {
+ nextRow = null
+ nextGroup = null
+ }
+ }
+ fetchNextRow()
+
+ // Manage the current partition.
+ val buffer: ExternalAppendOnlyUnsafeRowArray =
+ new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold)
+ var bufferIterator: Iterator[UnsafeRow] = _
+
+ val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType))
+
+ val frames = factories.map(_(indexRow))
+
+ private[this] def fetchNextPartition() {
+ // Collect all the rows in the current partition.
+ // Before we start to fetch new input rows, make a copy of nextGroup.
+ val currentGroup = nextGroup.copy()
+
+ // clear last partition
+ buffer.clear()
+
+ while (nextRowAvailable && nextGroup == currentGroup) {
+ buffer.add(nextRow)
+ fetchNextRow()
+ }
+
+ // Setup the frames.
+ var i = 0
+ while (i < numFrames) {
+ frames(i).prepare(buffer)
+ i += 1
+ }
+
+ // Setup iteration
+ rowIndex = 0
+ bufferIterator = buffer.generateIterator()
+ }
+
+ // Iteration
+ var rowIndex = 0
+
+ override final def hasNext: Boolean =
+ (bufferIterator != null && bufferIterator.hasNext) || nextRowAvailable
+
+ override final def next(): Iterator[UnsafeRow] = {
+ // Load the next partition if we need to.
+ if ((bufferIterator == null || !bufferIterator.hasNext) && nextRowAvailable) {
+ fetchNextPartition()
+ }
+
+ val join = new JoinedRow
+
+ bufferIterator.zipWithIndex.map {
+ case (current, index) =>
+ var frameIndex = 0
+ while (frameIndex < numFrames) {
+ frames(frameIndex).write(index, current)
+ // If the window is unbounded we don't need to write out window bounds.
+ if (isBounded(frameIndex)) {
+ indexRow.setInt(
+ lowerBoundIndex(frameIndex), frames(frameIndex).currentLowerBound())
+ indexRow.setInt(
+ upperBoundIndex(frameIndex), frames(frameIndex).currentUpperBound())
+ }
+ frameIndex += 1
+ }
+
+ pythonInputProj(join(indexRow, current))
+ }
}
}
@@ -156,12 +388,11 @@ case class WindowInPandasExec(
pyFuncs,
PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF,
argOffsets,
- windowInputSchema,
+ pythonInputSchema,
sessionLocalTimeZone,
pythonRunnerConf).compute(pythonInput, context.partitionId(), context)
val joined = new JoinedRow
- val resultProj = createResultProjection(expressions)
windowFunctionResult.flatMap(_.rowIterator.asScala).map { windowOutput =>
val leftRow = queue.remove()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
index 76794ed4e9766..38ecb0dd12daa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentBatchTimestamp,
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, WriteToDataSourceV2}
+import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2, WriteToDataSourceV2Exec}
import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWritSupport, RateControlMicroBatchReadSupport}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.v2._
@@ -148,6 +148,12 @@ class MicroBatchExecution(
logInfo(s"Query $prettyIdString was stopped")
}
+ /** Begins recording statistics about query progress for a given trigger. */
+ override protected def startTrigger(): Unit = {
+ super.startTrigger()
+ currentStatus = currentStatus.copy(isTriggerActive = true)
+ }
+
/**
* Repeatedly attempts to run batches as data arrives.
*/
@@ -241,6 +247,7 @@ class MicroBatchExecution(
* DONE
*/
private def populateStartOffsets(sparkSessionToRunBatches: SparkSession): Unit = {
+ sinkCommitProgress = None
offsetLog.getLatest() match {
case Some((latestBatchId, nextOffsets)) =>
/* First assume that we are re-executing the latest known batch
@@ -533,7 +540,8 @@ class MicroBatchExecution(
val nextBatch =
new Dataset(sparkSessionToRunBatch, lastExecution, RowEncoder(lastExecution.analyzed.schema))
- reportTimeTaken("addBatch") {
+ val batchSinkProgress: Option[StreamWriterCommitProgress] =
+ reportTimeTaken("addBatch") {
SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) {
sink match {
case s: Sink => s.addBatch(currentBatchId, nextBatch)
@@ -541,10 +549,15 @@ class MicroBatchExecution(
// This doesn't accumulate any data - it just forces execution of the microbatch writer.
nextBatch.collect()
}
+ lastExecution.executedPlan match {
+ case w: WriteToDataSourceV2Exec => w.commitProgress
+ case _ => None
+ }
}
}
withProgressLocked {
+ sinkCommitProgress = batchSinkProgress
watermarkTracker.updateWatermark(lastExecution.executedPlan)
commitLog.add(currentBatchId, CommitMetadata(watermarkTracker.currentWatermark))
committedOffsets ++= availableOffsets
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 6a22f0cc8431a..d1f3f74c5e731 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.QueryExecution
-import org.apache.spark.sql.execution.datasources.v2.DataSourceV2StreamingScanExec
+import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2StreamingScanExec, StreamWriterCommitProgress}
import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReadSupport
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent
@@ -56,6 +56,7 @@ trait ProgressReporter extends Logging {
protected def logicalPlan: LogicalPlan
protected def lastExecution: QueryExecution
protected def newData: Map[BaseStreamingSource, LogicalPlan]
+ protected def sinkCommitProgress: Option[StreamWriterCommitProgress]
protected def sources: Seq[BaseStreamingSource]
protected def sink: BaseStreamingSink
protected def offsetSeqMetadata: OffsetSeqMetadata
@@ -114,7 +115,6 @@ trait ProgressReporter extends Logging {
logDebug("Starting Trigger Calculation")
lastTriggerStartTimestamp = currentTriggerStartTimestamp
currentTriggerStartTimestamp = triggerClock.getTimeMillis()
- currentStatus = currentStatus.copy(isTriggerActive = true)
currentTriggerStartOffsets = null
currentTriggerEndOffsets = null
currentDurationsMs.clear()
@@ -168,7 +168,9 @@ trait ProgressReporter extends Logging {
)
}
- val sinkProgress = new SinkProgress(sink.toString)
+ val sinkProgress = SinkProgress(
+ sink.toString,
+ sinkCommitProgress.map(_.numOutputRows))
val newProgress = new StreamingQueryProgress(
id = id,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
index 89b4f40c9c0b9..83824f40ab90b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.command.StreamingExplainCommand
+import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
import org.apache.spark.util.{Clock, UninterruptibleThread, Utils}
@@ -114,6 +115,9 @@ abstract class StreamExecution(
@volatile
var availableOffsets = new StreamProgress
+ @volatile
+ var sinkCommitProgress: Option[StreamWriterCommitProgress] = None
+
/** The current batchId or -1 if execution has not yet been initialized. */
protected var currentBatchId: Long = -1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
index a1d2ac426f855..89033b70f1431 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala
@@ -119,6 +119,8 @@ class ContinuousExecution(
// For at least once, we can just ignore those reports and risk duplicates.
commitLog.getLatest() match {
case Some((latestEpochId, _)) =>
+ updateStatusMessage("Starting new streaming query " +
+ s"and getting offsets from latest epoch $latestEpochId")
val nextOffsets = offsetLog.get(latestEpochId).getOrElse {
throw new IllegalStateException(
s"Batch $latestEpochId was committed without end epoch offsets!")
@@ -130,6 +132,7 @@ class ContinuousExecution(
nextOffsets
case None =>
// We are starting this stream for the first time. Offsets are all None.
+ updateStatusMessage("Starting new streaming query")
logInfo(s"Starting new streaming query.")
currentBatchId = 0
OffsetSeq.fill(continuousSources.map(_ => null): _*)
@@ -264,6 +267,7 @@ class ContinuousExecution(
epochUpdateThread.setDaemon(true)
epochUpdateThread.start()
+ updateStatusMessage("Running")
reportTimeTaken("runContinuous") {
SQLExecution.withNewExecutionId(
sparkSessionForQuery, lastExecution) {
@@ -323,6 +327,8 @@ class ContinuousExecution(
* before this is called.
*/
def commit(epoch: Long): Unit = {
+ updateStatusMessage(s"Committing epoch $epoch")
+
assert(continuousSources.length == 1, "only one continuous source supported currently")
assert(offsetLog.get(epoch).isDefined, s"offset for epoch $epoch not reported before commit")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
index fede0f3e92d67..89f6edda2ef57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala
@@ -83,14 +83,14 @@ case class WindowExec(
partitionSpec: Seq[Expression],
orderSpec: Seq[SortOrder],
child: SparkPlan)
- extends UnaryExecNode {
+ extends WindowExecBase(windowExpression, partitionSpec, orderSpec, child) {
override def output: Seq[Attribute] =
child.output ++ windowExpression.map(_.toAttribute)
override def requiredChildDistribution: Seq[Distribution] = {
if (partitionSpec.isEmpty) {
- // Only show warning when the number of bytes is larger than 100 MB?
+ // Only show warning when the number of bytes is larger than 100 MiB?
logWarning("No Partition Defined for Window operation! Moving all data to a single "
+ "partition, this can cause serious performance degradation.")
AllTuples :: Nil
@@ -104,193 +104,6 @@ case class WindowExec(
override def outputPartitioning: Partitioning = child.outputPartitioning
- /**
- * Create a bound ordering object for a given frame type and offset. A bound ordering object is
- * used to determine which input row lies within the frame boundaries of an output row.
- *
- * This method uses Code Generation. It can only be used on the executor side.
- *
- * @param frame to evaluate. This can either be a Row or Range frame.
- * @param bound with respect to the row.
- * @param timeZone the session local timezone for time related calculations.
- * @return a bound ordering object.
- */
- private[this] def createBoundOrdering(
- frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = {
- (frame, bound) match {
- case (RowFrame, CurrentRow) =>
- RowBoundOrdering(0)
-
- case (RowFrame, IntegerLiteral(offset)) =>
- RowBoundOrdering(offset)
-
- case (RangeFrame, CurrentRow) =>
- val ordering = newOrdering(orderSpec, child.output)
- RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
-
- case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
- // Use only the first order expression when the offset is non-null.
- val sortExpr = orderSpec.head
- val expr = sortExpr.child
-
- // Create the projection which returns the current 'value'.
- val current = newMutableProjection(expr :: Nil, child.output)
-
- // Flip the sign of the offset when processing the order is descending
- val boundOffset = sortExpr.direction match {
- case Descending => UnaryMinus(offset)
- case Ascending => offset
- }
-
- // Create the projection which returns the current 'value' modified by adding the offset.
- val boundExpr = (expr.dataType, boundOffset.dataType) match {
- case (DateType, IntegerType) => DateAdd(expr, boundOffset)
- case (TimestampType, CalendarIntervalType) =>
- TimeAdd(expr, boundOffset, Some(timeZone))
- case (a, b) if a== b => Add(expr, boundOffset)
- }
- val bound = newMutableProjection(boundExpr :: Nil, child.output)
-
- // Construct the ordering. This is used to compare the result of current value projection
- // to the result of bound value projection. This is done manually because we want to use
- // Code Generation (if it is enabled).
- val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil
- val ordering = newOrdering(boundSortExprs, Nil)
- RangeBoundOrdering(ordering, current, bound)
-
- case (RangeFrame, _) =>
- sys.error("Non-Zero range offsets are not supported for windows " +
- "with multiple order expressions.")
- }
- }
-
- /**
- * Collection containing an entry for each window frame to process. Each entry contains a frame's
- * [[WindowExpression]]s and factory function for the WindowFrameFunction.
- */
- private[this] lazy val windowFrameExpressionFactoryPairs = {
- type FrameKey = (String, FrameType, Expression, Expression)
- type ExpressionBuffer = mutable.Buffer[Expression]
- val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]
-
- // Add a function and its function to the map for a given frame.
- def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
- val key = (tpe, fr.frameType, fr.lower, fr.upper)
- val (es, fns) = framedFunctions.getOrElseUpdate(
- key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
- es += e
- fns += fn
- }
-
- // Collect all valid window functions and group them by their frame.
- windowExpression.foreach { x =>
- x.foreach {
- case e @ WindowExpression(function, spec) =>
- val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
- function match {
- case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
- case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
- case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
- case f => sys.error(s"Unsupported window function: $f")
- }
- case _ =>
- }
- }
-
- // Map the groups to a (unbound) expression and frame factory pair.
- var numExpressions = 0
- val timeZone = conf.sessionLocalTimeZone
- framedFunctions.toSeq.map {
- case (key, (expressions, functionSeq)) =>
- val ordinal = numExpressions
- val functions = functionSeq.toArray
-
- // Construct an aggregate processor if we need one.
- def processor = AggregateProcessor(
- functions,
- ordinal,
- child.output,
- (expressions, schema) =>
- newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
-
- // Create the factory
- val factory = key match {
- // Offset Frame
- case ("OFFSET", _, IntegerLiteral(offset), _) =>
- target: InternalRow =>
- new OffsetWindowFunctionFrame(
- target,
- ordinal,
- // OFFSET frame functions are guaranteed be OffsetWindowFunctions.
- functions.map(_.asInstanceOf[OffsetWindowFunction]),
- child.output,
- (expressions, schema) =>
- newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
- offset)
-
- // Entire Partition Frame.
- case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
- target: InternalRow => {
- new UnboundedWindowFunctionFrame(target, processor)
- }
-
- // Growing Frame.
- case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
- target: InternalRow => {
- new UnboundedPrecedingWindowFunctionFrame(
- target,
- processor,
- createBoundOrdering(frameType, upper, timeZone))
- }
-
- // Shrinking Frame.
- case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
- target: InternalRow => {
- new UnboundedFollowingWindowFunctionFrame(
- target,
- processor,
- createBoundOrdering(frameType, lower, timeZone))
- }
-
- // Moving Frame.
- case ("AGGREGATE", frameType, lower, upper) =>
- target: InternalRow => {
- new SlidingWindowFunctionFrame(
- target,
- processor,
- createBoundOrdering(frameType, lower, timeZone),
- createBoundOrdering(frameType, upper, timeZone))
- }
- }
-
- // Keep track of the number of expressions. This is a side-effect in a map...
- numExpressions += expressions.size
-
- // Create the Frame Expression - Factory pair.
- (expressions, factory)
- }
- }
-
- /**
- * Create the resulting projection.
- *
- * This method uses Code Generation. It can only be used on the executor side.
- *
- * @param expressions unbound ordered function expressions.
- * @return the final resulting projection.
- */
- private[this] def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
- val references = expressions.zipWithIndex.map{ case (e, i) =>
- // Results of window expressions will be on the right side of child's output
- BoundReference(child.output.size + i, e.dataType, e.nullable)
- }
- val unboundToRefMap = expressions.zip(references).toMap
- val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
- UnsafeProjection.create(
- child.output ++ patchedWindowExpression,
- child.output)
- }
-
protected override def doExecute(): RDD[InternalRow] = {
// Unwrap the expressions and factories from the map.
val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
new file mode 100644
index 0000000000000..dcb86f48bdf32
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExecBase.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.window
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
+import org.apache.spark.sql.types.{CalendarIntervalType, DateType, IntegerType, TimestampType}
+
+abstract class WindowExecBase(
+ windowExpression: Seq[NamedExpression],
+ partitionSpec: Seq[Expression],
+ orderSpec: Seq[SortOrder],
+ child: SparkPlan) extends UnaryExecNode {
+
+ /**
+ * Create the resulting projection.
+ *
+ * This method uses Code Generation. It can only be used on the executor side.
+ *
+ * @param expressions unbound ordered function expressions.
+ * @return the final resulting projection.
+ */
+ protected def createResultProjection(expressions: Seq[Expression]): UnsafeProjection = {
+ val references = expressions.zipWithIndex.map { case (e, i) =>
+ // Results of window expressions will be on the right side of child's output
+ BoundReference(child.output.size + i, e.dataType, e.nullable)
+ }
+ val unboundToRefMap = expressions.zip(references).toMap
+ val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap))
+ UnsafeProjection.create(
+ child.output ++ patchedWindowExpression,
+ child.output)
+ }
+
+ /**
+ * Create a bound ordering object for a given frame type and offset. A bound ordering object is
+ * used to determine which input row lies within the frame boundaries of an output row.
+ *
+ * This method uses Code Generation. It can only be used on the executor side.
+ *
+ * @param frame to evaluate. This can either be a Row or Range frame.
+ * @param bound with respect to the row.
+ * @param timeZone the session local timezone for time related calculations.
+ * @return a bound ordering object.
+ */
+ private def createBoundOrdering(
+ frame: FrameType, bound: Expression, timeZone: String): BoundOrdering = {
+ (frame, bound) match {
+ case (RowFrame, CurrentRow) =>
+ RowBoundOrdering(0)
+
+ case (RowFrame, IntegerLiteral(offset)) =>
+ RowBoundOrdering(offset)
+
+ case (RangeFrame, CurrentRow) =>
+ val ordering = newOrdering(orderSpec, child.output)
+ RangeBoundOrdering(ordering, IdentityProjection, IdentityProjection)
+
+ case (RangeFrame, offset: Expression) if orderSpec.size == 1 =>
+ // Use only the first order expression when the offset is non-null.
+ val sortExpr = orderSpec.head
+ val expr = sortExpr.child
+
+ // Create the projection which returns the current 'value'.
+ val current = newMutableProjection(expr :: Nil, child.output)
+
+ // Flip the sign of the offset when processing the order is descending
+ val boundOffset = sortExpr.direction match {
+ case Descending => UnaryMinus(offset)
+ case Ascending => offset
+ }
+
+ // Create the projection which returns the current 'value' modified by adding the offset.
+ val boundExpr = (expr.dataType, boundOffset.dataType) match {
+ case (DateType, IntegerType) => DateAdd(expr, boundOffset)
+ case (TimestampType, CalendarIntervalType) =>
+ TimeAdd(expr, boundOffset, Some(timeZone))
+ case (a, b) if a == b => Add(expr, boundOffset)
+ }
+ val bound = newMutableProjection(boundExpr :: Nil, child.output)
+
+ // Construct the ordering. This is used to compare the result of current value projection
+ // to the result of bound value projection. This is done manually because we want to use
+ // Code Generation (if it is enabled).
+ val boundSortExprs = sortExpr.copy(BoundReference(0, expr.dataType, expr.nullable)) :: Nil
+ val ordering = newOrdering(boundSortExprs, Nil)
+ RangeBoundOrdering(ordering, current, bound)
+
+ case (RangeFrame, _) =>
+ sys.error("Non-Zero range offsets are not supported for windows " +
+ "with multiple order expressions.")
+ }
+ }
+
+ /**
+ * Collection containing an entry for each window frame to process. Each entry contains a frame's
+ * [[WindowExpression]]s and factory function for the WindowFrameFunction.
+ */
+ protected lazy val windowFrameExpressionFactoryPairs = {
+ type FrameKey = (String, FrameType, Expression, Expression)
+ type ExpressionBuffer = mutable.Buffer[Expression]
+ val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)]
+
+ // Add a function and its function to the map for a given frame.
+ def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = {
+ val key = (tpe, fr.frameType, fr.lower, fr.upper)
+ val (es, fns) = framedFunctions.getOrElseUpdate(
+ key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression]))
+ es += e
+ fns += fn
+ }
+
+ // Collect all valid window functions and group them by their frame.
+ windowExpression.foreach { x =>
+ x.foreach {
+ case e @ WindowExpression(function, spec) =>
+ val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
+ function match {
+ case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
+ case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
+ case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
+ case f: PythonUDF => collect("AGGREGATE", frame, e, f)
+ case f => sys.error(s"Unsupported window function: $f")
+ }
+ case _ =>
+ }
+ }
+
+ // Map the groups to a (unbound) expression and frame factory pair.
+ var numExpressions = 0
+ val timeZone = conf.sessionLocalTimeZone
+ framedFunctions.toSeq.map {
+ case (key, (expressions, functionSeq)) =>
+ val ordinal = numExpressions
+ val functions = functionSeq.toArray
+
+ // Construct an aggregate processor if we need one.
+ // Currently we don't allow mixing of Pandas UDF and SQL aggregation functions
+ // in a single Window physical node. Therefore, we can assume no SQL aggregation
+ // functions if Pandas UDF exists. In the future, we might mix Pandas UDF and SQL
+ // aggregation function in a single physical node.
+ def processor = if (functions.exists(_.isInstanceOf[PythonUDF])) {
+ null
+ } else {
+ AggregateProcessor(
+ functions,
+ ordinal,
+ child.output,
+ (expressions, schema) =>
+ newMutableProjection(expressions, schema, subexpressionEliminationEnabled))
+ }
+
+ // Create the factory
+ val factory = key match {
+ // Offset Frame
+ case ("OFFSET", _, IntegerLiteral(offset), _) =>
+ target: InternalRow =>
+ new OffsetWindowFunctionFrame(
+ target,
+ ordinal,
+ // OFFSET frame functions are guaranteed be OffsetWindowFunctions.
+ functions.map(_.asInstanceOf[OffsetWindowFunction]),
+ child.output,
+ (expressions, schema) =>
+ newMutableProjection(expressions, schema, subexpressionEliminationEnabled),
+ offset)
+
+ // Entire Partition Frame.
+ case ("AGGREGATE", _, UnboundedPreceding, UnboundedFollowing) =>
+ target: InternalRow => {
+ new UnboundedWindowFunctionFrame(target, processor)
+ }
+
+ // Growing Frame.
+ case ("AGGREGATE", frameType, UnboundedPreceding, upper) =>
+ target: InternalRow => {
+ new UnboundedPrecedingWindowFunctionFrame(
+ target,
+ processor,
+ createBoundOrdering(frameType, upper, timeZone))
+ }
+
+ // Shrinking Frame.
+ case ("AGGREGATE", frameType, lower, UnboundedFollowing) =>
+ target: InternalRow => {
+ new UnboundedFollowingWindowFunctionFrame(
+ target,
+ processor,
+ createBoundOrdering(frameType, lower, timeZone))
+ }
+
+ // Moving Frame.
+ case ("AGGREGATE", frameType, lower, upper) =>
+ target: InternalRow => {
+ new SlidingWindowFunctionFrame(
+ target,
+ processor,
+ createBoundOrdering(frameType, lower, timeZone),
+ createBoundOrdering(frameType, upper, timeZone))
+ }
+ }
+
+ // Keep track of the number of expressions. This is a side-effect in a map...
+ numExpressions += expressions.size
+
+ // Create the Frame Expression - Factory pair.
+ (expressions, factory)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
index 156002ef58fbe..a5601899ea2de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowFunctionFrame.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray
* Before use a frame must be prepared by passing it all the rows in the current partition. After
* preparation the update method can be called to fill the output rows.
*/
-private[window] abstract class WindowFunctionFrame {
+abstract class WindowFunctionFrame {
/**
* Prepare the frame for calculating the results for a partition.
*
@@ -42,6 +42,20 @@ private[window] abstract class WindowFunctionFrame {
* Write the current results to the target row.
*/
def write(index: Int, current: InternalRow): Unit
+
+ /**
+ * The current lower window bound in the row array (inclusive).
+ *
+ * This should be called after the current row is updated via [[write]]
+ */
+ def currentLowerBound(): Int
+
+ /**
+ * The current row index of the upper window bound in the row array (exclusive)
+ *
+ * This should be called after the current row is updated via [[write]]
+ */
+ def currentUpperBound(): Int
}
object WindowFunctionFrame {
@@ -62,7 +76,7 @@ object WindowFunctionFrame {
* @param newMutableProjection function used to create the projection.
* @param offset by which rows get moved within a partition.
*/
-private[window] final class OffsetWindowFunctionFrame(
+final class OffsetWindowFunctionFrame(
target: InternalRow,
ordinal: Int,
expressions: Array[OffsetWindowFunction],
@@ -137,6 +151,10 @@ private[window] final class OffsetWindowFunctionFrame(
}
inputIndex += 1
}
+
+ override def currentLowerBound(): Int = throw new UnsupportedOperationException()
+
+ override def currentUpperBound(): Int = throw new UnsupportedOperationException()
}
/**
@@ -148,7 +166,7 @@ private[window] final class OffsetWindowFunctionFrame(
* @param lbound comparator used to identify the lower bound of an output row.
* @param ubound comparator used to identify the upper bound of an output row.
*/
-private[window] final class SlidingWindowFunctionFrame(
+final class SlidingWindowFunctionFrame(
target: InternalRow,
processor: AggregateProcessor,
lbound: BoundOrdering,
@@ -170,24 +188,24 @@ private[window] final class SlidingWindowFunctionFrame(
private[this] val buffer = new util.ArrayDeque[InternalRow]()
/**
- * Index of the first input row with a value greater than the upper bound of the current
- * output row.
+ * Index of the first input row with a value equal to or greater than the lower bound of the
+ * current output row.
*/
- private[this] var inputHighIndex = 0
+ private[this] var lowerBound = 0
/**
- * Index of the first input row with a value equal to or greater than the lower bound of the
- * current output row.
+ * Index of the first input row with a value greater than the upper bound of the current
+ * output row.
*/
- private[this] var inputLowIndex = 0
+ private[this] var upperBound = 0
/** Prepare the frame for calculating a new partition. Reset all variables. */
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
input = rows
inputIterator = input.generateIterator()
nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
- inputHighIndex = 0
- inputLowIndex = 0
+ lowerBound = 0
+ upperBound = 0
buffer.clear()
}
@@ -197,27 +215,27 @@ private[window] final class SlidingWindowFunctionFrame(
// Drop all rows from the buffer for which the input row value is smaller than
// the output row lower bound.
- while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) {
+ while (!buffer.isEmpty && lbound.compare(buffer.peek(), lowerBound, current, index) < 0) {
buffer.remove()
- inputLowIndex += 1
+ lowerBound += 1
bufferUpdated = true
}
// Add all rows to the buffer for which the input row value is equal to or less than
// the output row upper bound.
- while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) {
- if (lbound.compare(nextRow, inputLowIndex, current, index) < 0) {
- inputLowIndex += 1
+ while (nextRow != null && ubound.compare(nextRow, upperBound, current, index) <= 0) {
+ if (lbound.compare(nextRow, lowerBound, current, index) < 0) {
+ lowerBound += 1
} else {
buffer.add(nextRow.copy())
bufferUpdated = true
}
nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
- inputHighIndex += 1
+ upperBound += 1
}
// Only recalculate and update when the buffer changes.
- if (bufferUpdated) {
+ if (processor != null && bufferUpdated) {
processor.initialize(input.length)
val iter = buffer.iterator()
while (iter.hasNext) {
@@ -226,6 +244,10 @@ private[window] final class SlidingWindowFunctionFrame(
processor.evaluate(target)
}
}
+
+ override def currentLowerBound(): Int = lowerBound
+
+ override def currentUpperBound(): Int = upperBound
}
/**
@@ -239,27 +261,39 @@ private[window] final class SlidingWindowFunctionFrame(
* @param target to write results to.
* @param processor to calculate the row values with.
*/
-private[window] final class UnboundedWindowFunctionFrame(
+final class UnboundedWindowFunctionFrame(
target: InternalRow,
processor: AggregateProcessor)
extends WindowFunctionFrame {
+ val lowerBound: Int = 0
+ var upperBound: Int = 0
+
/** Prepare the frame for calculating a new partition. Process all rows eagerly. */
override def prepare(rows: ExternalAppendOnlyUnsafeRowArray): Unit = {
- processor.initialize(rows.length)
-
- val iterator = rows.generateIterator()
- while (iterator.hasNext) {
- processor.update(iterator.next())
+ if (processor != null) {
+ processor.initialize(rows.length)
+ val iterator = rows.generateIterator()
+ while (iterator.hasNext) {
+ processor.update(iterator.next())
+ }
}
+
+ upperBound = rows.length
}
/** Write the frame columns for the current row to the given target row. */
override def write(index: Int, current: InternalRow): Unit = {
// Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate
// for each row.
- processor.evaluate(target)
+ if (processor != null) {
+ processor.evaluate(target)
+ }
}
+
+ override def currentLowerBound(): Int = lowerBound
+
+ override def currentUpperBound(): Int = upperBound
}
/**
@@ -276,7 +310,7 @@ private[window] final class UnboundedWindowFunctionFrame(
* @param processor to calculate the row values with.
* @param ubound comparator used to identify the upper bound of an output row.
*/
-private[window] final class UnboundedPrecedingWindowFunctionFrame(
+final class UnboundedPrecedingWindowFunctionFrame(
target: InternalRow,
processor: AggregateProcessor,
ubound: BoundOrdering)
@@ -308,7 +342,9 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
nextRow = inputIterator.next()
}
- processor.initialize(input.length)
+ if (processor != null) {
+ processor.initialize(input.length)
+ }
}
/** Write the frame columns for the current row to the given target row. */
@@ -318,17 +354,23 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
// Add all rows to the aggregates for which the input row value is equal to or less than
// the output row upper bound.
while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) {
- processor.update(nextRow)
+ if (processor != null) {
+ processor.update(nextRow)
+ }
nextRow = WindowFunctionFrame.getNextOrNull(inputIterator)
inputIndex += 1
bufferUpdated = true
}
// Only recalculate and update when the buffer changes.
- if (bufferUpdated) {
+ if (processor != null && bufferUpdated) {
processor.evaluate(target)
}
}
+
+ override def currentLowerBound(): Int = 0
+
+ override def currentUpperBound(): Int = inputIndex
}
/**
@@ -347,7 +389,7 @@ private[window] final class UnboundedPrecedingWindowFunctionFrame(
* @param processor to calculate the row values with.
* @param lbound comparator used to identify the lower bound of an output row.
*/
-private[window] final class UnboundedFollowingWindowFunctionFrame(
+final class UnboundedFollowingWindowFunctionFrame(
target: InternalRow,
processor: AggregateProcessor,
lbound: BoundOrdering)
@@ -384,7 +426,7 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
}
// Only recalculate and update when the buffer changes.
- if (bufferUpdated) {
+ if (processor != null && bufferUpdated) {
processor.initialize(input.length)
if (nextRow != null) {
processor.update(nextRow)
@@ -395,4 +437,8 @@ private[window] final class UnboundedFollowingWindowFunctionFrame(
processor.evaluate(target)
}
}
+
+ override def currentLowerBound(): Int = inputIndex
+
+ override def currentUpperBound(): Int = input.length
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index ac07e1f6bb4f8..319c2649592fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -309,13 +309,14 @@ private[sql] trait WithTestConf { self: BaseSessionStateBuilder =>
def overrideConfs: Map[String, String]
override protected lazy val conf: SQLConf = {
+ val overrideConfigurations = overrideConfs
val conf = parentState.map(_.conf.clone()).getOrElse {
new SQLConf {
clear()
override def clear(): Unit = {
super.clear()
// Make sure we start with the default test configs even after clear
- overrideConfs.foreach { case (key, value) => setConfString(key, value) }
+ overrideConfigurations.foreach { case (key, value) => setConfString(key, value) }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index c8e3e1c191044..914fa90ae7e14 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -273,7 +273,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* during parsing.
*
* - `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a
- * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To
+ * field configured by `columnNameOfCorruptRecord`, and sets malformed fields to `null`. To
* keep corrupt records, an user can set a string type field named
* `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have the
* field, it drops corrupt records during parsing. When inferring a schema, it implicitly
@@ -360,13 +360,13 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* during parsing. It supports the following case-insensitive modes.
*
* - `PERMISSIVE` : when it meets a corrupted record, puts the malformed string into a
- * field configured by `columnNameOfCorruptRecord`, and sets other fields to `null`. To keep
- * corrupt records, an user can set a string type field named `columnNameOfCorruptRecord`
- * in an user-defined schema. If a schema does not have the field, it drops corrupt records
- * during parsing. A record with less/more tokens than schema is not a corrupted record to
- * CSV. When it meets a record having fewer tokens than the length of the schema, sets
- * `null` to extra fields. When the record has more tokens than the length of the schema,
- * it drops extra tokens.
+ * field configured by `columnNameOfCorruptRecord`, and sets malformed fields to `null`.
+ * To keep corrupt records, an user can set a string type field named
+ * `columnNameOfCorruptRecord` in an user-defined schema. If a schema does not have
+ * the field, it drops corrupt records during parsing. A record with less/more tokens
+ * than schema is not a corrupted record to CSV. When it meets a record having fewer
+ * tokens than the length of the schema, sets `null` to extra fields. When the record
+ * has more tokens than the length of the schema, it drops extra tokens.
* - `DROPMALFORMED` : ignores the whole corrupted records.
* - `FAILFAST` : throws an exception when it meets corrupted records.
*
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
index d9fe1a992a093..881cd96cc9dc9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala
@@ -246,9 +246,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
val analyzedPlan = df.queryExecution.analyzed
df.queryExecution.assertAnalyzed()
- if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
- UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode)
- }
+ val operationCheckEnabled = sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled
if (sparkSession.sessionState.conf.adaptiveExecutionEnabled) {
logWarning(s"${SQLConf.ADAPTIVE_EXECUTION_ENABLED.key} " +
@@ -257,7 +255,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
(sink, trigger) match {
case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) =>
- if (sparkSession.sessionState.conf.isUnsupportedOperationCheckEnabled) {
+ if (operationCheckEnabled) {
UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode)
}
new StreamingQueryWrapper(new ContinuousExecution(
@@ -272,6 +270,9 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo
extraOptions,
deleteCheckpointOnStop))
case _ =>
+ if (operationCheckEnabled) {
+ UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode)
+ }
new StreamingQueryWrapper(new MicroBatchExecution(
sparkSession,
userSpecifiedName.orNull,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
index 9dc62b7aac891..6ca9aacab7247 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryStatus.scala
@@ -28,9 +28,11 @@ import org.apache.spark.annotation.Evolving
* Reports information about the instantaneous status of a streaming query.
*
* @param message A human readable description of what the stream is currently doing.
- * @param isDataAvailable True when there is new data to be processed.
+ * @param isDataAvailable True when there is new data to be processed. Doesn't apply
+ * to ContinuousExecution where it is always false.
* @param isTriggerActive True when the trigger is actively firing, false when waiting for the
- * next trigger time.
+ * next trigger time. Doesn't apply to ContinuousExecution where it is
+ * always false.
*
* @since 2.1.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
index 3cd6700efef5f..0b3945cbd1323 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/progress.scala
@@ -30,6 +30,7 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.Evolving
+import org.apache.spark.sql.streaming.SinkProgress.DEFAULT_NUM_OUTPUT_ROWS
/**
* Information about updates made to stateful operators in a [[StreamingQuery]] during a trigger.
@@ -207,11 +208,19 @@ class SourceProgress protected[sql](
* during a trigger. See [[StreamingQueryProgress]] for more information.
*
* @param description Description of the source corresponding to this status.
+ * @param numOutputRows Number of rows written to the sink or -1 for Continuous Mode (temporarily)
+ * or Sink V1 (until decommissioned).
* @since 2.1.0
*/
@Evolving
class SinkProgress protected[sql](
- val description: String) extends Serializable {
+ val description: String,
+ val numOutputRows: Long) extends Serializable {
+
+ /** SinkProgress without custom metrics. */
+ protected[sql] def this(description: String) {
+ this(description, DEFAULT_NUM_OUTPUT_ROWS)
+ }
/** The compact JSON representation of this progress. */
def json: String = compact(render(jsonValue))
@@ -222,6 +231,14 @@ class SinkProgress protected[sql](
override def toString: String = prettyJson
private[sql] def jsonValue: JValue = {
- ("description" -> JString(description))
+ ("description" -> JString(description)) ~
+ ("numOutputRows" -> JInt(numOutputRows))
}
}
+
+private[sql] object SinkProgress {
+ val DEFAULT_NUM_OUTPUT_ROWS: Long = -1L
+
+ def apply(description: String, numOutputRows: Option[Long]): SinkProgress =
+ new SinkProgress(description, numOutputRows.getOrElse(DEFAULT_NUM_OUTPUT_ROWS))
+}
diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
index ec263ea70bd4a..7e81ff1aba37b 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql
@@ -141,8 +141,3 @@ SELECT every("true");
SELECT k, v, every(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
SELECT k, v, some(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
SELECT k, v, any(v) OVER (PARTITION BY k ORDER BY v) FROM test_agg;
-
--- simple explain of queries having every/some/any agregates. Optimized
--- plan should show the rewritten aggregate expression.
-EXPLAIN EXTENDED SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k;
-
diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
index 41d316444ed6b..b3ec956cd178e 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
@@ -49,6 +49,3 @@ select * from values ("one", count(1)), ("two", 2) as data(a, b);
-- string to timestamp
select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b);
-
--- cross-join inline tables
-EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql
index 37f9cd44da7f2..ba14789d48db6 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql
@@ -29,27 +29,6 @@ select 2 * 5;
select 5 % 3;
select pmod(-7, 3);
--- check operator precedence.
--- We follow Oracle operator precedence in the table below that lists the levels of precedence
--- among SQL operators from high to low:
-------------------------------------------------------------------------------------------
--- Operator Operation
-------------------------------------------------------------------------------------------
--- +, - identity, negation
--- *, / multiplication, division
--- +, -, || addition, subtraction, concatenation
--- =, !=, <, >, <=, >=, IS NULL, LIKE, BETWEEN, IN comparison
--- NOT exponentiation, logical negation
--- AND conjunction
--- OR disjunction
-------------------------------------------------------------------------------------------
-explain select 'a' || 1 + 2;
-explain select 1 - 2 || 'b';
-explain select 2 * 4 + 3 || 'b';
-explain select 3 + 1 || 'a' || 4 / 2;
-explain select 1 == 1 OR 'a' || 'b' == 'ab';
-explain select 'a' || 'c' == 'ac' AND 2 == 3;
-
-- math functions
select cot(1);
select cot(null);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql
index f1461032065ad..1ae49c8bfc76a 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql
@@ -12,11 +12,6 @@ SELECT nullif(1, 2.1d), nullif(1, 1.0d);
SELECT nvl(1, 2.1d), nvl(null, 2.1d);
SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d);
--- explain for these functions; use range to avoid constant folding
-explain extended
-select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y')
-from range(2);
-
-- SPARK-16730 cast alias functions for Hive compatibility
SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1);
SELECT float(1), double(1), decimal(1);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
index 2effb43183d75..fbc231627e36f 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql
@@ -5,10 +5,6 @@ select format_string();
-- A pipe operator for string concatenation
select 'a' || 'b' || 'c';
--- Check if catalyst combine nested `Concat`s
-EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
-FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10));
-
-- replace function
select replace('abc', 'b', '123');
select replace('abc', 'b');
@@ -25,29 +21,6 @@ select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a');
select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null);
select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a');
--- turn off concatBinaryAsString
-set spark.sql.function.concatBinaryAsString=false;
-
--- Check if catalyst combine nested `Concat`s if concatBinaryAsString=false
-EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col
-FROM (
- SELECT
- string(id) col1,
- string(id + 1) col2,
- encode(string(id + 2), 'utf-8') col3,
- encode(string(id + 3), 'utf-8') col4
- FROM range(10)
-);
-
-EXPLAIN SELECT (col1 || (col3 || col4)) col
-FROM (
- SELECT
- string(id) col1,
- encode(string(id + 2), 'utf-8') col3,
- encode(string(id + 3), 'utf-8') col4
- FROM range(10)
-);
-
-- split function
SELECT split('aa1cc2ee3', '[1-9]+');
SELECT split('aa1cc2ee3', '[1-9]+', 2);
diff --git a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
index 72cd8ca9d8722..6f14c8ca87821 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
@@ -21,9 +21,3 @@ select * from range(1, null);
-- range call with a mixed-case function name
select * from RaNgE(2);
-
--- Explain
-EXPLAIN select * from RaNgE(2);
-
--- cross-join table valued functions
-EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3);
diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
index 9a8d025331b67..daf47c4d0a39a 100644
--- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 47
+-- Number of queries: 46
-- !query 0
@@ -459,31 +459,3 @@ struct
--- !query 46 output
-== Parsed Logical Plan ==
-'Aggregate ['k], ['k, unresolvedalias('every('v), None), unresolvedalias('some('v), None), unresolvedalias('any('v), None)]
-+- 'UnresolvedRelation `test_agg`
-
-== Analyzed Logical Plan ==
-k: int, every(v): boolean, some(v): boolean, any(v): boolean
-Aggregate [k#x], [k#x, every(v#x) AS every(v)#x, some(v#x) AS some(v)#x, any(v#x) AS any(v)#x]
-+- SubqueryAlias `test_agg`
- +- Project [k#x, v#x]
- +- SubqueryAlias `test_agg`
- +- LocalRelation [k#x, v#x]
-
-== Optimized Logical Plan ==
-Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, max(v#x) AS any(v)#x]
-+- LocalRelation [k#x, v#x]
-
-== Physical Plan ==
-*HashAggregate(keys=[k#x], functions=[min(v#x), max(v#x)], output=[k#x, every(v)#x, some(v)#x, any(v)#x])
-+- Exchange hashpartitioning(k#x, 200)
- +- *HashAggregate(keys=[k#x], functions=[partial_min(v#x), partial_max(v#x)], output=[k#x, min#x, max#x])
- +- LocalTableScan [k#x, v#x]
diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
index c065ce5012929..4e80f0bda5513 100644
--- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 18
+-- Number of queries: 17
-- !query 0
@@ -151,33 +151,3 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-
struct>
-- !query 16 output
1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0]
-
-
--- !query 17
-EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN VALUES ('one', 1), ('three', null)
--- !query 17 schema
-struct
--- !query 17 output
-== Parsed Logical Plan ==
-'Project [*]
-+- 'Join Cross
- :- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)]
- +- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)]
-
-== Analyzed Logical Plan ==
-col1: string, col2: int, col1: string, col2: int
-Project [col1#x, col2#x, col1#x, col2#x]
-+- Join Cross
- :- LocalRelation [col1#x, col2#x]
- +- LocalRelation [col1#x, col2#x]
-
-== Optimized Logical Plan ==
-Join Cross
-:- LocalRelation [col1#x, col2#x]
-+- LocalRelation [col1#x, col2#x]
-
-== Physical Plan ==
-BroadcastNestedLoopJoin BuildRight, Cross
-:- LocalTableScan [col1#x, col2#x]
-+- BroadcastExchange IdentityBroadcastMode
- +- LocalTableScan [col1#x, col2#x]
diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
index 570b281353f3d..e0cbd575bc346 100644
--- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 55
+-- Number of queries: 49
-- !query 0
@@ -195,260 +195,200 @@ struct
-- !query 24
-explain select 'a' || 1 + 2
+select cot(1)
-- !query 24 schema
-struct
+struct
-- !query 24 output
-== Physical Plan ==
-*Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x]
-+- *Scan OneRowRelation[]
+0.6420926159343306
-- !query 25
-explain select 1 - 2 || 'b'
+select cot(null)
-- !query 25 schema
-struct
+struct
-- !query 25 output
-== Physical Plan ==
-*Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x]
-+- *Scan OneRowRelation[]
+NULL
-- !query 26
-explain select 2 * 4 + 3 || 'b'
+select cot(0)
-- !query 26 schema
-struct
+struct
-- !query 26 output
-== Physical Plan ==
-*Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x]
-+- *Scan OneRowRelation[]
+Infinity
-- !query 27
-explain select 3 + 1 || 'a' || 4 / 2
+select cot(-1)
-- !query 27 schema
-struct
+struct
-- !query 27 output
-== Physical Plan ==
-*Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x]
-+- *Scan OneRowRelation[]
+-0.6420926159343306
-- !query 28
-explain select 1 == 1 OR 'a' || 'b' == 'ab'
+select ceiling(0)
-- !query 28 schema
-struct
+struct
-- !query 28 output
-== Physical Plan ==
-*Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x]
-+- *Scan OneRowRelation[]
+0
-- !query 29
-explain select 'a' || 'c' == 'ac' AND 2 == 3
+select ceiling(1)
-- !query 29 schema
-struct
+struct
-- !query 29 output
-== Physical Plan ==
-*Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x]
-+- *Scan OneRowRelation[]
+1
-- !query 30
-select cot(1)
+select ceil(1234567890123456)
-- !query 30 schema
-struct
+struct
-- !query 30 output
-0.6420926159343306
+1234567890123456
-- !query 31
-select cot(null)
+select ceiling(1234567890123456)
-- !query 31 schema
-struct
+struct
-- !query 31 output
-NULL
+1234567890123456
-- !query 32
-select cot(0)
+select ceil(0.01)
-- !query 32 schema
-struct
+struct
-- !query 32 output
-Infinity
+1
-- !query 33
-select cot(-1)
+select ceiling(-0.10)
-- !query 33 schema
-struct
+struct
-- !query 33 output
--0.6420926159343306
+0
-- !query 34
-select ceiling(0)
+select floor(0)
-- !query 34 schema
-struct
+struct
-- !query 34 output
0
-- !query 35
-select ceiling(1)
+select floor(1)
-- !query 35 schema
-struct
+struct
-- !query 35 output
1
-- !query 36
-select ceil(1234567890123456)
+select floor(1234567890123456)
-- !query 36 schema
-struct
+struct
-- !query 36 output
1234567890123456
-- !query 37
-select ceiling(1234567890123456)
--- !query 37 schema
-struct
--- !query 37 output
-1234567890123456
-
-
--- !query 38
-select ceil(0.01)
--- !query 38 schema
-struct
--- !query 38 output
-1
-
-
--- !query 39
-select ceiling(-0.10)
--- !query 39 schema
-struct
--- !query 39 output
-0
-
-
--- !query 40
-select floor(0)
--- !query 40 schema
-struct
--- !query 40 output
-0
-
-
--- !query 41
-select floor(1)
--- !query 41 schema
-struct
--- !query 41 output
-1
-
-
--- !query 42
-select floor(1234567890123456)
--- !query 42 schema
-struct
--- !query 42 output
-1234567890123456
-
-
--- !query 43
select floor(0.01)
--- !query 43 schema
+-- !query 37 schema
struct
--- !query 43 output
+-- !query 37 output
0
--- !query 44
+-- !query 38
select floor(-0.10)
--- !query 44 schema
+-- !query 38 schema
struct
--- !query 44 output
+-- !query 38 output
-1
--- !query 45
+-- !query 39
select 1 > 0.00001
--- !query 45 schema
+-- !query 39 schema
struct<(CAST(1 AS BIGINT) > 0):boolean>
--- !query 45 output
+-- !query 39 output
true
--- !query 46
+-- !query 40
select mod(7, 2), mod(7, 0), mod(0, 2), mod(7, null), mod(null, 2), mod(null, null)
--- !query 46 schema
+-- !query 40 schema
struct<(7 % 2):int,(7 % 0):int,(0 % 2):int,(7 % CAST(NULL AS INT)):int,(CAST(NULL AS INT) % 2):int,(CAST(NULL AS DOUBLE) % CAST(NULL AS DOUBLE)):double>
--- !query 46 output
+-- !query 40 output
1 NULL 0 NULL NULL NULL
--- !query 47
+-- !query 41
select BIT_LENGTH('abc')
--- !query 47 schema
+-- !query 41 schema
struct
--- !query 47 output
+-- !query 41 output
24
--- !query 48
+-- !query 42
select CHAR_LENGTH('abc')
--- !query 48 schema
+-- !query 42 schema
struct
--- !query 48 output
+-- !query 42 output
3
--- !query 49
+-- !query 43
select CHARACTER_LENGTH('abc')
--- !query 49 schema
+-- !query 43 schema
struct
--- !query 49 output
+-- !query 43 output
3
--- !query 50
+-- !query 44
select OCTET_LENGTH('abc')
--- !query 50 schema
+-- !query 44 schema
struct
--- !query 50 output
+-- !query 44 output
3
--- !query 51
+-- !query 45
select abs(-3.13), abs('-2.19')
--- !query 51 schema
+-- !query 45 schema
struct
--- !query 51 output
+-- !query 45 output
3.13 2.19
--- !query 52
+-- !query 46
select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11)
--- !query 52 schema
+-- !query 46 schema
struct<(+ CAST(-1.11 AS DOUBLE)):double,(+ -1.11):decimal(3,2),(- CAST(-1.11 AS DOUBLE)):double,(- -1.11):decimal(3,2)>
--- !query 52 output
+-- !query 46 output
-1.11 -1.11 1.11 1.11
--- !query 53
+-- !query 47
select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null)
--- !query 53 schema
+-- !query 47 schema
struct
--- !query 53 output
+-- !query 47 output
1 0 NULL NULL NULL NULL
--- !query 54
+-- !query 48
select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint))
--- !query 54 schema
+-- !query 48 schema
struct
--- !query 54 output
+-- !query 48 output
NULL NULL
diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
index e035505f15d28..69a8e958000db 100644
--- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 15
+-- Number of queries: 14
-- !query 0
@@ -67,74 +67,49 @@ struct
-- !query 8
-explain extended
-select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y')
-from range(2)
--- !query 8 schema
-struct
--- !query 8 output
-== Parsed Logical Plan ==
-'Project [unresolvedalias('ifnull('id, x), None), unresolvedalias('nullif('id, x), None), unresolvedalias('nvl('id, x), None), unresolvedalias('nvl2('id, x, y), None)]
-+- 'UnresolvedTableValuedFunction range, [2]
-
-== Analyzed Logical Plan ==
-ifnull(`id`, 'x'): string, nullif(`id`, 'x'): bigint, nvl(`id`, 'x'): string, nvl2(`id`, 'x', 'y'): string
-Project [ifnull(id#xL, x) AS ifnull(`id`, 'x')#x, nullif(id#xL, x) AS nullif(`id`, 'x')#xL, nvl(id#xL, x) AS nvl(`id`, 'x')#x, nvl2(id#xL, x, y) AS nvl2(`id`, 'x', 'y')#x]
-+- Range (0, 2, step=1, splits=None)
-
-== Optimized Logical Plan ==
-Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x]
-+- Range (0, 2, step=1, splits=None)
-
-== Physical Plan ==
-*Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x]
-+- *Range (0, 2, step=1, splits=2)
-
-
--- !query 9
SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1)
--- !query 9 schema
+-- !query 8 schema
struct
--- !query 9 output
+-- !query 8 output
true 1 1 1 1
--- !query 10
+-- !query 9
SELECT float(1), double(1), decimal(1)
--- !query 10 schema
+-- !query 9 schema
struct
--- !query 10 output
+-- !query 9 output
1.0 1.0 1
--- !query 11
+-- !query 10
SELECT date("2014-04-04"), timestamp(date("2014-04-04"))
--- !query 11 schema
+-- !query 10 schema
struct
--- !query 11 output
+-- !query 10 output
2014-04-04 2014-04-04 00:00:00
--- !query 12
+-- !query 11
SELECT string(1, 2)
--- !query 12 schema
+-- !query 11 schema
struct<>
--- !query 12 output
+-- !query 11 output
org.apache.spark.sql.AnalysisException
Function string accepts only one argument; line 1 pos 7
--- !query 13
+-- !query 12
CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 'col2', 'delta')) AS T(id, st)
--- !query 13 schema
+-- !query 12 schema
struct<>
--- !query 13 output
+-- !query 12 output
--- !query 14
+-- !query 13
SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY nvl(st.col1, "value")
--- !query 14 schema
+-- !query 13 schema
struct
--- !query 14 output
+-- !query 13 output
gamma 1
diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
index e8f2e0a81455a..25d93b2063146 100644
--- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 17
+-- Number of queries: 13
-- !query 0
@@ -29,151 +29,80 @@ abc
-- !query 3
-EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
-FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10))
--- !query 3 schema
-struct
--- !query 3 output
-== Parsed Logical Plan ==
-'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x]
-+- 'SubqueryAlias `__auto_generated_subquery_name`
- +- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x]
- +- 'UnresolvedTableValuedFunction range, [10]
-
-== Analyzed Logical Plan ==
-col: string
-Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x]
-+- SubqueryAlias `__auto_generated_subquery_name`
- +- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL]
- +- Range (0, 10, step=1, splits=None)
-
-== Optimized Logical Plan ==
-Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
-+- Range (0, 10, step=1, splits=None)
-
-== Physical Plan ==
-*Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
-+- *Range (0, 10, step=1, splits=2)
-
-
--- !query 4
select replace('abc', 'b', '123')
--- !query 4 schema
+-- !query 3 schema
struct
--- !query 4 output
+-- !query 3 output
a123c
--- !query 5
+-- !query 4
select replace('abc', 'b')
--- !query 5 schema
+-- !query 4 schema
struct
--- !query 5 output
+-- !query 4 output
ac
--- !query 6
+-- !query 5
select length(uuid()), (uuid() <> uuid())
--- !query 6 schema
+-- !query 5 schema
struct
--- !query 6 output
+-- !query 5 output
36 true
--- !query 7
+-- !query 6
select position('bar' in 'foobarbar'), position(null, 'foobarbar'), position('aaads', null)
--- !query 7 schema
+-- !query 6 schema
struct
--- !query 7 output
+-- !query 6 output
4 NULL NULL
--- !query 8
+-- !query 7
select left("abcd", 2), left("abcd", 5), left("abcd", '2'), left("abcd", null)
--- !query 8 schema
+-- !query 7 schema
struct
--- !query 8 output
+-- !query 7 output
ab abcd ab NULL
--- !query 9
+-- !query 8
select left(null, -2), left("abcd", -2), left("abcd", 0), left("abcd", 'a')
--- !query 9 schema
+-- !query 8 schema
struct
--- !query 9 output
+-- !query 8 output
NULL NULL
--- !query 10
+-- !query 9
select right("abcd", 2), right("abcd", 5), right("abcd", '2'), right("abcd", null)
--- !query 10 schema
+-- !query 9 schema
struct
--- !query 10 output
+-- !query 9 output
cd abcd cd NULL
--- !query 11
+-- !query 10
select right(null, -2), right("abcd", -2), right("abcd", 0), right("abcd", 'a')
--- !query 11 schema
+-- !query 10 schema
struct
--- !query 11 output
+-- !query 10 output
NULL NULL
--- !query 12
-set spark.sql.function.concatBinaryAsString=false
--- !query 12 schema
-struct
--- !query 12 output
-spark.sql.function.concatBinaryAsString false
-
-
--- !query 13
-EXPLAIN SELECT ((col1 || col2) || (col3 || col4)) col
-FROM (
- SELECT
- string(id) col1,
- string(id + 1) col2,
- encode(string(id + 2), 'utf-8') col3,
- encode(string(id + 3), 'utf-8') col4
- FROM range(10)
-)
--- !query 13 schema
-struct
--- !query 13 output
-== Physical Plan ==
-*Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
-+- *Range (0, 10, step=1, splits=2)
-
-
--- !query 14
-EXPLAIN SELECT (col1 || (col3 || col4)) col
-FROM (
- SELECT
- string(id) col1,
- encode(string(id + 2), 'utf-8') col3,
- encode(string(id + 3), 'utf-8') col4
- FROM range(10)
-)
--- !query 14 schema
-struct
--- !query 14 output
-== Physical Plan ==
-*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
-+- *Range (0, 10, step=1, splits=2)
-
-
--- !query 15
+-- !query 11
SELECT split('aa1cc2ee3', '[1-9]+')
--- !query 15 schema
+-- !query 11 schema
struct>
--- !query 15 output
+-- !query 11 output
["aa","cc","ee",""]
--- !query 16
+-- !query 12
SELECT split('aa1cc2ee3', '[1-9]+', 2)
--- !query 16 schema
+-- !query 12 schema
struct>
--- !query 16 output
+-- !query 12 output
["aa","cc2ee3"]
diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
index 94af9181225d6..fdbea0ee90720 100644
--- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
--- Number of queries: 10
+-- Number of queries: 8
-- !query 0
@@ -99,42 +99,3 @@ struct
-- !query 7 output
0
1
-
-
--- !query 8
-EXPLAIN select * from RaNgE(2)
--- !query 8 schema
-struct
--- !query 8 output
-== Physical Plan ==
-*Range (0, 2, step=1, splits=2)
-
-
--- !query 9
-EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3)
--- !query 9 schema
-struct
--- !query 9 output
-== Parsed Logical Plan ==
-'Project [*]
-+- 'Join Cross
- :- 'UnresolvedTableValuedFunction range, [3]
- +- 'UnresolvedTableValuedFunction range, [3]
-
-== Analyzed Logical Plan ==
-id: bigint, id: bigint
-Project [id#xL, id#xL]
-+- Join Cross
- :- Range (0, 3, step=1, splits=None)
- +- Range (0, 3, step=1, splits=None)
-
-== Optimized Logical Plan ==
-Join Cross
-:- Range (0, 3, step=1, splits=None)
-+- Range (0, 3, step=1, splits=None)
-
-== Physical Plan ==
-BroadcastNestedLoopJoin BuildRight, Cross
-:- *Range (0, 3, step=1, splits=2)
-+- BroadcastExchange IdentityBroadcastMode
- +- *Range (0, 3, step=1, splits=2)
diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
index 35740094ba53e..86a578ca013df 100644
--- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out
@@ -85,7 +85,7 @@ FROM various_maps
struct<>
-- !query 5 output
org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7
+cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_map2`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,0), decimal(36,35)].; line 1 pos 7
-- !query 6
@@ -113,7 +113,7 @@ FROM various_maps
struct<>
-- !query 8 output
org.apache.spark.sql.AnalysisException
-cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7
+cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), k, NamePlaceholder(), v1, NamePlaceholder(), v2), k, v1, v2))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7
-- !query 9
diff --git a/sql/core/src/test/resources/test-data/bad_after_good.csv b/sql/core/src/test/resources/test-data/bad_after_good.csv
new file mode 100644
index 0000000000000..4621a7d23714d
--- /dev/null
+++ b/sql/core/src/test/resources/test-data/bad_after_good.csv
@@ -0,0 +1,2 @@
+"good record",1999-08-01
+"bad record",1999-088-01
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index e6d1a038a5918..b7fc9570af919 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -2908,6 +2908,26 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
}
assert(ex.getMessage.contains("Cannot use null as map key"))
}
+
+ test("SPARK-26370: Fix resolution of higher-order function for the same identifier") {
+ val df = Seq(
+ (Seq(1, 9, 8, 7), 1, 2),
+ (Seq(5, 9, 7), 2, 2),
+ (Seq.empty, 3, 2),
+ (null, 4, 2)
+ ).toDF("i", "x", "d")
+
+ checkAnswer(df.selectExpr("x", "exists(i, x -> x % d == 0)"),
+ Seq(
+ Row(1, true),
+ Row(2, false),
+ Row(3, false),
+ Row(4, null)))
+ checkAnswer(df.filter("exists(i, x -> x % d == 0)"),
+ Seq(Row(Seq(1, 9, 8, 7), 1, 2)))
+ checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
+ Seq(Row(1)))
+ }
}
object DataFrameFunctionsSuite {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index e6b30f9956daf..c9f41ab1c0179 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -295,4 +295,16 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
df.join(df, df("id") <=> df("id")).queryExecution.optimizedPlan
}
}
+
+ test("NaN and -0.0 in join keys") {
+ val df1 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
+ val df2 = Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d")
+ val joined = df1.join(df2, Seq("f", "d"))
+ checkAnswer(joined, Seq(
+ Row(Float.NaN, Double.NaN),
+ Row(0.0f, 0.0),
+ Row(0.0f, 0.0),
+ Row(0.0f, 0.0),
+ Row(0.0f, 0.0)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index fc3faa08d55f4..b51c51e663503 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1904,7 +1904,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val e = intercept[SparkException] {
df.filter(filter).count()
}.getMessage
- assert(e.contains("grows beyond 64 KB"))
+ assert(e.contains("grows beyond 64 KiB"))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
index 78277d7dcf757..9277dc6859247 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala
@@ -20,6 +20,8 @@ package org.apache.spark.sql
import org.scalatest.Matchers.the
import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled}
+import org.apache.spark.sql.catalyst.optimizer.TransposeWindow
+import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -668,17 +670,43 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext {
("S2", "P2", 300)
).toDF("sno", "pno", "qty")
- val w1 = Window.partitionBy("sno")
- val w2 = Window.partitionBy("sno", "pno")
-
- checkAnswer(
- df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2"))
- .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1")),
- Seq(
- Row("S1", "P1", 100, 800, 800),
- Row("S1", "P1", 700, 800, 800),
- Row("S2", "P1", 200, 200, 500),
- Row("S2", "P2", 300, 300, 500)))
+ Seq(true, false).foreach { transposeWindowEnabled =>
+ val excludedRules = if (transposeWindowEnabled) "" else TransposeWindow.ruleName
+ withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> excludedRules) {
+ val w1 = Window.partitionBy("sno")
+ val w2 = Window.partitionBy("sno", "pno")
+
+ val select = df.select($"sno", $"pno", $"qty", sum($"qty").over(w2).alias("sum_qty_2"))
+ .select($"sno", $"pno", $"qty", col("sum_qty_2"), sum("qty").over(w1).alias("sum_qty_1"))
+
+ val expectedNumExchanges = if (transposeWindowEnabled) 1 else 2
+ val actualNumExchanges = select.queryExecution.executedPlan.collect {
+ case e: Exchange => e
+ }.length
+ assert(actualNumExchanges == expectedNumExchanges)
+
+ checkAnswer(
+ select,
+ Seq(
+ Row("S1", "P1", 100, 800, 800),
+ Row("S1", "P1", 700, 800, 800),
+ Row("S2", "P1", 200, 200, 500),
+ Row("S2", "P2", 300, 300, 500)))
+ }
+ }
+ }
+ test("NaN and -0.0 in window partition keys") {
+ val df = Seq(
+ (Float.NaN, Double.NaN, 1),
+ (0.0f/0.0f, 0.0/0.0, 1),
+ (0.0f, 0.0, 1),
+ (-0.0f, -0.0, 1)).toDF("f", "d", "i")
+ val result = df.select($"f", count("i").over(Window.partitionBy("f", "d")))
+ checkAnswer(result, Seq(
+ Row(Float.NaN, 2),
+ Row(Float.NaN, 2),
+ Row(0.0f, 2),
+ Row(0.0f, 2)))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 0f900833d2cfe..c90b15814a534 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1647,6 +1647,26 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkDataset(ds, data: _*)
checkAnswer(ds.select("x"), Seq(Row(1), Row(2)))
}
+
+ test("SPARK-26233: serializer should enforce decimal precision and scale") {
+ val s = StructType(Seq(StructField("a", StringType), StructField("b", DecimalType(38, 8))))
+ val encoder = RowEncoder(s)
+ implicit val uEnc = encoder
+ val df = spark.range(2).map(l => Row(l.toString, BigDecimal.valueOf(l + 0.1111)))
+ checkAnswer(df.groupBy(col("a")).agg(first(col("b"))),
+ Seq(Row("0", BigDecimal.valueOf(0.1111)), Row("1", BigDecimal.valueOf(1.1111))))
+ }
+
+ test("SPARK-26366: return nulls which are not filtered in except") {
+ val inputDF = sqlContext.createDataFrame(
+ sparkContext.parallelize(Seq(Row("0", "a"), Row("1", null))),
+ StructType(Seq(
+ StructField("a", StringType, nullable = true),
+ StructField("b", StringType, nullable = true))))
+
+ val exceptDF = inputDF.filter(col("a").isin("0") or col("b") > "c")
+ checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
+ }
}
case class TestDataUnion(x: Int, y: Int, z: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
index 56d300e30a58e..ce475922eb5e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StructType
@@ -29,10 +30,11 @@ class ExplainSuite extends QueryTest with SharedSQLContext {
private def checkKeywordsExistsInExplain(df: DataFrame, keywords: String*): Unit = {
val output = new java.io.ByteArrayOutputStream()
Console.withOut(output) {
- df.explain(extended = false)
+ df.explain(extended = true)
}
+ val normalizedOutput = output.toString.replaceAll("#\\d+", "#x")
for (key <- keywords) {
- assert(output.toString.contains(key))
+ assert(normalizedOutput.contains(key))
}
}
@@ -53,6 +55,133 @@ class ExplainSuite extends QueryTest with SharedSQLContext {
checkKeywordsExistsInExplain(df,
keywords = "InMemoryRelation", "StorageLevel(disk, memory, deserialized, 1 replicas)")
}
+
+ test("optimized plan should show the rewritten aggregate expression") {
+ withTempView("test_agg") {
+ sql(
+ """
+ |CREATE TEMPORARY VIEW test_agg AS SELECT * FROM VALUES
+ | (1, true), (1, false),
+ | (2, true),
+ | (3, false), (3, null),
+ | (4, null), (4, null),
+ | (5, null), (5, true), (5, false) AS test_agg(k, v)
+ """.stripMargin)
+
+ // simple explain of queries having every/some/any aggregates. Optimized
+ // plan should show the rewritten aggregate expression.
+ val df = sql("SELECT k, every(v), some(v), any(v) FROM test_agg GROUP BY k")
+ checkKeywordsExistsInExplain(df,
+ "Aggregate [k#x], [k#x, min(v#x) AS every(v)#x, max(v#x) AS some(v)#x, " +
+ "max(v#x) AS any(v)#x]")
+ }
+ }
+
+ test("explain inline tables cross-joins") {
+ val df = sql(
+ """
+ |SELECT * FROM VALUES ('one', 1), ('three', null)
+ | CROSS JOIN VALUES ('one', 1), ('three', null)
+ """.stripMargin)
+ checkKeywordsExistsInExplain(df,
+ "Join Cross",
+ ":- LocalRelation [col1#x, col2#x]",
+ "+- LocalRelation [col1#x, col2#x]")
+ }
+
+ test("explain table valued functions") {
+ checkKeywordsExistsInExplain(sql("select * from RaNgE(2)"), "Range (0, 2, step=1, splits=None)")
+ checkKeywordsExistsInExplain(sql("SELECT * FROM range(3) CROSS JOIN range(3)"),
+ "Join Cross",
+ ":- Range (0, 3, step=1, splits=None)",
+ "+- Range (0, 3, step=1, splits=None)")
+ }
+
+ test("explain string functions") {
+ // Check if catalyst combine nested `Concat`s
+ val df1 = sql(
+ """
+ |SELECT (col1 || col2 || col3 || col4) col
+ | FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10))
+ """.stripMargin)
+ checkKeywordsExistsInExplain(df1,
+ "Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)" +
+ ", cast(id#xL as string)) AS col#x]")
+
+ // Check if catalyst combine nested `Concat`s if concatBinaryAsString=false
+ withSQLConf(SQLConf.CONCAT_BINARY_AS_STRING.key -> "false") {
+ val df2 = sql(
+ """
+ |SELECT ((col1 || col2) || (col3 || col4)) col
+ |FROM (
+ | SELECT
+ | string(id) col1,
+ | string(id + 1) col2,
+ | encode(string(id + 2), 'utf-8') col3,
+ | encode(string(id + 3), 'utf-8') col4
+ | FROM range(10)
+ |)
+ """.stripMargin)
+ checkKeywordsExistsInExplain(df2,
+ "Project [concat(cast(id#xL as string), cast((id#xL + 1) as string), " +
+ "cast(encode(cast((id#xL + 2) as string), utf-8) as string), " +
+ "cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]")
+
+ val df3 = sql(
+ """
+ |SELECT (col1 || (col3 || col4)) col
+ |FROM (
+ | SELECT
+ | string(id) col1,
+ | encode(string(id + 2), 'utf-8') col3,
+ | encode(string(id + 3), 'utf-8') col4
+ | FROM range(10)
+ |)
+ """.stripMargin)
+ checkKeywordsExistsInExplain(df3,
+ "Project [concat(cast(id#xL as string), " +
+ "cast(encode(cast((id#xL + 2) as string), utf-8) as string), " +
+ "cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]")
+ }
+ }
+
+ test("check operator precedence") {
+ // We follow Oracle operator precedence in the table below that lists the levels
+ // of precedence among SQL operators from high to low:
+ // ---------------------------------------------------------------------------------------
+ // Operator Operation
+ // ---------------------------------------------------------------------------------------
+ // +, - identity, negation
+ // *, / multiplication, division
+ // +, -, || addition, subtraction, concatenation
+ // =, !=, <, >, <=, >=, IS NULL, LIKE, BETWEEN, IN comparison
+ // NOT exponentiation, logical negation
+ // AND conjunction
+ // OR disjunction
+ // ---------------------------------------------------------------------------------------
+ checkKeywordsExistsInExplain(sql("select 'a' || 1 + 2"),
+ "Project [null AS (CAST(concat(a, CAST(1 AS STRING)) AS DOUBLE) + CAST(2 AS DOUBLE))#x]")
+ checkKeywordsExistsInExplain(sql("select 1 - 2 || 'b'"),
+ "Project [-1b AS concat(CAST((1 - 2) AS STRING), b)#x]")
+ checkKeywordsExistsInExplain(sql("select 2 * 4 + 3 || 'b'"),
+ "Project [11b AS concat(CAST(((2 * 4) + 3) AS STRING), b)#x]")
+ checkKeywordsExistsInExplain(sql("select 3 + 1 || 'a' || 4 / 2"),
+ "Project [4a2.0 AS concat(concat(CAST((3 + 1) AS STRING), a), " +
+ "CAST((CAST(4 AS DOUBLE) / CAST(2 AS DOUBLE)) AS STRING))#x]")
+ checkKeywordsExistsInExplain(sql("select 1 == 1 OR 'a' || 'b' == 'ab'"),
+ "Project [true AS ((1 = 1) OR (concat(a, b) = ab))#x]")
+ checkKeywordsExistsInExplain(sql("select 'a' || 'c' == 'ac' AND 2 == 3"),
+ "Project [false AS ((concat(a, c) = ac) AND (2 = 3))#x]")
+ }
+
+ test("explain for these functions; use range to avoid constant folding") {
+ val df = sql("select ifnull(id, 'x'), nullif(id, 'x'), nvl(id, 'x'), nvl2(id, 'x', 'y') " +
+ "from range(2)")
+ checkKeywordsExistsInExplain(df,
+ "Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, " +
+ "id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, " +
+ "x AS nvl2(`id`, 'x', 'y')#x]")
+ }
}
case class ExplainSingleData(id: Int)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index aa2162c9d2cda..91445c8d96d85 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -895,4 +895,18 @@ class JoinSuite extends QueryTest with SharedSQLContext {
checkAnswer(res, Row(0, 0, 0))
}
}
+
+ test("SPARK-26352: join reordering should not change the order of columns") {
+ withTable("tab1", "tab2", "tab3") {
+ spark.sql("select 1 as x, 100 as y").write.saveAsTable("tab1")
+ spark.sql("select 42 as i, 200 as j").write.saveAsTable("tab2")
+ spark.sql("select 1 as a, 42 as b").write.saveAsTable("tab3")
+
+ val df = spark.sql("""
+ with tmp as (select * from tab1 cross join tab2)
+ select * from tmp join tab3 on a = x and b = i
+ """)
+ checkAnswer(df, Row(1, 100, 42, 200, 1, 42))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 4cc8a45391996..37a8815350a53 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2899,6 +2899,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
}
+
+ test("SPARK-26366: verify ReplaceExceptWithFilter") {
+ Seq(true, false).foreach { enabled =>
+ withSQLConf(SQLConf.REPLACE_EXCEPT_WITH_FILTER.key -> enabled.toString) {
+ val df = spark.createDataFrame(
+ sparkContext.parallelize(Seq(Row(0, 3, 5),
+ Row(0, 3, null),
+ Row(null, 3, 5),
+ Row(0, null, 5),
+ Row(0, null, null),
+ Row(null, null, 5),
+ Row(null, 3, null),
+ Row(null, null, null))),
+ StructType(Seq(StructField("c1", IntegerType),
+ StructField("c2", IntegerType),
+ StructField("c3", IntegerType))))
+ val where = "c2 >= 3 OR c1 >= 0"
+ val whereNullSafe =
+ """
+ |(c2 IS NOT NULL AND c2 >= 3)
+ |OR (c1 IS NOT NULL AND c1 >= 0)
+ """.stripMargin
+
+ val df_a = df.filter(where)
+ val df_b = df.filter(whereNullSafe)
+ checkAnswer(df.except(df_a), df.except(df_b))
+
+ val whereWithIn = "c2 >= 3 OR c1 in (2)"
+ val whereWithInNullSafe =
+ """
+ |(c2 IS NOT NULL AND c2 >= 3)
+ """.stripMargin
+ val dfIn_a = df.filter(whereWithIn)
+ val dfIn_b = df.filter(whereWithInNullSafe)
+ checkAnswer(df.except(dfIn_a), df.except(dfIn_b))
+ }
+ }
+ }
}
case class Foo(bar: Option[String])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
index cf4585bf7ac6c..b2515226d9a14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
@@ -137,28 +137,39 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
}
}
+ // For better test coverage, runs the tests on mixed config sets: WHOLESTAGE_CODEGEN_ENABLED
+ // and CODEGEN_FACTORY_MODE.
+ private lazy val codegenConfigSets = Array(
+ ("true", "CODEGEN_ONLY"),
+ ("false", "CODEGEN_ONLY"),
+ ("false", "NO_CODEGEN")
+ ).map { case (wholeStageCodegenEnabled, codegenFactoryMode) =>
+ Array(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> wholeStageCodegenEnabled,
+ SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactoryMode)
+ }
+
/** Run a test case. */
private def runTest(testCase: TestCase): Unit = {
val input = fileToString(new File(testCase.inputFile))
val (comments, code) = input.split("\n").partition(_.startsWith("--"))
- // Runs all the tests on both codegen-only and interpreter modes
- val codegenConfigSets = Array(CODEGEN_ONLY, NO_CODEGEN).map {
- case codegenFactoryMode =>
- Array(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenFactoryMode.toString)
- }
- val configSets = {
- val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5))
- val configs = configLines.map(_.split(",").map { confAndValue =>
- val (conf, value) = confAndValue.span(_ != '=')
- conf.trim -> value.substring(1).trim
- })
- // When we are regenerating the golden files, we don't need to set any config as they
- // all need to return the same result
- if (regenerateGoldenFiles) {
- Array.empty[Array[(String, String)]]
- } else {
+ // List of SQL queries to run
+ // note: this is not a robust way to split queries using semicolon, but works for now.
+ val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq
+
+ // When we are regenerating the golden files, we don't need to set any config as they
+ // all need to return the same result
+ if (regenerateGoldenFiles) {
+ runQueries(queries, testCase.resultFile, None)
+ } else {
+ val configSets = {
+ val configLines = comments.filter(_.startsWith("--SET")).map(_.substring(5))
+ val configs = configLines.map(_.split(",").map { confAndValue =>
+ val (conf, value) = confAndValue.span(_ != '=')
+ conf.trim -> value.substring(1).trim
+ })
+
if (configs.nonEmpty) {
codegenConfigSets.flatMap { codegenConfig =>
configs.map { config =>
@@ -169,15 +180,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext {
codegenConfigSets
}
}
- }
- // List of SQL queries to run
- // note: this is not a robust way to split queries using semicolon, but works for now.
- val queries = code.mkString("\n").split("(?<=[^\\\\]);").map(_.trim).filter(_ != "").toSeq
-
- if (configSets.isEmpty) {
- runQueries(queries, testCase.resultFile, None)
- } else {
configSets.foreach { configSet =>
try {
runQueries(queries, testCase.resultFile, Some(configSet))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
index cd6b2647e0be6..1a1c956aed3d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -27,4 +27,9 @@ class SerializationSuite extends SparkFunSuite with SharedSQLContext {
val spark = SparkSession.builder.getOrCreate()
new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sqlContext)
}
+
+ test("[SPARK-26409] SQLConf should be serializable") {
+ val spark = SparkSession.builder.getOrCreate()
+ new JavaSerializer(new SparkConf()).newInstance().serialize(spark.sessionState.conf)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index cb562d65b6147..02dc32d5f90ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -227,12 +227,12 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared
BigInt(0) -> (("0.0 B", "0")),
BigInt(100) -> (("100.0 B", "100")),
BigInt(2047) -> (("2047.0 B", "2.05E+3")),
- BigInt(2048) -> (("2.0 KB", "2.05E+3")),
- BigInt(3333333) -> (("3.2 MB", "3.33E+6")),
- BigInt(4444444444L) -> (("4.1 GB", "4.44E+9")),
- BigInt(5555555555555L) -> (("5.1 TB", "5.56E+12")),
- BigInt(6666666666666666L) -> (("5.9 PB", "6.67E+15")),
- BigInt(1L << 10 ) * (1L << 60) -> (("1024.0 EB", "1.18E+21")),
+ BigInt(2048) -> (("2.0 KiB", "2.05E+3")),
+ BigInt(3333333) -> (("3.2 MiB", "3.33E+6")),
+ BigInt(4444444444L) -> (("4.1 GiB", "4.44E+9")),
+ BigInt(5555555555555L) -> (("5.1 TiB", "5.56E+12")),
+ BigInt(6666666666666666L) -> (("5.9 PiB", "6.67E+15")),
+ BigInt(1L << 10 ) * (1L << 60) -> (("1024.0 EiB", "1.18E+21")),
BigInt(1L << 11) * (1L << 60) -> (("2.36E+21 B", "2.36E+21"))
)
numbers.foreach { case (input, (expectedSize, expectedRows)) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 5088821ad7361..c95c52f1d3a9c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -22,6 +22,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Sort}
import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
class SubquerySuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -1280,4 +1281,40 @@ class SubquerySuite extends QueryTest with SharedSQLContext {
assert(subqueries.length == 1)
}
}
+
+ test("SPARK-26078: deduplicate fake self joins for IN subqueries") {
+ withTempView("a", "b") {
+ Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("a")
+ Seq("a" -> 2, "b" -> 1).toDF("id", "num").createTempView("b")
+
+ val df1 = spark.sql(
+ """
+ |SELECT id,num,source FROM (
+ | SELECT id, num, 'a' as source FROM a
+ | UNION ALL
+ | SELECT id, num, 'b' as source FROM b
+ |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2)
+ """.stripMargin)
+ checkAnswer(df1, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
+ val df2 = spark.sql(
+ """
+ |SELECT id,num,source FROM (
+ | SELECT id, num, 'a' as source FROM a
+ | UNION ALL
+ | SELECT id, num, 'b' as source FROM b
+ |) AS c WHERE c.id NOT IN (SELECT id FROM b WHERE num = 2)
+ """.stripMargin)
+ checkAnswer(df2, Seq(Row("b", 1, "a"), Row("b", 1, "b")))
+ val df3 = spark.sql(
+ """
+ |SELECT id,num,source FROM (
+ | SELECT id, num, 'a' as source FROM a
+ | UNION ALL
+ | SELECT id, num, 'b' as source FROM b
+ |) AS c WHERE c.id IN (SELECT id FROM b WHERE num = 2) OR
+ |c.id IN (SELECT id FROM b WHERE num = 3)
+ """.stripMargin)
+ checkAnswer(df3, Seq(Row("a", 2, "a"), Row("a", 2, "b")))
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 20dcefa7e3cad..a26d306cff6b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import java.math.BigDecimal
+
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.QueryExecution
@@ -26,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.InsertIntoHadoopFsRelationComm
import org.apache.spark.sql.functions.{lit, udf}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
-import org.apache.spark.sql.types.{DataTypes, DoubleType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.util.QueryExecutionListener
@@ -420,4 +422,32 @@ class UDFSuite extends QueryTest with SharedSQLContext {
checkAnswer(df, Seq(Row("null1x"), Row(null), Row("N3null")))
}
}
+
+ test("SPARK-26308: udf with decimal") {
+ val df1 = spark.createDataFrame(
+ sparkContext.parallelize(Seq(Row(new BigDecimal("2011000000000002456556")))),
+ StructType(Seq(StructField("col1", DecimalType(30, 0)))))
+ val udf1 = org.apache.spark.sql.functions.udf((value: BigDecimal) => {
+ if (value == null) null else value.toBigInteger.toString
+ })
+ checkAnswer(df1.select(udf1(df1.col("col1"))), Seq(Row("2011000000000002456556")))
+ }
+
+ test("SPARK-26308: udf with complex types of decimal") {
+ val df1 = spark.createDataFrame(
+ sparkContext.parallelize(Seq(Row(Array(new BigDecimal("2011000000000002456556"))))),
+ StructType(Seq(StructField("col1", ArrayType(DecimalType(30, 0))))))
+ val udf1 = org.apache.spark.sql.functions.udf((arr: Seq[BigDecimal]) => {
+ arr.map(value => if (value == null) null else value.toBigInteger.toString)
+ })
+ checkAnswer(df1.select(udf1($"col1")), Seq(Row(Array("2011000000000002456556"))))
+
+ val df2 = spark.createDataFrame(
+ sparkContext.parallelize(Seq(Row(Map("a" -> new BigDecimal("2011000000000002456556"))))),
+ StructType(Seq(StructField("col1", MapType(StringType, DecimalType(30, 0))))))
+ val udf2 = org.apache.spark.sql.functions.udf((map: Map[String, BigDecimal]) => {
+ map.mapValues(value => if (value == null) null else value.toBigInteger.toString)
+ })
+ checkAnswer(df2.select(udf2($"col1")), Seq(Row(Map("a" -> "2011000000000002456556"))))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 6ad025f37e440..4a439940beb74 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -263,7 +263,6 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
.setMaster("local[*]")
.setAppName("test")
.set("spark.ui.enabled", "false")
- .set("spark.driver.allowMultipleContexts", "true")
.set(SQLConf.SHUFFLE_PARTITIONS.key, "5")
.set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true")
.set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
index 1ad5713ab8ae6..ca8692290edb2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{LocalSparkSession, Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
-import org.apache.spark.sql.execution.metric.SQLShuffleMetricsReporter
+import org.apache.spark.sql.execution.metric.SQLShuffleReadMetricsReporter
import org.apache.spark.sql.types._
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter
@@ -140,7 +140,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkSession {
new UnsafeRowSerializer(2))
val shuffled = new ShuffledRowRDD(
dependency,
- SQLShuffleMetricsReporter.createShuffleReadMetrics(spark.sparkContext))
+ SQLShuffleReadMetricsReporter.createShuffleReadMetrics(spark.sparkContext))
shuffled.count()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala
new file mode 100644
index 0000000000000..bdf753debe62a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/HashedRelationMetricsBenchmark.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.benchmark
+
+import org.apache.spark.SparkConf
+import org.apache.spark.benchmark.Benchmark
+import org.apache.spark.internal.config.MEMORY_OFFHEAP_ENABLED
+import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{BoundReference, UnsafeProjection}
+import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap
+import org.apache.spark.sql.types.LongType
+
+/**
+ * Benchmark to measure metrics performance at HashedRelation.
+ * To run this benchmark:
+ * {{{
+ * 1. without sbt: bin/spark-submit --class
+ * 2. build/sbt "sql/test:runMain "
+ * 3. generate result: SPARK_GENERATE_BENCHMARK_FILES=1 build/sbt "sql/test:runMain "
+ * Results will be written to "benchmarks/HashedRelationMetricsBenchmark-results.txt".
+ * }}}
+ */
+object HashedRelationMetricsBenchmark extends SqlBasedBenchmark {
+
+ def benchmarkLongToUnsafeRowMapMetrics(numRows: Int): Unit = {
+ runBenchmark("LongToUnsafeRowMap metrics") {
+ val benchmark = new Benchmark("LongToUnsafeRowMap metrics", numRows, output = output)
+ benchmark.addCase("LongToUnsafeRowMap") { iter =>
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
+
+ val keys = Range.Long(0, numRows, 1)
+ val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ keys.foreach { k =>
+ map.append(k, unsafeProj(InternalRow(k)))
+ }
+ map.optimize()
+
+ val threads = (0 to 100).map { _ =>
+ val thread = new Thread {
+ override def run: Unit = {
+ val row = unsafeProj(InternalRow(0L)).copy()
+ keys.foreach { k =>
+ assert(map.getValue(k, row) eq row)
+ assert(row.getLong(0) == k)
+ }
+ }
+ }
+ thread.start()
+ thread
+ }
+ threads.map(_.join())
+ map.free()
+ }
+ benchmark.run()
+ }
+ }
+
+ override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
+ benchmarkLongToUnsafeRowMapMetrics(500000)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala
index ec552f7ddf47a..6bd0a2591fc1f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{StringType, StructField, StructType}
+import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator}
class FileIndexSuite extends SharedSQLContext {
@@ -95,6 +95,31 @@ class FileIndexSuite extends SharedSQLContext {
}
}
+ test("SPARK-26263: Throw exception when partition value can't be casted to user-specified type") {
+ withTempDir { dir =>
+ val partitionDirectory = new File(dir, "a=foo")
+ partitionDirectory.mkdir()
+ val file = new File(partitionDirectory, "text.txt")
+ stringToFile(file, "text")
+ val path = new Path(dir.getCanonicalPath)
+ val schema = StructType(Seq(StructField("a", IntegerType, false)))
+ withSQLConf(SQLConf.VALIDATE_PARTITION_COLUMNS.key -> "true") {
+ val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema))
+ val msg = intercept[RuntimeException] {
+ fileIndex.partitionSpec()
+ }.getMessage
+ assert(msg == "Failed to cast value `foo` to `IntegerType` for partition column `a`")
+ }
+
+ withSQLConf(SQLConf.VALIDATE_PARTITION_COLUMNS.key -> "false") {
+ val fileIndex = new InMemoryFileIndex(spark, Seq(path), Map.empty, Some(schema))
+ val partitionValues = fileIndex.partitionSpec().partitions.map(_.values)
+ assert(partitionValues.length == 1 && partitionValues(0).numFields == 1 &&
+ partitionValues(0).isNullAt(0))
+ }
+ }
+ }
+
test("InMemoryFileIndex: input paths are converted to qualified paths") {
withTempDir { dir =>
val file = new File(dir, "text.txt")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index 3b977d74053e6..d9e5d7af19671 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -63,6 +63,7 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
private val datesFile = "test-data/dates.csv"
private val unescapedQuotesFile = "test-data/unescaped-quotes.csv"
private val valueMalformedFile = "test-data/value-malformed.csv"
+ private val badAfterGoodFile = "test-data/bad_after_good.csv"
/** Verifies data and schema. */
private def verifyCars(
@@ -2012,4 +2013,22 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
assert(!files.exists(_.getName.endsWith("csv")))
}
}
+
+ test("Do not reuse last good value for bad input field") {
+ val schema = StructType(
+ StructField("col1", StringType) ::
+ StructField("col2", DateType) ::
+ Nil
+ )
+ val rows = spark.read
+ .schema(schema)
+ .format("csv")
+ .load(testFile(badAfterGoodFile))
+
+ val expectedRows = Seq(
+ Row("good record", java.sql.Date.valueOf("1999-08-01")),
+ Row("bad record", null))
+
+ checkAnswer(rows, expectedRows)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index dff37ca2d40f0..8f575a371c98e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -38,6 +38,7 @@ import org.apache.spark.sql.execution.datasources.DataSource
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
+import org.apache.spark.sql.types.StructType.fromDDL
import org.apache.spark.util.Utils
class TestFileFilter extends PathFilter {
@@ -57,14 +58,17 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
val factory = new JsonFactory()
- def enforceCorrectType(value: Any, dataType: DataType): Any = {
+ def enforceCorrectType(
+ value: Any,
+ dataType: DataType,
+ options: Map[String, String] = Map.empty): Any = {
val writer = new StringWriter()
Utils.tryWithResource(factory.createGenerator(writer)) { generator =>
generator.writeObject(value)
generator.flush()
}
- val dummyOption = new JSONOptions(Map.empty[String, String], "GMT")
+ val dummyOption = new JSONOptions(options, SQLConf.get.sessionLocalTimeZone)
val dummySchema = StructType(Seq.empty)
val parser = new JacksonParser(dummySchema, dummyOption, allowArrayAsStructs = true)
@@ -96,19 +100,27 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(intNumber.toLong * 1000L)),
enforceCorrectType(intNumber.toLong, TimestampType))
val strTime = "2014-09-30 12:34:56"
- checkTypePromotion(DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)),
- enforceCorrectType(strTime, TimestampType))
+ checkTypePromotion(
+ expected = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf(strTime)),
+ enforceCorrectType(strTime, TimestampType,
+ Map("timestampFormat" -> "yyyy-MM-dd HH:mm:ss")))
val strDate = "2014-10-15"
checkTypePromotion(
DateTimeUtils.fromJavaDate(Date.valueOf(strDate)), enforceCorrectType(strDate, DateType))
val ISO8601Time1 = "1970-01-01T01:00:01.0Z"
- val ISO8601Time2 = "1970-01-01T02:00:01-01:00"
checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(3601000)),
- enforceCorrectType(ISO8601Time1, TimestampType))
+ enforceCorrectType(
+ ISO8601Time1,
+ TimestampType,
+ Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss.SX")))
+ val ISO8601Time2 = "1970-01-01T02:00:01-01:00"
checkTypePromotion(DateTimeUtils.fromJavaTimestamp(new Timestamp(10801000)),
- enforceCorrectType(ISO8601Time2, TimestampType))
+ enforceCorrectType(
+ ISO8601Time2,
+ TimestampType,
+ Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ssXXX")))
val ISO8601Date = "1970-01-01"
checkTypePromotion(DateTimeUtils.millisToDays(32400000),
@@ -248,7 +260,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
checkAnswer(
sql("select nullstr, headers.Host from jsonTable"),
- Seq(Row("", "1.abc.com"), Row("", null), Row(null, null), Row(null, null))
+ Seq(Row("", "1.abc.com"), Row("", null), Row("", null), Row(null, null))
)
}
@@ -1440,103 +1452,105 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
}
test("backward compatibility") {
- // This test we make sure our JSON support can read JSON data generated by previous version
- // of Spark generated through toJSON method and JSON data source.
- // The data is generated by the following program.
- // Here are a few notes:
- // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13)
- // in the JSON object.
- // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to
- // JSON objects generated by those Spark versions (col17).
- // - If the type is NullType, we do not write data out.
-
- // Create the schema.
- val struct =
- StructType(
- StructField("f1", FloatType, true) ::
- StructField("f2", ArrayType(BooleanType), true) :: Nil)
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "true") {
+ // This test we make sure our JSON support can read JSON data generated by previous version
+ // of Spark generated through toJSON method and JSON data source.
+ // The data is generated by the following program.
+ // Here are a few notes:
+ // - Spark 1.5.0 cannot save timestamp data. So, we manually added timestamp field (col13)
+ // in the JSON object.
+ // - For Spark before 1.5.1, we do not generate UDTs. So, we manually added the UDT value to
+ // JSON objects generated by those Spark versions (col17).
+ // - If the type is NullType, we do not write data out.
+
+ // Create the schema.
+ val struct =
+ StructType(
+ StructField("f1", FloatType, true) ::
+ StructField("f2", ArrayType(BooleanType), true) :: Nil)
- val dataTypes =
- Seq(
- StringType, BinaryType, NullType, BooleanType,
- ByteType, ShortType, IntegerType, LongType,
- FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
- DateType, TimestampType,
- ArrayType(IntegerType), MapType(StringType, LongType), struct,
- new TestUDT.MyDenseVectorUDT())
- val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
- StructField(s"col$index", dataType, nullable = true)
- }
- val schema = StructType(fields)
+ val dataTypes =
+ Seq(
+ StringType, BinaryType, NullType, BooleanType,
+ ByteType, ShortType, IntegerType, LongType,
+ FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
+ DateType, TimestampType,
+ ArrayType(IntegerType), MapType(StringType, LongType), struct,
+ new TestUDT.MyDenseVectorUDT())
+ val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
+ StructField(s"col$index", dataType, nullable = true)
+ }
+ val schema = StructType(fields)
- val constantValues =
- Seq(
- "a string in binary".getBytes(StandardCharsets.UTF_8),
- null,
- true,
- 1.toByte,
- 2.toShort,
- 3,
- Long.MaxValue,
- 0.25.toFloat,
- 0.75,
- new java.math.BigDecimal(s"1234.23456"),
- new java.math.BigDecimal(s"1.23456"),
- java.sql.Date.valueOf("2015-01-01"),
- java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"),
- Seq(2, 3, 4),
- Map("a string" -> 2000L),
- Row(4.75.toFloat, Seq(false, true)),
- new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25)))
- val data =
- Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil
+ val constantValues =
+ Seq(
+ "a string in binary".getBytes(StandardCharsets.UTF_8),
+ null,
+ true,
+ 1.toByte,
+ 2.toShort,
+ 3,
+ Long.MaxValue,
+ 0.25.toFloat,
+ 0.75,
+ new java.math.BigDecimal(s"1234.23456"),
+ new java.math.BigDecimal(s"1.23456"),
+ java.sql.Date.valueOf("2015-01-01"),
+ java.sql.Timestamp.valueOf("2015-01-01 23:50:59.123"),
+ Seq(2, 3, 4),
+ Map("a string" -> 2000L),
+ Row(4.75.toFloat, Seq(false, true)),
+ new TestUDT.MyDenseVector(Array(0.25, 2.25, 4.25)))
+ val data =
+ Row.fromSeq(Seq("Spark " + spark.sparkContext.version) ++ constantValues) :: Nil
- // Data generated by previous versions.
- // scalastyle:off
- val existingJSONData =
+ // Data generated by previous versions.
+ // scalastyle:off
+ val existingJSONData =
"""{"col0":"Spark 1.2.2","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
- """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
- """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
- """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
- """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
- """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
- """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil
- // scalastyle:on
-
- // Generate data for the current version.
- val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema)
- withTempPath { path =>
- df.write.format("json").mode("overwrite").save(path.getCanonicalPath)
+ """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
+ """{"col0":"Spark 1.3.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
+ """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
+ """{"col0":"Spark 1.4.1","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
+ """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"2015-01-01","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" ::
+ """{"col0":"Spark 1.5.0","col1":"YSBzdHJpbmcgaW4gYmluYXJ5","col3":true,"col4":1,"col5":2,"col6":3,"col7":9223372036854775807,"col8":0.25,"col9":0.75,"col10":1234.23456,"col11":1.23456,"col12":"16436","col13":"2015-01-01 23:50:59.123","col14":[2,3,4],"col15":{"a string":2000},"col16":{"f1":4.75,"f2":[false,true]},"col17":[0.25,2.25,4.25]}""" :: Nil
+ // scalastyle:on
+
+ // Generate data for the current version.
+ val df = spark.createDataFrame(spark.sparkContext.parallelize(data, 1), schema)
+ withTempPath { path =>
+ df.write.format("json").mode("overwrite").save(path.getCanonicalPath)
- // df.toJSON will convert internal rows to external rows first and then generate
- // JSON objects. While, df.write.format("json") will write internal rows directly.
- val allJSON =
+ // df.toJSON will convert internal rows to external rows first and then generate
+ // JSON objects. While, df.write.format("json") will write internal rows directly.
+ val allJSON =
existingJSONData ++
df.toJSON.collect() ++
sparkContext.textFile(path.getCanonicalPath).collect()
- Utils.deleteRecursively(path)
- sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath)
-
- // Read data back with the schema specified.
- val col0Values =
- Seq(
- "Spark 1.2.2",
- "Spark 1.3.1",
- "Spark 1.3.1",
- "Spark 1.4.1",
- "Spark 1.4.1",
- "Spark 1.5.0",
- "Spark 1.5.0",
- "Spark " + spark.sparkContext.version,
- "Spark " + spark.sparkContext.version)
- val expectedResult = col0Values.map { v =>
- Row.fromSeq(Seq(v) ++ constantValues)
+ Utils.deleteRecursively(path)
+ sparkContext.parallelize(allJSON, 1).saveAsTextFile(path.getCanonicalPath)
+
+ // Read data back with the schema specified.
+ val col0Values =
+ Seq(
+ "Spark 1.2.2",
+ "Spark 1.3.1",
+ "Spark 1.3.1",
+ "Spark 1.4.1",
+ "Spark 1.4.1",
+ "Spark 1.5.0",
+ "Spark 1.5.0",
+ "Spark " + spark.sparkContext.version,
+ "Spark " + spark.sparkContext.version)
+ val expectedResult = col0Values.map { v =>
+ Row.fromSeq(Seq(v) ++ constantValues)
+ }
+ checkAnswer(
+ spark.read.format("json").schema(schema).load(path.getCanonicalPath),
+ expectedResult
+ )
}
- checkAnswer(
- spark.read.format("json").schema(schema).load(path.getCanonicalPath),
- expectedResult
- )
}
}
@@ -2563,4 +2577,68 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
assert(!files.exists(_.getName.endsWith("json")))
}
}
+
+ test("return partial result for bad records") {
+ val schema = "a double, b array, c string, _corrupt_record string"
+ val badRecords = Seq(
+ """{"a":"-","b":[0, 1, 2],"c":"abc"}""",
+ """{"a":0.1,"b":{},"c":"def"}""").toDS()
+ val df = spark.read.schema(schema).json(badRecords)
+
+ checkAnswer(
+ df,
+ Row(null, Array(0, 1, 2), "abc", """{"a":"-","b":[0, 1, 2],"c":"abc"}""") ::
+ Row(0.1, null, "def", """{"a":0.1,"b":{},"c":"def"}""") :: Nil)
+ }
+
+ test("inferring timestamp type") {
+ Seq(true, false).foreach { legacyParser =>
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) {
+ def schemaOf(jsons: String*): StructType = spark.read.json(jsons.toDS).schema
+
+ assert(schemaOf(
+ """{"a":"2018-12-17T10:11:12.123-01:00"}""",
+ """{"a":"2018-12-16T22:23:24.123-02:00"}""") === fromDDL("a timestamp"))
+
+ assert(schemaOf("""{"a":"2018-12-17T10:11:12.123-01:00"}""", """{"a":1}""")
+ === fromDDL("a string"))
+ assert(schemaOf("""{"a":"2018-12-17T10:11:12.123-01:00"}""", """{"a":"123"}""")
+ === fromDDL("a string"))
+
+ assert(schemaOf("""{"a":"2018-12-17T10:11:12.123-01:00"}""", """{"a":null}""")
+ === fromDDL("a timestamp"))
+ assert(schemaOf("""{"a":null}""", """{"a":"2018-12-17T10:11:12.123-01:00"}""")
+ === fromDDL("a timestamp"))
+ }
+ }
+ }
+
+ test("roundtrip for timestamp type inferring") {
+ Seq(true, false).foreach { legacyParser =>
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> legacyParser.toString) {
+ val customSchema = new StructType().add("date", TimestampType)
+ withTempDir { dir =>
+ val timestampsWithFormatPath = s"${dir.getCanonicalPath}/timestampsWithFormat.json"
+ val timestampsWithFormat = spark.read
+ .option("timestampFormat", "dd/MM/yyyy HH:mm")
+ .json(datesRecords)
+ assert(timestampsWithFormat.schema === customSchema)
+
+ timestampsWithFormat.write
+ .format("json")
+ .option("timestampFormat", "yyyy-MM-dd HH:mm:ss")
+ .option(DateTimeUtils.TIMEZONE_OPTION, "UTC")
+ .save(timestampsWithFormatPath)
+
+ val readBack = spark.read
+ .option("timestampFormat", "yyyy-MM-dd HH:mm:ss")
+ .option(DateTimeUtils.TIMEZONE_OPTION, "UTC")
+ .json(timestampsWithFormatPath)
+
+ assert(readBack.schema === customSchema)
+ checkAnswer(readBack, timestampsWithFormat)
+ }
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index f808ca458aaa7..88067358667c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -101,7 +101,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
"hdfs://host:9000/path/a=10.5/b=hello")
var exception = intercept[AssertionError] {
- parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, timeZoneId)
+ parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], None, true, true, timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
@@ -117,6 +117,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/path/")),
None,
true,
+ true,
timeZoneId)
// Valid
@@ -132,6 +133,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/path/something=true/table")),
None,
true,
+ true,
timeZoneId)
// Valid
@@ -147,6 +149,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/path/table=true")),
None,
true,
+ true,
timeZoneId)
// Invalid
@@ -162,6 +165,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/path/")),
None,
true,
+ true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
@@ -184,6 +188,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
Set(new Path("hdfs://host:9000/tmp/tables/")),
None,
true,
+ true,
timeZoneId)
}
assert(exception.getMessage().contains("Conflicting directory structures detected"))
@@ -191,13 +196,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
test("parse partition") {
def check(path: String, expected: Option[PartitionValues]): Unit = {
- val actual = parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)._1
+ val actual = parsePartition(new Path(path), true, Set.empty[Path],
+ Map.empty, true, timeZone)._1
assert(expected === actual)
}
def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = {
val message = intercept[T] {
- parsePartition(new Path(path), true, Set.empty[Path], Map.empty, timeZone)
+ parsePartition(new Path(path), true, Set.empty[Path], Map.empty, true, timeZone)
}.getMessage
assert(message.contains(expected))
@@ -242,6 +248,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
typeInference = true,
basePaths = Set(new Path("file://path/a=10")),
Map.empty,
+ true,
timeZone = timeZone)._1
assert(partitionSpec1.isEmpty)
@@ -252,6 +259,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
typeInference = true,
basePaths = Set(new Path("file://path")),
Map.empty,
+ true,
timeZone = timeZone)._1
assert(partitionSpec2 ==
@@ -272,6 +280,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
rootPaths,
None,
true,
+ true,
timeZoneId)
assert(actualSpec.partitionColumns === spec.partitionColumns)
assert(actualSpec.partitions.length === spec.partitions.length)
@@ -384,7 +393,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha
test("parse partitions with type inference disabled") {
def check(paths: Seq[String], spec: PartitionSpec): Unit = {
val actualSpec =
- parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None, true, timeZoneId)
+ parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], None,
+ true, true, timeZoneId)
assert(actualSpec === spec)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 0f1d08b6af5d5..47265df4831df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -96,8 +96,9 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))
val shuffleExpected1 = Map(
"records read" -> 2L,
- "local blocks fetched" -> 2L,
- "remote blocks fetched" -> 0L)
+ "local blocks read" -> 2L,
+ "remote blocks read" -> 0L,
+ "shuffle records written" -> 2L)
testSparkPlanMetrics(df, 1, Map(
2L -> (("HashAggregate", expected1(0))),
1L -> (("Exchange", shuffleExpected1)),
@@ -113,8 +114,9 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
"avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))
val shuffleExpected2 = Map(
"records read" -> 4L,
- "local blocks fetched" -> 4L,
- "remote blocks fetched" -> 0L)
+ "local blocks read" -> 4L,
+ "remote blocks read" -> 0L,
+ "shuffle records written" -> 4L)
testSparkPlanMetrics(df2, 1, Map(
2L -> (("HashAggregate", expected2(0))),
1L -> (("Exchange", shuffleExpected2)),
@@ -170,6 +172,11 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
val df = testData2.groupBy().agg(collect_set('a)) // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> (("ObjectHashAggregate", Map("number of output rows" -> 2L))),
+ 1L -> (("Exchange", Map(
+ "shuffle records written" -> 2L,
+ "records read" -> 2L,
+ "local blocks read" -> 2L,
+ "remote blocks read" -> 0L))),
0L -> (("ObjectHashAggregate", Map("number of output rows" -> 1L))))
)
@@ -177,6 +184,11 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
val df2 = testData2.groupBy('a).agg(collect_set('a))
testSparkPlanMetrics(df2, 1, Map(
2L -> (("ObjectHashAggregate", Map("number of output rows" -> 4L))),
+ 1L -> (("Exchange", Map(
+ "shuffle records written" -> 4L,
+ "records read" -> 4L,
+ "local blocks read" -> 4L,
+ "remote blocks read" -> 0L))),
0L -> (("ObjectHashAggregate", Map("number of output rows" -> 3L))))
)
}
@@ -204,8 +216,9 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
"number of output rows" -> 4L))),
2L -> (("Exchange", Map(
"records read" -> 4L,
- "local blocks fetched" -> 2L,
- "remote blocks fetched" -> 0L))))
+ "local blocks read" -> 2L,
+ "remote blocks read" -> 0L,
+ "shuffle records written" -> 2L))))
)
}
}
@@ -248,50 +261,6 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
)
}
- test("BroadcastHashJoin metrics: track avg probe") {
- // The executed plan looks like:
- // Project [a#210, b#211, b#221]
- // +- BroadcastHashJoin [a#210], [a#220], Inner, BuildRight
- // :- Project [_1#207 AS a#210, _2#208 AS b#211]
- // : +- Filter isnotnull(_1#207)
- // : +- LocalTableScan [_1#207, _2#208]
- // +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, binary, true]))
- // +- Project [_1#217 AS a#220, _2#218 AS b#221]
- // +- Filter isnotnull(_1#217)
- // +- LocalTableScan [_1#217, _2#218]
- //
- // Assume the execution plan with node id is
- // WholeStageCodegen disabled:
- // Project(nodeId = 0)
- // BroadcastHashJoin(nodeId = 1)
- // ...(ignored)
- //
- // WholeStageCodegen enabled:
- // WholeStageCodegen(nodeId = 0)
- // Project(nodeId = 1)
- // BroadcastHashJoin(nodeId = 2)
- // Project(nodeId = 3)
- // Filter(nodeId = 4)
- // ...(ignored)
- Seq(true, false).foreach { enableWholeStage =>
- val df1 = generateRandomBytesDF()
- val df2 = generateRandomBytesDF()
- val df = df1.join(broadcast(df2), "a")
- val nodeIds = if (enableWholeStage) {
- Set(2L)
- } else {
- Set(1L)
- }
- val metrics = getSparkPlanMetrics(df, 2, nodeIds, enableWholeStage).get
- nodeIds.foreach { nodeId =>
- val probes = metrics(nodeId)._2("avg hash probe (min, med, max)")
- probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
- assert(probe.toDouble > 1.0)
- }
- }
- }
- }
-
test("ShuffledHashJoin metrics") {
withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "40",
"spark.sql.shuffle.partitions" -> "2",
@@ -299,63 +268,28 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value")
val df2 = (1 to 10).map(i => (i, i.toString)).toSeq.toDF("key", "value")
// Assume the execution plan is
- // ... -> ShuffledHashJoin(nodeId = 1) -> Project(nodeId = 0)
+ // Project(nodeId = 0)
+ // +- ShuffledHashJoin(nodeId = 1)
+ // :- Exchange(nodeId = 2)
+ // : +- Project(nodeId = 3)
+ // : +- LocalTableScan(nodeId = 4)
+ // +- Exchange(nodeId = 5)
+ // +- Project(nodeId = 6)
+ // +- LocalTableScan(nodeId = 7)
val df = df1.join(df2, "key")
testSparkPlanMetrics(df, 1, Map(
1L -> (("ShuffledHashJoin", Map(
- "number of output rows" -> 2L,
- "avg hash probe (min, med, max)" -> "\n(1, 1, 1)"))))
+ "number of output rows" -> 2L))),
+ 2L -> (("Exchange", Map(
+ "shuffle records written" -> 2L,
+ "records read" -> 2L))),
+ 5L -> (("Exchange", Map(
+ "shuffle records written" -> 10L,
+ "records read" -> 10L))))
)
}
}
- test("ShuffledHashJoin metrics: track avg probe") {
- // The executed plan looks like:
- // Project [a#308, b#309, b#319]
- // +- ShuffledHashJoin [a#308], [a#318], Inner, BuildRight
- // :- Exchange hashpartitioning(a#308, 2)
- // : +- Project [_1#305 AS a#308, _2#306 AS b#309]
- // : +- Filter isnotnull(_1#305)
- // : +- LocalTableScan [_1#305, _2#306]
- // +- Exchange hashpartitioning(a#318, 2)
- // +- Project [_1#315 AS a#318, _2#316 AS b#319]
- // +- Filter isnotnull(_1#315)
- // +- LocalTableScan [_1#315, _2#316]
- //
- // Assume the execution plan with node id is
- // WholeStageCodegen disabled:
- // Project(nodeId = 0)
- // ShuffledHashJoin(nodeId = 1)
- // ...(ignored)
- //
- // WholeStageCodegen enabled:
- // WholeStageCodegen(nodeId = 0)
- // Project(nodeId = 1)
- // ShuffledHashJoin(nodeId = 2)
- // ...(ignored)
- withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "5000000",
- "spark.sql.shuffle.partitions" -> "2",
- "spark.sql.join.preferSortMergeJoin" -> "false") {
- Seq(true, false).foreach { enableWholeStage =>
- val df1 = generateRandomBytesDF(65535 * 5)
- val df2 = generateRandomBytesDF(65535)
- val df = df1.join(df2, "a")
- val nodeIds = if (enableWholeStage) {
- Set(2L)
- } else {
- Set(1L)
- }
- val metrics = getSparkPlanMetrics(df, 1, nodeIds, enableWholeStage).get
- nodeIds.foreach { nodeId =>
- val probes = metrics(nodeId)._2("avg hash probe (min, med, max)")
- probes.toString.stripPrefix("\n(").stripSuffix(")").split(", ").foreach { probe =>
- assert(probe.toDouble > 1.0)
- }
- }
- }
- }
- }
-
test("BroadcastHashJoin(outer) metrics") {
val df1 = Seq((1, "a"), (1, "b"), (4, "c")).toDF("key", "value")
val df2 = Seq((1, "a"), (1, "b"), (2, "c"), (3, "d")).toDF("key2", "value")
@@ -610,4 +544,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared
assert(filters.head.metrics("numOutputRows").value == 1)
}
}
+
+ test("SPARK-26327: FileSourceScanExec metrics") {
+ withTable("testDataForScan") {
+ spark.range(10).selectExpr("id", "id % 3 as p")
+ .write.partitionBy("p").saveAsTable("testDataForScan")
+ // The execution plan only has 1 FileScan node.
+ val df = spark.sql(
+ "SELECT * FROM testDataForScan WHERE p = 1")
+ testSparkPlanMetrics(df, 1, Map(
+ 0L -> (("Scan parquet default.testdataforscan", Map(
+ "number of output rows" -> 3L,
+ "number of files" -> 2L))))
+ )
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
index 7bef687e7e43b..2f460b044b237 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryStatusAndProgressSuite.scala
@@ -73,7 +73,8 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
| "inputRowsPerSecond" : 10.0
| } ],
| "sink" : {
- | "description" : "sink"
+ | "description" : "sink",
+ | "numOutputRows" : -1
| }
|}
""".stripMargin.trim)
@@ -105,7 +106,8 @@ class StreamingQueryStatusAndProgressSuite extends StreamTest with Eventually {
| "numInputRows" : 678
| } ],
| "sink" : {
- | "description" : "sink"
+ | "description" : "sink",
+ | "numOutputRows" : -1
| }
|}
""".stripMargin.trim)
@@ -250,7 +252,7 @@ object StreamingQueryStatusAndProgressSuite {
processedRowsPerSecond = Double.PositiveInfinity // should not be present in the json
)
),
- sink = new SinkProgress("sink")
+ sink = SinkProgress("sink", None)
)
val testProgress2 = new StreamingQueryProgress(
@@ -274,7 +276,7 @@ object StreamingQueryStatusAndProgressSuite {
processedRowsPerSecond = Double.NegativeInfinity // should not be present in the json
)
),
- sink = new SinkProgress("sink")
+ sink = SinkProgress("sink", None)
)
val testStatus = new StreamingQueryStatus("active", true, false)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala
new file mode 100644
index 0000000000000..10bea7f090571
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming.continuous
+
+import org.apache.spark.sql.execution.streaming.StreamExecution
+import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream
+import org.apache.spark.sql.streaming.Trigger
+
+class ContinuousQueryStatusAndProgressSuite extends ContinuousSuiteBase {
+ test("StreamingQueryStatus - ContinuousExecution isDataAvailable and isTriggerActive " +
+ "should be false") {
+ import testImplicits._
+
+ val input = ContinuousMemoryStream[Int]
+
+ def assertStatus(stream: StreamExecution): Unit = {
+ assert(stream.status.isDataAvailable === false)
+ assert(stream.status.isTriggerActive === false)
+ }
+
+ val trigger = Trigger.Continuous(100)
+ testStream(input.toDF(), useV2Sink = true)(
+ StartStream(trigger),
+ Execute(assertStatus),
+ AddData(input, 0, 1, 2),
+ Execute(assertStatus),
+ CheckAnswer(0, 1, 2),
+ Execute(assertStatus),
+ StopStream,
+ Execute(assertStatus),
+ AddData(input, 3, 4, 5),
+ Execute(assertStatus),
+ StartStream(trigger),
+ Execute(assertStatus),
+ CheckAnswer(0, 1, 2, 3, 4, 5),
+ Execute(assertStatus),
+ StopStream,
+ Execute(assertStatus))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index 237872585e11d..e45ab19aadbfa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -23,6 +23,13 @@ import java.util.concurrent.ConcurrentLinkedQueue
import scala.collection.JavaConverters._
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.parquet.hadoop.ParquetFileReader
+import org.apache.parquet.hadoop.util.HadoopInputFile
+import org.apache.parquet.schema.PrimitiveType
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
+import org.apache.parquet.schema.Type.Repetition
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkContext
@@ -31,6 +38,7 @@ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.TableIdentifier
+import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@@ -522,11 +530,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
Seq("json", "orc", "parquet", "csv").foreach { format =>
val schema = StructType(
StructField("cl1", IntegerType, nullable = false).withComment("test") ::
- StructField("cl2", IntegerType, nullable = true) ::
- StructField("cl3", IntegerType, nullable = true) :: Nil)
+ StructField("cl2", IntegerType, nullable = true) ::
+ StructField("cl3", IntegerType, nullable = true) :: Nil)
val row = Row(3, null, 4)
val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema)
+ // if we write and then read, the read will enforce schema to be nullable
val tableName = "tab"
withTable(tableName) {
df.write.format(format).mode("overwrite").saveAsTable(tableName)
@@ -536,12 +545,41 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
Row("cl1", "test") :: Nil)
// Verify the schema
val expectedFields = schema.fields.map(f => f.copy(nullable = true))
- assert(spark.table(tableName).schema == schema.copy(fields = expectedFields))
+ assert(spark.table(tableName).schema === schema.copy(fields = expectedFields))
}
}
}
}
+ test("parquet - column nullability -- write only") {
+ val schema = StructType(
+ StructField("cl1", IntegerType, nullable = false) ::
+ StructField("cl2", IntegerType, nullable = true) :: Nil)
+ val row = Row(3, 4)
+ val df = spark.createDataFrame(sparkContext.parallelize(row :: Nil), schema)
+
+ withTempPath { dir =>
+ val path = dir.getAbsolutePath
+ df.write.mode("overwrite").parquet(path)
+ val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0)
+
+ val hadoopInputFile = HadoopInputFile.fromPath(new Path(file), new Configuration())
+ val f = ParquetFileReader.open(hadoopInputFile)
+ val parquetSchema = f.getFileMetaData.getSchema.getColumns.asScala
+ .map(_.getPrimitiveType)
+ f.close()
+
+ // the write keeps nullable info from the schema
+ val expectedParquetSchema = Seq(
+ new PrimitiveType(Repetition.REQUIRED, PrimitiveTypeName.INT32, "cl1"),
+ new PrimitiveType(Repetition.OPTIONAL, PrimitiveTypeName.INT32, "cl2")
+ )
+
+ assert (expectedParquetSchema === parquetSchema)
+ }
+
+ }
+
test("SPARK-17230: write out results of decimal calculation") {
val df = spark.range(99, 101)
.selectExpr("id", "cast(id as long) * cast('1.0' as decimal(38, 18)) as num")
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 5823548a8063c..03f4b8d83e353 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive
+import java.util.Locale
+
import scala.util.control.NonFatal
import com.google.common.util.concurrent.Striped
@@ -29,6 +31,8 @@ import org.apache.spark.sql.catalyst.{QualifiedTableName, TableIdentifier}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.HiveCaseSensitiveInferenceMode._
import org.apache.spark.sql.types._
@@ -113,7 +117,44 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
}
}
- def convertToLogicalRelation(
+ // Return true for Apache ORC and Hive ORC-related configuration names.
+ // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`.
+ private def isOrcProperty(key: String) =
+ key.startsWith("orc.") || key.contains(".orc.")
+
+ private def isParquetProperty(key: String) =
+ key.startsWith("parquet.") || key.contains(".parquet.")
+
+ def convert(relation: HiveTableRelation): LogicalRelation = {
+ val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
+
+ // Consider table and storage properties. For properties existing in both sides, storage
+ // properties will supersede table properties.
+ if (serde.contains("parquet")) {
+ val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++
+ relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA ->
+ SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString)
+ convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet")
+ } else {
+ val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++
+ relation.tableMeta.storage.properties
+ if (SQLConf.get.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") {
+ convertToLogicalRelation(
+ relation,
+ options,
+ classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat],
+ "orc")
+ } else {
+ convertToLogicalRelation(
+ relation,
+ options,
+ classOf[org.apache.spark.sql.hive.orc.OrcFileFormat],
+ "orc")
+ }
+ }
+ }
+
+ private def convertToLogicalRelation(
relation: HiveTableRelation,
options: Map[String, String],
fileFormatClass: Class[_ <: FileFormat],
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 07ee105404311..8a5ab188a949f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -31,8 +31,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoDir, InsertIntoTab
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.command.{CreateTableCommand, DDLUtils}
-import org.apache.spark.sql.execution.datasources.{CreateTable, LogicalRelation}
-import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetOptions}
+import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
@@ -181,49 +180,17 @@ case class RelationConversions(
conf: SQLConf,
sessionCatalog: HiveSessionCatalog) extends Rule[LogicalPlan] {
private def isConvertible(relation: HiveTableRelation): Boolean = {
- val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
- serde.contains("parquet") && conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) ||
- serde.contains("orc") && conf.getConf(HiveUtils.CONVERT_METASTORE_ORC)
+ isConvertible(relation.tableMeta)
}
- // Return true for Apache ORC and Hive ORC-related configuration names.
- // Note that Spark doesn't support configurations like `hive.merge.orcfile.stripe.level`.
- private def isOrcProperty(key: String) =
- key.startsWith("orc.") || key.contains(".orc.")
-
- private def isParquetProperty(key: String) =
- key.startsWith("parquet.") || key.contains(".parquet.")
-
- private def convert(relation: HiveTableRelation): LogicalRelation = {
- val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
-
- // Consider table and storage properties. For properties existing in both sides, storage
- // properties will supersede table properties.
- if (serde.contains("parquet")) {
- val options = relation.tableMeta.properties.filterKeys(isParquetProperty) ++
- relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA ->
- conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString)
- sessionCatalog.metastoreCatalog
- .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet")
- } else {
- val options = relation.tableMeta.properties.filterKeys(isOrcProperty) ++
- relation.tableMeta.storage.properties
- if (conf.getConf(SQLConf.ORC_IMPLEMENTATION) == "native") {
- sessionCatalog.metastoreCatalog.convertToLogicalRelation(
- relation,
- options,
- classOf[org.apache.spark.sql.execution.datasources.orc.OrcFileFormat],
- "orc")
- } else {
- sessionCatalog.metastoreCatalog.convertToLogicalRelation(
- relation,
- options,
- classOf[org.apache.spark.sql.hive.orc.OrcFileFormat],
- "orc")
- }
- }
+ private def isConvertible(tableMeta: CatalogTable): Boolean = {
+ val serde = tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT)
+ serde.contains("parquet") && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_PARQUET) ||
+ serde.contains("orc") && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_ORC)
}
+ private val metastoreCatalog = sessionCatalog.metastoreCatalog
+
override def apply(plan: LogicalPlan): LogicalPlan = {
plan resolveOperators {
// Write path
@@ -231,12 +198,21 @@ case class RelationConversions(
// Inserting into partitioned table is not supported in Parquet/Orc data source (yet).
if query.resolved && DDLUtils.isHiveTable(r.tableMeta) &&
!r.isPartitioned && isConvertible(r) =>
- InsertIntoTable(convert(r), partition, query, overwrite, ifPartitionNotExists)
+ InsertIntoTable(metastoreCatalog.convert(r), partition,
+ query, overwrite, ifPartitionNotExists)
// Read path
case relation: HiveTableRelation
if DDLUtils.isHiveTable(relation.tableMeta) && isConvertible(relation) =>
- convert(relation)
+ metastoreCatalog.convert(relation)
+
+ // CTAS
+ case CreateTable(tableDesc, mode, Some(query))
+ if DDLUtils.isHiveTable(tableDesc) && tableDesc.partitionColumnNames.isEmpty &&
+ isConvertible(tableDesc) && SQLConf.get.getConf(HiveUtils.CONVERT_METASTORE_CTAS) =>
+ DDLUtils.checkDataColNames(tableDesc)
+ OptimizedCreateHiveTableAsSelectCommand(
+ tableDesc, query, query.output.map(_.name), mode)
}
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala
index 66067704195dd..b60d4c71f5941 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala
@@ -110,6 +110,14 @@ private[spark] object HiveUtils extends Logging {
.booleanConf
.createWithDefault(true)
+ val CONVERT_METASTORE_CTAS = buildConf("spark.sql.hive.convertMetastoreCtas")
+ .doc("When set to true, Spark will try to use built-in data source writer " +
+ "instead of Hive serde in CTAS. This flag is effective only if " +
+ "`spark.sql.hive.convertMetastoreParquet` or `spark.sql.hive.convertMetastoreOrc` is " +
+ "enabled respectively for Parquet and ORC formats")
+ .booleanConf
+ .createWithDefault(true)
+
val HIVE_METASTORE_SHARED_PREFIXES = buildConf("spark.sql.hive.metastore.sharedPrefixes")
.doc("A comma separated list of class prefixes that should be loaded using the classloader " +
"that is shared between Spark SQL and a specific version of Hive. An example of classes " +
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index bf78edd6105e6..7249eacfbf9a6 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -20,32 +20,26 @@ package org.apache.spark.sql.hive.execution
import scala.util.control.NonFatal
import org.apache.spark.sql.{AnalysisException, Row, SaveMode, SparkSession}
-import org.apache.spark.sql.catalyst.catalog.CatalogTable
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, SessionCatalog}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.command.DataWritingCommand
+import org.apache.spark.sql.execution.command.{DataWritingCommand, DDLUtils}
+import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, InsertIntoHadoopFsRelationCommand, LogicalRelation}
+import org.apache.spark.sql.hive.HiveSessionCatalog
+trait CreateHiveTableAsSelectBase extends DataWritingCommand {
+ val tableDesc: CatalogTable
+ val query: LogicalPlan
+ val outputColumnNames: Seq[String]
+ val mode: SaveMode
-/**
- * Create table and insert the query result into it.
- *
- * @param tableDesc the Table Describe, which may contain serde, storage handler etc.
- * @param query the query whose result will be insert into the new relation
- * @param mode SaveMode
- */
-case class CreateHiveTableAsSelectCommand(
- tableDesc: CatalogTable,
- query: LogicalPlan,
- outputColumnNames: Seq[String],
- mode: SaveMode)
- extends DataWritingCommand {
-
- private val tableIdentifier = tableDesc.identifier
+ protected val tableIdentifier = tableDesc.identifier
override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = {
val catalog = sparkSession.sessionState.catalog
- if (catalog.tableExists(tableIdentifier)) {
+ val tableExists = catalog.tableExists(tableIdentifier)
+
+ if (tableExists) {
assert(mode != SaveMode.Overwrite,
s"Expect the table $tableIdentifier has been dropped when the save mode is Overwrite")
@@ -57,13 +51,8 @@ case class CreateHiveTableAsSelectCommand(
return Seq.empty
}
- InsertIntoHiveTable(
- tableDesc,
- Map.empty,
- query,
- overwrite = false,
- ifPartitionNotExists = false,
- outputColumnNames = outputColumnNames).run(sparkSession, child)
+ val command = getWritingCommand(catalog, tableDesc, tableExists = true)
+ command.run(sparkSession, child)
} else {
// TODO ideally, we should get the output data ready first and then
// add the relation into catalog, just in case of failure occurs while data
@@ -75,15 +64,8 @@ case class CreateHiveTableAsSelectCommand(
try {
// Read back the metadata of the table which was created just now.
val createdTableMeta = catalog.getTableMetadata(tableDesc.identifier)
- // For CTAS, there is no static partition values to insert.
- val partition = createdTableMeta.partitionColumnNames.map(_ -> None).toMap
- InsertIntoHiveTable(
- createdTableMeta,
- partition,
- query,
- overwrite = true,
- ifPartitionNotExists = false,
- outputColumnNames = outputColumnNames).run(sparkSession, child)
+ val command = getWritingCommand(catalog, createdTableMeta, tableExists = false)
+ command.run(sparkSession, child)
} catch {
case NonFatal(e) =>
// drop the created table.
@@ -95,9 +77,89 @@ case class CreateHiveTableAsSelectCommand(
Seq.empty[Row]
}
+ // Returns `DataWritingCommand` which actually writes data into the table.
+ def getWritingCommand(
+ catalog: SessionCatalog,
+ tableDesc: CatalogTable,
+ tableExists: Boolean): DataWritingCommand
+
override def argString(maxFields: Int): String = {
s"[Database:${tableDesc.database}, " +
s"TableName: ${tableDesc.identifier.table}, " +
s"InsertIntoHiveTable]"
}
}
+
+/**
+ * Create table and insert the query result into it.
+ *
+ * @param tableDesc the table description, which may contain serde, storage handler etc.
+ * @param query the query whose result will be insert into the new relation
+ * @param mode SaveMode
+ */
+case class CreateHiveTableAsSelectCommand(
+ tableDesc: CatalogTable,
+ query: LogicalPlan,
+ outputColumnNames: Seq[String],
+ mode: SaveMode)
+ extends CreateHiveTableAsSelectBase {
+
+ override def getWritingCommand(
+ catalog: SessionCatalog,
+ tableDesc: CatalogTable,
+ tableExists: Boolean): DataWritingCommand = {
+ // For CTAS, there is no static partition values to insert.
+ val partition = tableDesc.partitionColumnNames.map(_ -> None).toMap
+ InsertIntoHiveTable(
+ tableDesc,
+ partition,
+ query,
+ overwrite = if (tableExists) false else true,
+ ifPartitionNotExists = false,
+ outputColumnNames = outputColumnNames)
+ }
+}
+
+/**
+ * Create table and insert the query result into it. This creates Hive table but inserts
+ * the query result into it by using data source.
+ *
+ * @param tableDesc the table description, which may contain serde, storage handler etc.
+ * @param query the query whose result will be insert into the new relation
+ * @param mode SaveMode
+ */
+case class OptimizedCreateHiveTableAsSelectCommand(
+ tableDesc: CatalogTable,
+ query: LogicalPlan,
+ outputColumnNames: Seq[String],
+ mode: SaveMode)
+ extends CreateHiveTableAsSelectBase {
+
+ override def getWritingCommand(
+ catalog: SessionCatalog,
+ tableDesc: CatalogTable,
+ tableExists: Boolean): DataWritingCommand = {
+ val metastoreCatalog = catalog.asInstanceOf[HiveSessionCatalog].metastoreCatalog
+ val hiveTable = DDLUtils.readHiveTable(tableDesc)
+
+ val hadoopRelation = metastoreCatalog.convert(hiveTable) match {
+ case LogicalRelation(t: HadoopFsRelation, _, _, _) => t
+ case _ => throw new AnalysisException(s"$tableIdentifier should be converted to " +
+ "HadoopFsRelation.")
+ }
+
+ InsertIntoHadoopFsRelationCommand(
+ hadoopRelation.location.rootPaths.head,
+ Map.empty, // We don't support to convert partitioned table.
+ false,
+ Seq.empty, // We don't support to convert partitioned table.
+ hadoopRelation.bucketSpec,
+ hadoopRelation.fileFormat,
+ hadoopRelation.options,
+ query,
+ if (tableExists) mode else SaveMode.Overwrite,
+ Some(tableDesc),
+ Some(hadoopRelation.location),
+ query.output.map(_.name))
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
index e5c9df05d5674..470c6a342b4dd 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
@@ -92,4 +92,18 @@ class HiveParquetSuite extends QueryTest with ParquetTest with TestHiveSingleton
}
}
}
+
+ test("SPARK-25271: write empty map into hive parquet table") {
+ import testImplicits._
+
+ Seq(Map(1 -> "a"), Map.empty[Int, String]).toDF("m").createOrReplaceTempView("p")
+ withTempView("p") {
+ val targetTable = "targetTable"
+ withTable(targetTable) {
+ sql(s"CREATE TABLE $targetTable STORED AS PARQUET AS SELECT m FROM p")
+ checkAnswer(sql(s"SELECT m FROM $targetTable"),
+ Row(Map(1 -> "a")) :: Row(Map.empty[Int, String]) :: Nil)
+ }
+ }
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShimSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShimSuite.scala
new file mode 100644
index 0000000000000..a716f739b5c20
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShimSuite.scala
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive
+
+import scala.collection.JavaConverters._
+import scala.language.implicitConversions
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.hive.serde2.ColumnProjectionUtils
+
+import org.apache.spark.SparkFunSuite
+
+class HiveShimSuite extends SparkFunSuite {
+
+ test("appendReadColumns") {
+ val conf = new Configuration
+ val ids = Seq(1, 2, 3).map(Int.box)
+ val names = Seq("a", "b", "c")
+ val moreIds = Seq(4, 5).map(Int.box)
+ val moreNames = Seq("d", "e")
+
+ // test when READ_COLUMN_NAMES_CONF_STR is empty
+ HiveShim.appendReadColumns(conf, ids, names)
+ assert(names.asJava === ColumnProjectionUtils.getReadColumnNames(conf))
+
+ // test when READ_COLUMN_NAMES_CONF_STR is non-empty
+ HiveShim.appendReadColumns(conf, moreIds, moreNames)
+ assert((names ++ moreNames).asJava === ColumnProjectionUtils.getReadColumnNames(conf))
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala
index 5879748d05b2b..510de3a7eab57 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala
@@ -752,6 +752,17 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter
}
}
+ test("SPARK-26307: CTAS - INSERT a partitioned table using Hive serde") {
+ withTable("tab1") {
+ withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") {
+ val df = Seq(("a", 100)).toDF("part", "id")
+ df.write.format("hive").partitionBy("part").mode("overwrite").saveAsTable("tab1")
+ df.write.format("hive").partitionBy("part").mode("append").saveAsTable("tab1")
+ }
+ }
+ }
+
+
Seq("LOCAL", "").foreach { local =>
Seq(true, false).foreach { caseSensitivity =>
Seq("orc", "parquet").foreach { format =>
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index fab2a27cdef17..6acf44606cbbe 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -2276,6 +2276,46 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
}
}
+ test("SPARK-25271: Hive ctas commands should use data source if it is convertible") {
+ withTempView("p") {
+ Seq(1, 2, 3).toDF("id").createOrReplaceTempView("p")
+
+ Seq("orc", "parquet").foreach { format =>
+ Seq(true, false).foreach { isConverted =>
+ withSQLConf(
+ HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted",
+ HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted") {
+ Seq(true, false).foreach { isConvertedCtas =>
+ withSQLConf(HiveUtils.CONVERT_METASTORE_CTAS.key -> s"$isConvertedCtas") {
+
+ val targetTable = "targetTable"
+ withTable(targetTable) {
+ val df = sql(s"CREATE TABLE $targetTable STORED AS $format AS SELECT id FROM p")
+ checkAnswer(sql(s"SELECT id FROM $targetTable"),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+
+ val ctasDSCommand = df.queryExecution.analyzed.collect {
+ case _: OptimizedCreateHiveTableAsSelectCommand => true
+ }.headOption
+ val ctasCommand = df.queryExecution.analyzed.collect {
+ case _: CreateHiveTableAsSelectCommand => true
+ }.headOption
+
+ if (isConverted && isConvertedCtas) {
+ assert(ctasDSCommand.nonEmpty)
+ assert(ctasCommand.isEmpty)
+ } else {
+ assert(ctasDSCommand.isEmpty)
+ assert(ctasCommand.nonEmpty)
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
test("SPARK-26181 hasMinMaxStats method of ColumnStatsMap is not correct") {
withSQLConf(SQLConf.CBO_ENABLED.key -> "true") {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
index 6075f2c8877d6..f0f62b608785d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/HadoopFsRelationTest.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.sources
import java.io.File
+import java.util.TimeZone
import scala.util.Random
@@ -125,56 +126,62 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes
} else {
Seq(false)
}
- for (dataType <- supportedDataTypes) {
- for (parquetDictionaryEncodingEnabled <- parquetDictionaryEncodingEnabledConfs) {
- val extraMessage = if (isParquetDataSource) {
- s" with parquet.enable.dictionary = $parquetDictionaryEncodingEnabled"
- } else {
- ""
- }
- logInfo(s"Testing $dataType data type$extraMessage")
-
- val extraOptions = Map[String, String](
- "parquet.enable.dictionary" -> parquetDictionaryEncodingEnabled.toString
- )
-
- withTempPath { file =>
- val path = file.getCanonicalPath
-
- val dataGenerator = RandomDataGenerator.forType(
- dataType = dataType,
- nullable = true,
- new Random(System.nanoTime())
- ).getOrElse {
- fail(s"Failed to create data generator for schema $dataType")
+ // TODO: Support new parser too, see SPARK-26374.
+ withSQLConf(SQLConf.LEGACY_TIME_PARSER_ENABLED.key -> "true") {
+ for (dataType <- supportedDataTypes) {
+ for (parquetDictionaryEncodingEnabled <- parquetDictionaryEncodingEnabledConfs) {
+ val extraMessage = if (isParquetDataSource) {
+ s" with parquet.enable.dictionary = $parquetDictionaryEncodingEnabled"
+ } else {
+ ""
+ }
+ logInfo(s"Testing $dataType data type$extraMessage")
+
+ val extraOptions = Map[String, String](
+ "parquet.enable.dictionary" -> parquetDictionaryEncodingEnabled.toString
+ )
+
+ withTempPath { file =>
+ val path = file.getCanonicalPath
+
+ val seed = System.nanoTime()
+ withClue(s"Random data generated with the seed: ${seed}") {
+ val dataGenerator = RandomDataGenerator.forType(
+ dataType = dataType,
+ nullable = true,
+ new Random(seed)
+ ).getOrElse {
+ fail(s"Failed to create data generator for schema $dataType")
+ }
+
+ // Create a DF for the schema with random data. The index field is used to sort the
+ // DataFrame. This is a workaround for SPARK-10591.
+ val schema = new StructType()
+ .add("index", IntegerType, nullable = false)
+ .add("col", dataType, nullable = true)
+ val rdd =
+ spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator())))
+ val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1)
+
+ df.write
+ .mode("overwrite")
+ .format(dataSourceName)
+ .option("dataSchema", df.schema.json)
+ .options(extraOptions)
+ .save(path)
+
+ val loadedDF = spark
+ .read
+ .format(dataSourceName)
+ .option("dataSchema", df.schema.json)
+ .schema(df.schema)
+ .options(extraOptions)
+ .load(path)
+ .orderBy("index")
+
+ checkAnswer(loadedDF, df)
+ }
}
-
- // Create a DF for the schema with random data. The index field is used to sort the
- // DataFrame. This is a workaround for SPARK-10591.
- val schema = new StructType()
- .add("index", IntegerType, nullable = false)
- .add("col", dataType, nullable = true)
- val rdd =
- spark.sparkContext.parallelize((1 to 10).map(i => Row(i, dataGenerator())))
- val df = spark.createDataFrame(rdd, schema).orderBy("index").coalesce(1)
-
- df.write
- .mode("overwrite")
- .format(dataSourceName)
- .option("dataSchema", df.schema.json)
- .options(extraOptions)
- .save(path)
-
- val loadedDF = spark
- .read
- .format(dataSourceName)
- .option("dataSchema", df.schema.json)
- .schema(df.schema)
- .options(extraOptions)
- .load(path)
- .orderBy("index")
-
- checkAnswer(loadedDF, df)
}
}
}