diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 6d20ef1f98a3c..3e60c50ada59b 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -186,7 +186,17 @@ class HadoopMapReduceCommitProtocol( logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") for (part <- partitionPaths) { val finalPartPath = new Path(path, part) - fs.delete(finalPartPath, true) + if (!fs.delete(finalPartPath, true) && !fs.exists(finalPartPath.getParent)) { + // According to the official hadoop FileSystem API spec, delete op should assume + // the destination is no longer present regardless of return value, thus we do not + // need to double check if finalPartPath exists before rename. + // Also in our case, based on the spec, delete returns false only when finalPartPath + // does not exist. When this happens, we need to take action if parent of finalPartPath + // also does not exist(e.g. the scenario described on SPARK-23815), because + // FileSystem API spec on rename op says the rename dest(finalPartPath) must have + // a parent that exists, otherwise we may get unexpected result on the rename. + fs.mkdirs(finalPartPath.getParent) + } fs.rename(new Path(stagingDir, part), finalPartPath) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala index 7e14938acd8e0..c1fedd63f6a90 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/AsyncEventQueue.scala @@ -166,7 +166,7 @@ private class AsyncEventQueue(val name: String, conf: SparkConf, metrics: LiveLi val prevLastReportTimestamp = lastReportTimestamp lastReportTimestamp = System.currentTimeMillis() val previous = new java.util.Date(prevLastReportTimestamp) - logWarning(s"Dropped $droppedEvents events from $name since $previous.") + logWarning(s"Dropped $droppedCount events from $name since $previous.") } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala index af8ff64d33ffe..adf8145726711 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/stat/KolmogorovSmirnovTest.scala @@ -85,7 +85,7 @@ object KolmogorovSmirnovTest { dataset: Dataset[_], sampleCol: String, cdf: Function[java.lang.Double, java.lang.Double]): DataFrame = { - test(dataset, sampleCol, (x: Double) => cdf.call(x)) + test(dataset, sampleCol, (x: Double) => cdf.call(x).toDouble) } /** diff --git a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java index 43779878890db..35a250955b282 100644 --- a/mllib/src/test/java/org/apache/spark/SharedSparkSession.java +++ b/mllib/src/test/java/org/apache/spark/SharedSparkSession.java @@ -42,7 +42,12 @@ public void setUp() throws IOException { @After public void tearDown() { - spark.stop(); - spark = null; + try { + spark.stop(); + spark = null; + } finally { + SparkSession.clearDefaultSession(); + SparkSession.clearActiveSession(); + } } } diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 409807216ae12..ab20085ec5f15 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2079,6 +2079,21 @@ def size(col): return Column(sc._jvm.functions.size(_to_java_column(col))) +@since(2.4) +def array_max(col): + """ + Collection function: returns the maximum value of the array. + + :param col: name of column or expression + + >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data']) + >>> df.select(array_max(df.data).alias('max')).collect() + [Row(max=3), Row(max=10)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.array_max(_to_java_column(col))) + + @since(1.5) def sort_array(col, asc=True): """ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 96c2a776a5049..4e99c8e3c6b10 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -186,16 +186,12 @@ def __init__(self, key, value): self.value = value -class ReusedSQLTestCase(ReusedPySparkTestCase): - @classmethod - def setUpClass(cls): - ReusedPySparkTestCase.setUpClass() - cls.spark = SparkSession(cls.sc) - - @classmethod - def tearDownClass(cls): - ReusedPySparkTestCase.tearDownClass() - cls.spark.stop() +class SQLTestUtils(object): + """ + This util assumes the instance of this to have 'spark' attribute, having a spark session. + It is usually used with 'ReusedSQLTestCase' class but can be used if you feel sure the + the implementation of this class has 'spark' attribute. + """ @contextmanager def sql_conf(self, pairs): @@ -204,6 +200,7 @@ def sql_conf(self, pairs): `value` to the configuration `key` and then restores it back when it exits. """ assert isinstance(pairs, dict), "pairs should be a dictionary." + assert hasattr(self, "spark"), "it should have 'spark' attribute, having a spark session." keys = pairs.keys() new_values = pairs.values() @@ -219,6 +216,18 @@ def sql_conf(self, pairs): else: self.spark.conf.set(key, old_value) + +class ReusedSQLTestCase(ReusedPySparkTestCase, SQLTestUtils): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + def assertPandasEqual(self, expected, result): msg = ("DataFrames are not equal: " + "\n\nExpected:\n%s\n%s" % (expected, expected.dtypes) + @@ -3066,6 +3075,64 @@ def test_sparksession_with_stopped_sparkcontext(self): sc.stop() +class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): + # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is + # static and immutable. This can't be set or unset, for example, via `spark.conf`. + + @classmethod + def setUpClass(cls): + import glob + from pyspark.find_spark_home import _find_spark_home + + SPARK_HOME = _find_spark_home() + filename_pattern = ( + "sql/core/target/scala-*/test-classes/org/apache/spark/sql/" + "TestQueryExecutionListener.class") + if not glob.glob(os.path.join(SPARK_HOME, filename_pattern)): + raise unittest.SkipTest( + "'org.apache.spark.sql.TestQueryExecutionListener' is not " + "available. Will skip the related tests.") + + # Note that 'spark.sql.queryExecutionListeners' is a static immutable configuration. + cls.spark = SparkSession.builder \ + .master("local[4]") \ + .appName(cls.__name__) \ + .config( + "spark.sql.queryExecutionListeners", + "org.apache.spark.sql.TestQueryExecutionListener") \ + .getOrCreate() + + @classmethod + def tearDownClass(cls): + cls.spark.stop() + + def tearDown(self): + self.spark._jvm.OnSuccessCall.clear() + + def test_query_execution_listener_on_collect(self): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be called before 'collect'") + self.spark.sql("SELECT * FROM range(1)").collect() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'collect'") + + @unittest.skipIf( + not _have_pandas or not _have_pyarrow, + _pandas_requirement_message or _pyarrow_requirement_message) + def test_query_execution_listener_on_collect_with_arrow(self): + with self.sql_conf({"spark.sql.execution.arrow.enabled": True}): + self.assertFalse( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should not be " + "called before 'toPandas'") + self.spark.sql("SELECT * FROM range(1)").toPandas() + self.assertTrue( + self.spark._jvm.OnSuccessCall.isCalled(), + "The callback from the query execution listener should be called after 'toPandas'") + + class SparkSessionTests(PySparkTestCase): # This test is separate because it's closely related with session's start and stop. diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala index 82f6c714f3555..4086970ffb256 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/Config.scala @@ -167,5 +167,5 @@ private[spark] object Config extends Logging { val KUBERNETES_EXECUTOR_ANNOTATION_PREFIX = "spark.kubernetes.executor.annotation." val KUBERNETES_EXECUTOR_SECRETS_PREFIX = "spark.kubernetes.executor.secrets." - val KUBERNETES_DRIVER_ENV_KEY = "spark.kubernetes.driverEnv." + val KUBERNETES_DRIVER_ENV_PREFIX = "spark.kubernetes.driverEnv." } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala new file mode 100644 index 0000000000000..77b634ddfabcc --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesConf.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReference, LocalObjectReferenceBuilder, Pod} + +import org.apache.spark.SparkConf +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.{JavaMainAppResource, MainAppResource} +import org.apache.spark.internal.config.ConfigEntry + +private[spark] sealed trait KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark driver. + */ +private[spark] case class KubernetesDriverSpecificConf( + mainAppResource: Option[MainAppResource], + mainClass: String, + appName: String, + appArgs: Seq[String]) extends KubernetesRoleSpecificConf + +/* + * Structure containing metadata for Kubernetes logic that builds a Spark executor. + */ +private[spark] case class KubernetesExecutorSpecificConf( + executorId: String, + driverPod: Pod) + extends KubernetesRoleSpecificConf + +/** + * Structure containing metadata for Kubernetes logic to build Spark pods. + */ +private[spark] case class KubernetesConf[T <: KubernetesRoleSpecificConf]( + sparkConf: SparkConf, + roleSpecificConf: T, + appResourceNamePrefix: String, + appId: String, + roleLabels: Map[String, String], + roleAnnotations: Map[String, String], + roleSecretNamesToMountPaths: Map[String, String], + roleEnvs: Map[String, String]) { + + def namespace(): String = sparkConf.get(KUBERNETES_NAMESPACE) + + def sparkJars(): Seq[String] = sparkConf + .getOption("spark.jars") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def sparkFiles(): Seq[String] = sparkConf + .getOption("spark.files") + .map(str => str.split(",").toSeq) + .getOrElse(Seq.empty[String]) + + def imagePullPolicy(): String = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) + + def imagePullSecrets(): Seq[LocalObjectReference] = { + sparkConf + .get(IMAGE_PULL_SECRETS) + .map(_.split(",")) + .getOrElse(Array.empty[String]) + .map(_.trim) + .map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + } + + def nodeSelector(): Map[String, String] = + KubernetesUtils.parsePrefixedKeyValuePairs(sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) + + def get[T](config: ConfigEntry[T]): T = sparkConf.get(config) + + def get(conf: String): String = sparkConf.get(conf) + + def get(conf: String, defaultValue: String): String = sparkConf.get(conf, defaultValue) + + def getOption(key: String): Option[String] = sparkConf.getOption(key) +} + +private[spark] object KubernetesConf { + def createDriverConf( + sparkConf: SparkConf, + appName: String, + appResourceNamePrefix: String, + appId: String, + mainAppResource: Option[MainAppResource], + mainClass: String, + appArgs: Array[String]): KubernetesConf[KubernetesDriverSpecificConf] = { + val sparkConfWithMainAppJar = sparkConf.clone() + mainAppResource.foreach { + case JavaMainAppResource(res) => + val previousJars = sparkConf + .getOption("spark.jars") + .map(_.split(",")) + .getOrElse(Array.empty) + if (!previousJars.contains(res)) { + sparkConfWithMainAppJar.setJars(previousJars ++ Seq(res)) + } + } + + val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_LABEL_PREFIX) + require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + + s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + + s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + + "operations.") + val driverLabels = driverCustomLabels ++ Map( + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) + val driverAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) + val driverSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_SECRETS_PREFIX) + val driverEnvs = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_DRIVER_ENV_PREFIX) + + KubernetesConf( + sparkConfWithMainAppJar, + KubernetesDriverSpecificConf(mainAppResource, mainClass, appName, appArgs), + appResourceNamePrefix, + appId, + driverLabels, + driverAnnotations, + driverSecretNamesToMountPaths, + driverEnvs) + } + + def createExecutorConf( + sparkConf: SparkConf, + executorId: String, + appId: String, + driverPod: Pod): KubernetesConf[KubernetesExecutorSpecificConf] = { + val executorCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_LABEL_PREFIX) + require( + !executorCustomLabels.contains(SPARK_APP_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") + require( + !executorCustomLabels.contains(SPARK_EXECUTOR_ID_LABEL), + s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + + " Spark.") + require( + !executorCustomLabels.contains(SPARK_ROLE_LABEL), + s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") + val executorLabels = Map( + SPARK_EXECUTOR_ID_LABEL -> executorId, + SPARK_APP_ID_LABEL -> appId, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ + executorCustomLabels + val executorAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) + val executorSecrets = KubernetesUtils.parsePrefixedKeyValuePairs( + sparkConf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) + val executorEnv = sparkConf.getExecutorEnv.toMap + + KubernetesConf( + sparkConf.clone(), + KubernetesExecutorSpecificConf(executorId, driverPod), + sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX), + appId, + executorLabels, + executorAnnotations, + executorSecrets, + executorEnv) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala similarity index 57% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala index cf41b22e241af..0c5ae022f4070 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesUtilsTest.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesDriverSpec.scala @@ -16,21 +16,16 @@ */ package org.apache.spark.deploy.k8s -import io.fabric8.kubernetes.api.model.LocalObjectReference +import io.fabric8.kubernetes.api.model.HasMetadata -import org.apache.spark.SparkFunSuite - -class KubernetesUtilsTest extends SparkFunSuite { - - test("testParseImagePullSecrets") { - val noSecrets = KubernetesUtils.parseImagePullSecrets(None) - assert(noSecrets === Nil) - - val oneSecret = KubernetesUtils.parseImagePullSecrets(Some("imagePullSecret")) - assert(oneSecret === new LocalObjectReference("imagePullSecret") :: Nil) - - val commaSeparatedSecrets = KubernetesUtils.parseImagePullSecrets(Some("s1, s2 , s3,s4")) - assert(commaSeparatedSecrets.map(_.getName) === "s1" :: "s2" :: "s3" :: "s4" :: Nil) - } +private[spark] case class KubernetesDriverSpec( + pod: SparkPod, + driverKubernetesResources: Seq[HasMetadata], + systemProperties: Map[String, String]) +private[spark] object KubernetesDriverSpec { + def initialSpec(initialProps: Map[String, String]): KubernetesDriverSpec = KubernetesDriverSpec( + SparkPod.initialPod(), + Seq.empty, + initialProps) } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala index 5b2bb819cdb14..ee629068ad90d 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/KubernetesUtils.scala @@ -37,17 +37,6 @@ private[spark] object KubernetesUtils { sparkConf.getAllWithPrefix(prefix).toMap } - /** - * Parses comma-separated list of imagePullSecrets into K8s-understandable format - */ - def parseImagePullSecrets(imagePullSecrets: Option[String]): List[LocalObjectReference] = { - imagePullSecrets match { - case Some(secretsCommaSeparated) => - secretsCommaSeparated.split(',').map(_.trim).map(new LocalObjectReference(_)).toList - case None => Nil - } - } - def requireNandDefined(opt1: Option[_], opt2: Option[_], errMessage: String): Unit = { opt1.foreach { _ => require(opt2.isEmpty, errMessage) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala deleted file mode 100644 index c35e7db51d407..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ /dev/null @@ -1,72 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} - -/** - * Bootstraps a driver or executor container or an init-container with needed secrets mounted. - */ -private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { - - /** - * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. - * - * @param pod the pod into which the secret volumes are being added. - * @return the updated pod with the secret volumes added. - */ - def addSecretVolumes(pod: Pod): Pod = { - var podBuilder = new PodBuilder(pod) - secretNamesToMountPaths.keys.foreach { name => - podBuilder = podBuilder - .editOrNewSpec() - .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() - .endSpec() - } - - podBuilder.build() - } - - /** - * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the - * given container. - * - * @param container the container into which the secret volumes are being mounted. - * @return the updated container with the secrets mounted. - */ - def mountSecrets(container: Container): Container = { - var containerBuilder = new ContainerBuilder(container) - secretNamesToMountPaths.foreach { case (name, path) => - containerBuilder = containerBuilder - .addNewVolumeMount() - .withName(secretVolumeName(name)) - .withMountPath(path) - .endVolumeMount() - } - - containerBuilder.build() - } - - private def secretVolumeName(secretName: String): String = { - secretName + "-volume" - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala similarity index 64% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala index 17614e040e587..345dd117fd35f 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPod.scala @@ -14,17 +14,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBuilder} -/** - * Represents a step in configuring the Spark driver pod. - */ -private[spark] trait DriverConfigurationStep { +private[spark] case class SparkPod(pod: Pod, container: Container) - /** - * Apply some transformation to the previous state of the driver to add a new feature to it. - */ - def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec +private[spark] object SparkPod { + def initialPod(): SparkPod = { + SparkPod( + new PodBuilder() + .withNewMetadata() + .endMetadata() + .withNewSpec() + .endSpec() + .build(), + new ContainerBuilder().build()) + } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala new file mode 100644 index 0000000000000..07bdccbe0479d --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStep.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, KubernetesUtils, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config._ +import org.apache.spark.launcher.SparkLauncher + +private[spark] class BasicDriverFeatureStep( + conf: KubernetesConf[KubernetesDriverSpecificConf]) + extends KubernetesFeatureConfigStep { + + private val driverPodName = conf + .get(KUBERNETES_DRIVER_POD_NAME) + .getOrElse(s"${conf.appResourceNamePrefix}-driver") + + private val driverContainerImage = conf + .get(DRIVER_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the driver container image")) + + // CPU settings + private val driverCpuCores = conf.get("spark.driver.cores", "1") + private val driverLimitCores = conf.get(KUBERNETES_DRIVER_LIMIT_CORES) + + // Memory settings + private val driverMemoryMiB = conf.get(DRIVER_MEMORY) + private val memoryOverheadMiB = conf + .get(DRIVER_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) + private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB + + override def configurePod(pod: SparkPod): SparkPod = { + val driverCustomEnvs = conf.roleEnvs + .toSeq + .map { env => + new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + } + + val driverCpuQuantity = new QuantityBuilder(false) + .withAmount(driverCpuCores) + .build() + val driverMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${driverMemoryWithOverheadMiB}Mi") + .build() + val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => + ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) + } + + val driverContainer = new ContainerBuilder(pod.container) + .withName(DRIVER_CONTAINER_NAME) + .withImage(driverContainerImage) + .withImagePullPolicy(conf.imagePullPolicy()) + .addAllToEnv(driverCustomEnvs.asJava) + .addNewEnv() + .withName(ENV_DRIVER_BIND_ADDRESS) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .endEnv() + .withNewResources() + .addToRequests("cpu", driverCpuQuantity) + .addToLimits(maybeCpuLimitQuantity.toMap.asJava) + .addToRequests("memory", driverMemoryQuantity) + .addToLimits("memory", driverMemoryQuantity) + .endResources() + .addToArgs("driver") + .addToArgs("--properties-file", SPARK_CONF_PATH) + .addToArgs("--class", conf.roleSpecificConf.mainClass) + // The user application jar is merged into the spark.jars list and managed through that + // property, so there is no need to reference it explicitly here. + .addToArgs(SparkLauncher.NO_RESOURCE) + .addToArgs(conf.roleSpecificConf.appArgs: _*) + .build() + + val driverPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(driverPodName) + .addToLabels(conf.roleLabels.asJava) + .addToAnnotations(conf.roleAnnotations.asJava) + .endMetadata() + .withNewSpec() + .withRestartPolicy("Never") + .withNodeSelector(conf.nodeSelector().asJava) + .addToImagePullSecrets(conf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(driverPod, driverContainer) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val additionalProps = mutable.Map( + KUBERNETES_DRIVER_POD_NAME.key -> driverPodName, + "spark.app.id" -> conf.appId, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> conf.appResourceNamePrefix, + KUBERNETES_DRIVER_SUBMIT_CHECK.key -> "true") + + val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkJars()) + val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath( + conf.sparkFiles()) + if (resolvedSparkJars.nonEmpty) { + additionalProps.put("spark.jars", resolvedSparkJars.mkString(",")) + } + if (resolvedSparkFiles.nonEmpty) { + additionalProps.put("spark.files", resolvedSparkFiles.mkString(",")) + } + additionalProps.toMap + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala new file mode 100644 index 0000000000000..d22097587aafe --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStep.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, ContainerPortBuilder, EnvVar, EnvVarBuilder, EnvVarSourceBuilder, HasMetadata, PodBuilder, QuantityBuilder} + +import org.apache.spark.SparkException +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.util.Utils + +private[spark] class BasicExecutorFeatureStep( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]) + extends KubernetesFeatureConfigStep { + + // Consider moving some of these fields to KubernetesConf or KubernetesExecutorSpecificConf + private val executorExtraClasspath = kubernetesConf.get(EXECUTOR_CLASS_PATH) + private val executorContainerImage = kubernetesConf + .get(EXECUTOR_CONTAINER_IMAGE) + .getOrElse(throw new SparkException("Must specify the executor container image")) + private val blockManagerPort = kubernetesConf + .sparkConf + .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) + + private val executorPodNamePrefix = kubernetesConf.appResourceNamePrefix + + private val driverUrl = RpcEndpointAddress( + kubernetesConf.get("spark.driver.host"), + kubernetesConf.sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val executorMemoryMiB = kubernetesConf.get(EXECUTOR_MEMORY) + private val executorMemoryString = kubernetesConf.get( + EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) + + private val memoryOverheadMiB = kubernetesConf + .get(EXECUTOR_MEMORY_OVERHEAD) + .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, + MEMORY_OVERHEAD_MIN_MIB)) + private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB + + private val executorCores = kubernetesConf.sparkConf.getInt("spark.executor.cores", 1) + private val executorCoresRequest = + if (kubernetesConf.sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { + kubernetesConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get + } else { + executorCores.toString + } + private val executorLimitCores = kubernetesConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) + + override def configurePod(pod: SparkPod): SparkPod = { + val name = s"$executorPodNamePrefix-exec-${kubernetesConf.roleSpecificConf.executorId}" + + // hostname must be no longer than 63 characters, so take the last 63 characters of the pod + // name as the hostname. This preserves uniqueness since the end of name contains + // executorId + val hostname = name.substring(Math.max(0, name.length - 63)) + val executorMemoryQuantity = new QuantityBuilder(false) + .withAmount(s"${executorMemoryWithOverhead}Mi") + .build() + val executorCpuQuantity = new QuantityBuilder(false) + .withAmount(executorCoresRequest) + .build() + val executorExtraClasspathEnv = executorExtraClasspath.map { cp => + new EnvVarBuilder() + .withName(ENV_CLASSPATH) + .withValue(cp) + .build() + } + val executorExtraJavaOptionsEnv = kubernetesConf + .get(EXECUTOR_JAVA_OPTIONS) + .map { opts => + val delimitedOpts = Utils.splitCommandString(opts) + delimitedOpts.zipWithIndex.map { + case (opt, index) => + new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() + } + }.getOrElse(Seq.empty[EnvVar]) + val executorEnv = (Seq( + (ENV_DRIVER_URL, driverUrl), + (ENV_EXECUTOR_CORES, executorCores.toString), + (ENV_EXECUTOR_MEMORY, executorMemoryString), + (ENV_APPLICATION_ID, kubernetesConf.appId), + // This is to set the SPARK_CONF_DIR to be /opt/spark/conf + (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), + (ENV_EXECUTOR_ID, kubernetesConf.roleSpecificConf.executorId)) ++ + kubernetesConf.roleEnvs) + .map(env => new EnvVarBuilder() + .withName(env._1) + .withValue(env._2) + .build() + ) ++ Seq( + new EnvVarBuilder() + .withName(ENV_EXECUTOR_POD_IP) + .withValueFrom(new EnvVarSourceBuilder() + .withNewFieldRef("v1", "status.podIP") + .build()) + .build() + ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq + val requiredPorts = Seq( + (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) + .map { case (name, port) => + new ContainerPortBuilder() + .withName(name) + .withContainerPort(port) + .build() + } + + val executorContainer = new ContainerBuilder(pod.container) + .withName("executor") + .withImage(executorContainerImage) + .withImagePullPolicy(kubernetesConf.imagePullPolicy()) + .withNewResources() + .addToRequests("memory", executorMemoryQuantity) + .addToLimits("memory", executorMemoryQuantity) + .addToRequests("cpu", executorCpuQuantity) + .endResources() + .addAllToEnv(executorEnv.asJava) + .withPorts(requiredPorts.asJava) + .addToArgs("executor") + .build() + val containerWithLimitCores = executorLimitCores.map { limitCores => + val executorCpuLimitQuantity = new QuantityBuilder(false) + .withAmount(limitCores) + .build() + new ContainerBuilder(executorContainer) + .editResources() + .addToLimits("cpu", executorCpuLimitQuantity) + .endResources() + .build() + }.getOrElse(executorContainer) + val driverPod = kubernetesConf.roleSpecificConf.driverPod + val executorPod = new PodBuilder(pod.pod) + .editOrNewMetadata() + .withName(name) + .withLabels(kubernetesConf.roleLabels.asJava) + .withAnnotations(kubernetesConf.roleAnnotations.asJava) + .withOwnerReferences() + .addNewOwnerReference() + .withController(true) + .withApiVersion(driverPod.getApiVersion) + .withKind(driverPod.getKind) + .withName(driverPod.getMetadata.getName) + .withUid(driverPod.getMetadata.getUid) + .endOwnerReference() + .endMetadata() + .editOrNewSpec() + .withHostname(hostname) + .withRestartPolicy("Never") + .withNodeSelector(kubernetesConf.nodeSelector().asJava) + .addToImagePullSecrets(kubernetesConf.imagePullSecrets(): _*) + .endSpec() + .build() + SparkPod(executorPod, containerWithLimitCores) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala new file mode 100644 index 0000000000000..ff5ad6673b309 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStep.scala @@ -0,0 +1,216 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import java.io.File +import java.nio.charset.StandardCharsets + +import scala.collection.JavaConverters._ + +import com.google.common.io.{BaseEncoding, Files} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret, SecretBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +private[spark] class DriverKubernetesCredentialsFeatureStep(kubernetesConf: KubernetesConf[_]) + extends KubernetesFeatureConfigStep { + // TODO clean up this class, and credentials in general. See also SparkKubernetesClientFactory. + // We should use a struct to hold all creds-related fields. A lot of the code is very repetitive. + + private val maybeMountedOAuthTokenFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") + private val maybeMountedClientKeyFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") + private val maybeMountedClientCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") + private val maybeMountedCaCertFile = kubernetesConf.getOption( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") + private val driverServiceAccount = kubernetesConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) + + private val oauthTokenBase64 = kubernetesConf + .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") + .map { token => + BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) + } + + private val caCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", + "Driver CA cert file") + private val clientKeyDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", + "Driver client key file") + private val clientCertDataBase64 = safeFileConfToBase64( + s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", + "Driver client cert file") + + // TODO decide whether or not to apply this step entirely in the caller, i.e. the builder. + private val shouldMountSecret = oauthTokenBase64.isDefined || + caCertDataBase64.isDefined || + clientKeyDataBase64.isDefined || + clientCertDataBase64.isDefined + + private val driverCredentialsSecretName = + s"${kubernetesConf.appResourceNamePrefix}-kubernetes-credentials" + + override def configurePod(pod: SparkPod): SparkPod = { + if (!shouldMountSecret) { + pod.copy( + pod = driverServiceAccount.map { account => + new PodBuilder(pod.pod) + .editOrNewSpec() + .withServiceAccount(account) + .withServiceAccountName(account) + .endSpec() + .build() + }.getOrElse(pod.pod)) + } else { + val driverPodWithMountedKubernetesCredentials = + new PodBuilder(pod.pod) + .editOrNewSpec() + .addNewVolume() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withNewSecret().withSecretName(driverCredentialsSecretName).endSecret() + .endVolume() + .endSpec() + .build() + + val driverContainerWithMountedSecretVolume = + new ContainerBuilder(pod.container) + .addNewVolumeMount() + .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) + .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) + .endVolumeMount() + .build() + SparkPod(driverPodWithMountedKubernetesCredentials, driverContainerWithMountedSecretVolume) + } + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val resolvedMountedOAuthTokenFile = resolveSecretLocation( + maybeMountedOAuthTokenFile, + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) + val resolvedMountedClientKeyFile = resolveSecretLocation( + maybeMountedClientKeyFile, + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_PATH) + val resolvedMountedClientCertFile = resolveSecretLocation( + maybeMountedClientCertFile, + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_PATH) + val resolvedMountedCaCertFile = resolveSecretLocation( + maybeMountedCaCertFile, + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_PATH) + + val redactedTokens = kubernetesConf.sparkConf.getAll + .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)) + .toMap + .mapValues( _ => "") + redactedTokens ++ + resolvedMountedCaCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientKeyFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedClientCertFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) ++ + resolvedMountedOAuthTokenFile.map { file => + Map( + s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> + file) + }.getOrElse(Map.empty) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + if (shouldMountSecret) { + Seq(createCredentialsSecret()) + } else { + Seq.empty + } + } + + private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { + kubernetesConf.getOption(conf) + .map(new File(_)) + .map { file => + require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", + fileType, file.getAbsolutePath)) + BaseEncoding.base64().encode(Files.toByteArray(file)) + } + } + + /** + * Resolve a Kubernetes secret data entry from an optional client credential used by the + * driver to talk to the Kubernetes API server. + * + * @param userSpecifiedCredential the optional user-specified client credential. + * @param secretName name of the Kubernetes secret storing the client credential. + * @return a secret data entry in the form of a map from the secret name to the secret data, + * which may be empty if the user-specified credential is empty. + */ + private def resolveSecretData( + userSpecifiedCredential: Option[String], + secretName: String): Map[String, String] = { + userSpecifiedCredential.map { valueBase64 => + Map(secretName -> valueBase64) + }.getOrElse(Map.empty[String, String]) + } + + private def resolveSecretLocation( + mountedUserSpecified: Option[String], + valueMountedFromSubmitter: Option[String], + mountedCanonicalLocation: String): Option[String] = { + mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => + mountedCanonicalLocation + }) + } + + private def createCredentialsSecret(): Secret = { + val allSecretData = + resolveSecretData( + clientKeyDataBase64, + DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ + resolveSecretData( + clientCertDataBase64, + DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ + resolveSecretData( + caCertDataBase64, + DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ + resolveSecretData( + oauthTokenBase64, + DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) + + new SecretBuilder() + .withNewMetadata() + .withName(driverCredentialsSecretName) + .endMetadata() + .withData(allSecretData.asJava) + .build() + } + +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala new file mode 100644 index 0000000000000..f2d7bbd08f305 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStep.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.{HasMetadata, ServiceBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.internal.Logging +import org.apache.spark.util.{Clock, SystemClock} + +private[spark] class DriverServiceFeatureStep( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], + clock: Clock = new SystemClock) + extends KubernetesFeatureConfigStep with Logging { + import DriverServiceFeatureStep._ + + require(kubernetesConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, + s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + + "address is managed and set to the driver pod's IP address.") + require(kubernetesConf.getOption(DRIVER_HOST_KEY).isEmpty, + s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + + "managed via a Kubernetes service.") + + private val preferredServiceName = s"${kubernetesConf.appResourceNamePrefix}$DRIVER_SVC_POSTFIX" + private val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { + preferredServiceName + } else { + val randomServiceId = clock.getTimeMillis() + val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" + logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + + s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + + s"$shorterServiceName as the driver service's name.") + shorterServiceName + } + + private val driverPort = kubernetesConf.sparkConf.getInt( + "spark.driver.port", DEFAULT_DRIVER_PORT) + private val driverBlockManagerPort = kubernetesConf.sparkConf.getInt( + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) + + override def configurePod(pod: SparkPod): SparkPod = pod + + override def getAdditionalPodSystemProperties(): Map[String, String] = { + val driverHostname = s"$resolvedServiceName.${kubernetesConf.namespace()}.svc" + Map(DRIVER_HOST_KEY -> driverHostname, + "spark.driver.port" -> driverPort.toString, + org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key -> + driverBlockManagerPort.toString) + } + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = { + val driverService = new ServiceBuilder() + .withNewMetadata() + .withName(resolvedServiceName) + .endMetadata() + .withNewSpec() + .withClusterIP("None") + .withSelector(kubernetesConf.roleLabels.asJava) + .addNewPort() + .withName(DRIVER_PORT_NAME) + .withPort(driverPort) + .withNewTargetPort(driverPort) + .endPort() + .addNewPort() + .withName(BLOCK_MANAGER_PORT_NAME) + .withPort(driverBlockManagerPort) + .withNewTargetPort(driverBlockManagerPort) + .endPort() + .endSpec() + .build() + Seq(driverService) + } +} + +private[spark] object DriverServiceFeatureStep { + val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key + val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key + val DRIVER_SVC_POSTFIX = "-driver-svc" + val MAX_SERVICE_NAME_LENGTH = 63 +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala new file mode 100644 index 0000000000000..4c1be3bb13293 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/KubernetesFeatureConfigStep.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.HasMetadata + +import org.apache.spark.deploy.k8s.SparkPod + +/** + * A collection of functions that together represent a "feature" in pods that are launched for + * Spark drivers and executors. + */ +private[spark] trait KubernetesFeatureConfigStep { + + /** + * Apply modifications on the given pod in accordance to this feature. This can include attaching + * volumes, adding environment variables, and adding labels/annotations. + *

+ * Note that we should return a SparkPod that keeps all of the properties of the passed SparkPod + * object. So this is correct: + *

+   * {@code val configuredPod = new PodBuilder(pod.pod)
+   *     .editSpec()
+   *     ...
+   *     .build()
+   *   val configuredContainer = new ContainerBuilder(pod.container)
+   *     ...
+   *     .build()
+   *   SparkPod(configuredPod, configuredContainer)
+   *  }
+   * 
+ * This is incorrect: + *
+   * {@code val configuredPod = new PodBuilder() // Loses the original state
+   *     .editSpec()
+   *     ...
+   *     .build()
+   *   val configuredContainer = new ContainerBuilder() // Loses the original state
+   *     ...
+   *     .build()
+   *   SparkPod(configuredPod, configuredContainer)
+   *  }
+   * 
+ */ + def configurePod(pod: SparkPod): SparkPod + + /** + * Return any system properties that should be set on the JVM in accordance to this feature. + */ + def getAdditionalPodSystemProperties(): Map[String, String] + + /** + * Return any additional Kubernetes resources that should be added to support this feature. Only + * applicable when creating the driver in cluster mode. + */ + def getAdditionalKubernetesResources(): Seq[HasMetadata] +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala new file mode 100644 index 0000000000000..97fa9499b2edb --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStep.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, VolumeBuilder, VolumeMountBuilder} + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesRoleSpecificConf, SparkPod} + +private[spark] class MountSecretsFeatureStep( + kubernetesConf: KubernetesConf[_ <: KubernetesRoleSpecificConf]) + extends KubernetesFeatureConfigStep { + override def configurePod(pod: SparkPod): SparkPod = { + val addedVolumes = kubernetesConf + .roleSecretNamesToMountPaths + .keys + .map(secretName => + new VolumeBuilder() + .withName(secretVolumeName(secretName)) + .withNewSecret() + .withSecretName(secretName) + .endSecret() + .build()) + val podWithVolumes = new PodBuilder(pod.pod) + .editOrNewSpec() + .addToVolumes(addedVolumes.toSeq: _*) + .endSpec() + .build() + val addedVolumeMounts = kubernetesConf + .roleSecretNamesToMountPaths + .map { + case (secretName, mountPath) => + new VolumeMountBuilder() + .withName(secretVolumeName(secretName)) + .withMountPath(mountPath) + .build() + } + val containerWithMounts = new ContainerBuilder(pod.container) + .addToVolumeMounts(addedVolumeMounts.toSeq: _*) + .build() + SparkPod(podWithVolumes, containerWithMounts) + } + + override def getAdditionalPodSystemProperties(): Map[String, String] = Map.empty + + override def getAdditionalKubernetesResources(): Seq[HasMetadata] = Seq.empty + + private def secretVolumeName(secretName: String): String = s"$secretName-volume" +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala deleted file mode 100644 index b4d3f04a1bc32..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps._ -import org.apache.spark.launcher.SparkLauncher -import org.apache.spark.util.SystemClock -import org.apache.spark.util.Utils - -/** - * Figures out and returns the complete ordered list of needed DriverConfigurationSteps to - * configure the Spark driver pod. The returned steps will be applied one by one in the given - * order to produce a final KubernetesDriverSpec that is used in KubernetesClientApplication - * to construct and create the driver pod. - */ -private[spark] class DriverConfigOrchestrator( - kubernetesAppId: String, - kubernetesResourceNamePrefix: String, - mainAppResource: Option[MainAppResource], - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) { - - // The resource name prefix is derived from the Spark application name, making it easy to connect - // the names of the Kubernetes resources from e.g. kubectl or the Kubernetes dashboard to the - // application the user submitted. - - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - - def getAllConfigurationSteps: Seq[DriverConfigurationStep] = { - val driverCustomLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_LABEL_PREFIX) - require(!driverCustomLabels.contains(SPARK_APP_ID_LABEL), "Label with key " + - s"$SPARK_APP_ID_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - require(!driverCustomLabels.contains(SPARK_ROLE_LABEL), "Label with key " + - s"$SPARK_ROLE_LABEL is not allowed as it is reserved for Spark bookkeeping " + - "operations.") - - val secretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_DRIVER_SECRETS_PREFIX) - - val allDriverLabels = driverCustomLabels ++ Map( - SPARK_APP_ID_LABEL -> kubernetesAppId, - SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) - - val initialSubmissionStep = new BasicDriverConfigurationStep( - kubernetesAppId, - kubernetesResourceNamePrefix, - allDriverLabels, - imagePullPolicy, - appName, - mainClass, - appArgs, - sparkConf) - - val serviceBootstrapStep = new DriverServiceBootstrapStep( - kubernetesResourceNamePrefix, - allDriverLabels, - sparkConf, - new SystemClock) - - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - sparkConf, kubernetesResourceNamePrefix) - - val additionalMainAppJar = if (mainAppResource.nonEmpty) { - val mayBeResource = mainAppResource.get match { - case JavaMainAppResource(resource) if resource != SparkLauncher.NO_RESOURCE => - Some(resource) - case _ => None - } - mayBeResource - } else { - None - } - - val sparkJars = sparkConf.getOption("spark.jars") - .map(_.split(",")) - .getOrElse(Array.empty[String]) ++ - additionalMainAppJar.toSeq - val sparkFiles = sparkConf.getOption("spark.files") - .map(_.split(",")) - .getOrElse(Array.empty[String]) - - // TODO(SPARK-23153): remove once submission client local dependencies are supported. - if (existSubmissionLocalFiles(sparkJars) || existSubmissionLocalFiles(sparkFiles)) { - throw new SparkException("The Kubernetes mode does not yet support referencing application " + - "dependencies in the local file system.") - } - - val dependencyResolutionStep = if (sparkJars.nonEmpty || sparkFiles.nonEmpty) { - Seq(new DependencyResolutionStep( - sparkJars, - sparkFiles)) - } else { - Nil - } - - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - - Seq( - initialSubmissionStep, - serviceBootstrapStep, - kubernetesCredentialsStep) ++ - dependencyResolutionStep ++ - mountSecretsStep - } - - private def existSubmissionLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme == "file" - } - } - - private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { - files.exists { uri => - Utils.resolveURI(uri).getScheme != "local" - } - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala index e16d1add600b2..a97f5650fb869 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesClientApplication.scala @@ -27,12 +27,10 @@ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.deploy.SparkApplication +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.SparkKubernetesClientFactory -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.ConfigBuilder import org.apache.spark.util.Utils /** @@ -43,9 +41,9 @@ import org.apache.spark.util.Utils * @param driverArgs arguments to the driver */ private[spark] case class ClientArguments( - mainAppResource: Option[MainAppResource], - mainClass: String, - driverArgs: Array[String]) + mainAppResource: Option[MainAppResource], + mainClass: String, + driverArgs: Array[String]) private[spark] object ClientArguments { @@ -80,8 +78,9 @@ private[spark] object ClientArguments { * watcher that monitors and logs the application status. Waits for the application to terminate if * spark.kubernetes.submission.waitAppCompletion is true. * - * @param submissionSteps steps that collectively configure the driver - * @param sparkConf the submission client Spark configuration + * @param builder Responsible for building the base driver pod based on a composition of + * implemented features. + * @param kubernetesConf application configuration * @param kubernetesClient the client to talk to the Kubernetes API server * @param waitForAppCompletion a flag indicating whether the client should wait for the application * to complete @@ -89,31 +88,21 @@ private[spark] object ClientArguments { * @param watcher a watcher that monitors and logs the application status */ private[spark] class Client( - submissionSteps: Seq[DriverConfigurationStep], - sparkConf: SparkConf, + builder: KubernetesDriverBuilder, + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf], kubernetesClient: KubernetesClient, waitForAppCompletion: Boolean, appName: String, watcher: LoggingPodStatusWatcher, kubernetesResourceNamePrefix: String) extends Logging { - /** - * Run command that initializes a DriverSpec that will be updated after each - * DriverConfigurationStep in the sequence that is passed in. The final KubernetesDriverSpec - * will be used to build the Driver Container, Driver Pod, and Kubernetes Resources - */ def run(): Unit = { - var currentDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf) - // submissionSteps contain steps necessary to take, to resolve varying - // client arguments that are passed in, created by orchestrator - for (nextStep <- submissionSteps) { - currentDriverSpec = nextStep.configureDriver(currentDriverSpec) - } + val resolvedDriverSpec = builder.buildFromFeatures(kubernetesConf) val configMapName = s"$kubernetesResourceNamePrefix-driver-conf-map" - val configMap = buildConfigMap(configMapName, currentDriverSpec.driverSparkConf) + val configMap = buildConfigMap(configMapName, resolvedDriverSpec.systemProperties) // The include of the ENV_VAR for "SPARK_CONF_DIR" is to allow for the // Spark command builder to pickup on the Java Options present in the ConfigMap - val resolvedDriverContainer = new ContainerBuilder(currentDriverSpec.driverContainer) + val resolvedDriverContainer = new ContainerBuilder(resolvedDriverSpec.pod.container) .addNewEnv() .withName(ENV_SPARK_CONF_DIR) .withValue(SPARK_CONF_DIR_INTERNAL) @@ -123,7 +112,7 @@ private[spark] class Client( .withMountPath(SPARK_CONF_DIR_INTERNAL) .endVolumeMount() .build() - val resolvedDriverPod = new PodBuilder(currentDriverSpec.driverPod) + val resolvedDriverPod = new PodBuilder(resolvedDriverSpec.pod.pod) .editSpec() .addToContainers(resolvedDriverContainer) .addNewVolume() @@ -141,12 +130,10 @@ private[spark] class Client( .watch(watcher)) { _ => val createdDriverPod = kubernetesClient.pods().create(resolvedDriverPod) try { - if (currentDriverSpec.otherKubernetesResources.nonEmpty) { - val otherKubernetesResources = - currentDriverSpec.otherKubernetesResources ++ Seq(configMap) - addDriverOwnerReference(createdDriverPod, otherKubernetesResources) - kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() - } + val otherKubernetesResources = + resolvedDriverSpec.driverKubernetesResources ++ Seq(configMap) + addDriverOwnerReference(createdDriverPod, otherKubernetesResources) + kubernetesClient.resourceList(otherKubernetesResources: _*).createOrReplace() } catch { case NonFatal(e) => kubernetesClient.pods().delete(createdDriverPod) @@ -180,20 +167,17 @@ private[spark] class Client( } // Build a Config Map that will house spark conf properties in a single file for spark-submit - private def buildConfigMap(configMapName: String, conf: SparkConf): ConfigMap = { + private def buildConfigMap(configMapName: String, conf: Map[String, String]): ConfigMap = { val properties = new Properties() - conf.getAll.foreach { case (k, v) => + conf.foreach { case (k, v) => properties.setProperty(k, v) } val propertiesWriter = new StringWriter() properties.store(propertiesWriter, s"Java properties built from Kubernetes config map with name: $configMapName") - - val namespace = conf.get(KUBERNETES_NAMESPACE) new ConfigMapBuilder() .withNewMetadata() .withName(configMapName) - .withNamespace(namespace) .endMetadata() .addToData(SPARK_CONF_FILE_NAME, propertiesWriter.toString) .build() @@ -211,7 +195,7 @@ private[spark] class KubernetesClientApplication extends SparkApplication { } private def run(clientArguments: ClientArguments, sparkConf: SparkConf): Unit = { - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) + val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") // For constructing the app ID, we can't use the Spark application name, as the app ID is going // to be added as a label to group resources belonging to the same application. Label values are // considerably restrictive, e.g. must be no longer than 63 characters in length. So we generate @@ -219,10 +203,19 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val kubernetesAppId = s"spark-${UUID.randomUUID().toString.replaceAll("-", "")}" val launchTime = System.currentTimeMillis() val waitForAppCompletion = sparkConf.get(WAIT_FOR_APP_COMPLETION) - val appName = sparkConf.getOption("spark.app.name").getOrElse("spark") val kubernetesResourceNamePrefix = { s"$appName-$launchTime".toLowerCase.replaceAll("\\.", "-") } + val kubernetesConf = KubernetesConf.createDriverConf( + sparkConf, + appName, + kubernetesResourceNamePrefix, + kubernetesAppId, + clientArguments.mainAppResource, + clientArguments.mainClass, + clientArguments.driverArgs) + val builder = new KubernetesDriverBuilder + val namespace = kubernetesConf.namespace() // The master URL has been checked for validity already in SparkSubmit. // We just need to get rid of the "k8s://" prefix here. val master = sparkConf.get("spark.master").substring("k8s://".length) @@ -230,15 +223,6 @@ private[spark] class KubernetesClientApplication extends SparkApplication { val watcher = new LoggingPodStatusWatcherImpl(kubernetesAppId, loggingInterval) - val orchestrator = new DriverConfigOrchestrator( - kubernetesAppId, - kubernetesResourceNamePrefix, - clientArguments.mainAppResource, - appName, - clientArguments.mainClass, - clientArguments.driverArgs, - sparkConf) - Utils.tryWithResource(SparkKubernetesClientFactory.createKubernetesClient( master, Some(namespace), @@ -247,8 +231,8 @@ private[spark] class KubernetesClientApplication extends SparkApplication { None, None)) { kubernetesClient => val client = new Client( - orchestrator.getAllConfigurationSteps, - sparkConf, + builder, + kubernetesConf, kubernetesClient, waitForAppCompletion, appName, diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala new file mode 100644 index 0000000000000..c7579ed8cb689 --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilder.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, KubernetesRoleSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesDriverBuilder( + provideBasicStep: (KubernetesConf[KubernetesDriverSpecificConf]) => BasicDriverFeatureStep = + new BasicDriverFeatureStep(_), + provideCredentialsStep: (KubernetesConf[KubernetesDriverSpecificConf]) + => DriverKubernetesCredentialsFeatureStep = + new DriverKubernetesCredentialsFeatureStep(_), + provideServiceStep: (KubernetesConf[KubernetesDriverSpecificConf]) => DriverServiceFeatureStep = + new DriverServiceFeatureStep(_), + provideSecretsStep: (KubernetesConf[_ <: KubernetesRoleSpecificConf] + => MountSecretsFeatureStep) = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf]): KubernetesDriverSpec = { + val baseFeatures = Seq( + provideBasicStep(kubernetesConf), + provideCredentialsStep(kubernetesConf), + provideServiceStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + + var spec = KubernetesDriverSpec.initialSpec(kubernetesConf.sparkConf.getAll.toMap) + for (feature <- allFeatures) { + val configuredPod = feature.configurePod(spec.pod) + val addedSystemProperties = feature.getAdditionalPodSystemProperties() + val addedResources = feature.getAdditionalKubernetesResources() + spec = KubernetesDriverSpec( + configuredPod, + spec.driverKubernetesResources ++ addedResources, + spec.systemProperties ++ addedSystemProperties) + } + spec + } +} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala deleted file mode 100644 index db13f09387ef9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverSpec.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, HasMetadata, Pod, PodBuilder} - -import org.apache.spark.SparkConf - -/** - * Represents the components and characteristics of a Spark driver. The driver can be considered - * as being comprised of the driver pod itself, any other Kubernetes resources that the driver - * pod depends on, and the SparkConf that should be supplied to the Spark application. The driver - * container should be operated on via the specific field of this case class as opposed to trying - * to edit the container directly on the pod. The driver container should be attached at the - * end of executing all submission steps. - */ -private[spark] case class KubernetesDriverSpec( - driverPod: Pod, - driverContainer: Container, - otherKubernetesResources: Seq[HasMetadata], - driverSparkConf: SparkConf) - -private[spark] object KubernetesDriverSpec { - def initialSpec(initialSparkConf: SparkConf): KubernetesDriverSpec = { - KubernetesDriverSpec( - // Set new metadata and a new spec so that submission steps can use - // PodBuilder#editMetadata() and/or PodBuilder#editSpec() safely. - new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build(), - new ContainerBuilder().build(), - Seq.empty[HasMetadata], - initialSparkConf.clone()) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala deleted file mode 100644 index fcb1db8008053..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.config.{DRIVER_CLASS_PATH, DRIVER_MEMORY, DRIVER_MEMORY_OVERHEAD} -import org.apache.spark.launcher.SparkLauncher - -/** - * Performs basic configuration for the driver pod. - */ -private[spark] class BasicDriverConfigurationStep( - kubernetesAppId: String, - resourceNamePrefix: String, - driverLabels: Map[String, String], - imagePullPolicy: String, - appName: String, - mainClass: String, - appArgs: Array[String], - sparkConf: SparkConf) extends DriverConfigurationStep { - - private val driverPodName = sparkConf - .get(KUBERNETES_DRIVER_POD_NAME) - .getOrElse(s"$resourceNamePrefix-driver") - - private val driverExtraClasspath = sparkConf.get(DRIVER_CLASS_PATH) - - private val driverContainerImage = sparkConf - .get(DRIVER_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the driver container image")) - - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - - // CPU settings - private val driverCpuCores = sparkConf.getOption("spark.driver.cores").getOrElse("1") - private val driverLimitCores = sparkConf.get(KUBERNETES_DRIVER_LIMIT_CORES) - - // Memory settings - private val driverMemoryMiB = sparkConf.get(DRIVER_MEMORY) - private val memoryOverheadMiB = sparkConf - .get(DRIVER_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * driverMemoryMiB).toInt, MEMORY_OVERHEAD_MIN_MIB)) - private val driverMemoryWithOverheadMiB = driverMemoryMiB + memoryOverheadMiB - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverExtraClasspathEnv = driverExtraClasspath.map { classPath => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(classPath) - .build() - } - - val driverCustomAnnotations = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_DRIVER_ANNOTATION_PREFIX) - require(!driverCustomAnnotations.contains(SPARK_APP_NAME_ANNOTATION), - s"Annotation with key $SPARK_APP_NAME_ANNOTATION is not allowed as it is reserved for" + - " Spark bookkeeping operations.") - - val driverCustomEnvs = sparkConf.getAllWithPrefix(KUBERNETES_DRIVER_ENV_KEY).toSeq - .map { env => - new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - } - - val driverAnnotations = driverCustomAnnotations ++ Map(SPARK_APP_NAME_ANNOTATION -> appName) - - val nodeSelector = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, KUBERNETES_NODE_SELECTOR_PREFIX) - - val driverCpuQuantity = new QuantityBuilder(false) - .withAmount(driverCpuCores) - .build() - val driverMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${driverMemoryWithOverheadMiB}Mi") - .build() - val maybeCpuLimitQuantity = driverLimitCores.map { limitCores => - ("cpu", new QuantityBuilder(false).withAmount(limitCores).build()) - } - - val driverContainerWithoutArgs = new ContainerBuilder(driverSpec.driverContainer) - .withName(DRIVER_CONTAINER_NAME) - .withImage(driverContainerImage) - .withImagePullPolicy(imagePullPolicy) - .addAllToEnv(driverCustomEnvs.asJava) - .addToEnv(driverExtraClasspathEnv.toSeq: _*) - .addNewEnv() - .withName(ENV_DRIVER_BIND_ADDRESS) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .endEnv() - .withNewResources() - .addToRequests("cpu", driverCpuQuantity) - .addToRequests("memory", driverMemoryQuantity) - .addToLimits("memory", driverMemoryQuantity) - .addToLimits(maybeCpuLimitQuantity.toMap.asJava) - .endResources() - .addToArgs("driver") - .addToArgs("--properties-file", SPARK_CONF_PATH) - .addToArgs("--class", mainClass) - // The user application jar is merged into the spark.jars list and managed through that - // property, so there is no need to reference it explicitly here. - .addToArgs(SparkLauncher.NO_RESOURCE) - - val driverContainer = appArgs.toList match { - case "" :: Nil | Nil => driverContainerWithoutArgs.build() - case _ => driverContainerWithoutArgs.addToArgs(appArgs: _*).build() - } - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - val baseDriverPod = new PodBuilder(driverSpec.driverPod) - .editOrNewMetadata() - .withName(driverPodName) - .addToLabels(driverLabels.asJava) - .addToAnnotations(driverAnnotations.asJava) - .endMetadata() - .withNewSpec() - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .setIfMissing(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set("spark.app.id", kubernetesAppId) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, resourceNamePrefix) - // to set the config variables to allow client-mode spark-submit from driver - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - - driverSpec.copy( - driverPod = baseDriverPod, - driverSparkConf = resolvedSparkConf, - driverContainer = driverContainer) - } - -} - diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala deleted file mode 100644 index 43de329f239ad..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStep.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import io.fabric8.kubernetes.api.model.ContainerBuilder - -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.KubernetesUtils -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Step that configures the classpath, spark.jars, and spark.files for the driver given that the - * user may provide remote files or files with local:// schemes. - */ -private[spark] class DependencyResolutionStep( - sparkJars: Seq[String], - sparkFiles: Seq[String]) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val resolvedSparkJars = KubernetesUtils.resolveFileUrisAndPath(sparkJars) - val resolvedSparkFiles = KubernetesUtils.resolveFileUrisAndPath(sparkFiles) - - val sparkConf = driverSpec.driverSparkConf.clone() - if (resolvedSparkJars.nonEmpty) { - sparkConf.set("spark.jars", resolvedSparkJars.mkString(",")) - } - if (resolvedSparkFiles.nonEmpty) { - sparkConf.set("spark.files", resolvedSparkFiles.mkString(",")) - } - val resolvedDriverContainer = if (resolvedSparkJars.nonEmpty) { - new ContainerBuilder(driverSpec.driverContainer) - .addNewEnv() - .withName(ENV_MOUNTED_CLASSPATH) - .withValue(resolvedSparkJars.mkString(File.pathSeparator)) - .endEnv() - .build() - } else { - driverSpec.driverContainer - } - - driverSpec.copy( - driverContainer = resolvedDriverContainer, - driverSparkConf = sparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala deleted file mode 100644 index 2424e63999a82..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStep.scala +++ /dev/null @@ -1,245 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File -import java.nio.charset.StandardCharsets - -import scala.collection.JavaConverters._ -import scala.language.implicitConversions - -import com.google.common.io.{BaseEncoding, Files} -import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder, Secret, SecretBuilder} - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -/** - * Mounts Kubernetes credentials into the driver pod. The driver will use such mounted credentials - * to request executors. - */ -private[spark] class DriverKubernetesCredentialsStep( - submissionSparkConf: SparkConf, - kubernetesResourceNamePrefix: String) extends DriverConfigurationStep { - - private val maybeMountedOAuthTokenFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX") - private val maybeMountedClientKeyFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX") - private val maybeMountedClientCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX") - private val maybeMountedCaCertFile = submissionSparkConf.getOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX") - private val driverServiceAccount = submissionSparkConf.get(KUBERNETES_SERVICE_ACCOUNT_NAME) - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val driverSparkConf = driverSpec.driverSparkConf.clone() - - val oauthTokenBase64 = submissionSparkConf - .getOption(s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX") - .map { token => - BaseEncoding.base64().encode(token.getBytes(StandardCharsets.UTF_8)) - } - val caCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - "Driver CA cert file") - val clientKeyDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - "Driver client key file") - val clientCertDataBase64 = safeFileConfToBase64( - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - "Driver client cert file") - - val driverSparkConfWithCredentialsLocations = setDriverPodKubernetesCredentialLocations( - driverSparkConf, - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val kubernetesCredentialsSecret = createCredentialsSecret( - oauthTokenBase64, - caCertDataBase64, - clientKeyDataBase64, - clientCertDataBase64) - - val driverPodWithMountedKubernetesCredentials = kubernetesCredentialsSecret.map { secret => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .addNewVolume() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withNewSecret().withSecretName(secret.getMetadata.getName).endSecret() - .endVolume() - .endSpec() - .build() - }.getOrElse( - driverServiceAccount.map { account => - new PodBuilder(driverSpec.driverPod) - .editOrNewSpec() - .withServiceAccount(account) - .withServiceAccountName(account) - .endSpec() - .build() - }.getOrElse(driverSpec.driverPod) - ) - - val driverContainerWithMountedSecretVolume = kubernetesCredentialsSecret.map { _ => - new ContainerBuilder(driverSpec.driverContainer) - .addNewVolumeMount() - .withName(DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) - .withMountPath(DRIVER_CREDENTIALS_SECRETS_BASE_DIR) - .endVolumeMount() - .build() - }.getOrElse(driverSpec.driverContainer) - - driverSpec.copy( - driverPod = driverPodWithMountedKubernetesCredentials, - otherKubernetesResources = - driverSpec.otherKubernetesResources ++ kubernetesCredentialsSecret.toSeq, - driverSparkConf = driverSparkConfWithCredentialsLocations, - driverContainer = driverContainerWithMountedSecretVolume) - } - - private def createCredentialsSecret( - driverOAuthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): Option[Secret] = { - val allSecretData = - resolveSecretData( - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME) ++ - resolveSecretData( - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME) ++ - resolveSecretData( - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_SECRET_NAME) ++ - resolveSecretData( - driverOAuthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_SECRET_NAME) - - if (allSecretData.isEmpty) { - None - } else { - Some(new SecretBuilder() - .withNewMetadata() - .withName(s"$kubernetesResourceNamePrefix-kubernetes-credentials") - .endMetadata() - .withData(allSecretData.asJava) - .build()) - } - } - - private def setDriverPodKubernetesCredentialLocations( - driverSparkConf: SparkConf, - driverOauthTokenBase64: Option[String], - driverCaCertDataBase64: Option[String], - driverClientKeyDataBase64: Option[String], - driverClientCertDataBase64: Option[String]): SparkConf = { - val resolvedMountedOAuthTokenFile = resolveSecretLocation( - maybeMountedOAuthTokenFile, - driverOauthTokenBase64, - DRIVER_CREDENTIALS_OAUTH_TOKEN_PATH) - val resolvedMountedClientKeyFile = resolveSecretLocation( - maybeMountedClientKeyFile, - driverClientKeyDataBase64, - DRIVER_CREDENTIALS_CLIENT_KEY_PATH) - val resolvedMountedClientCertFile = resolveSecretLocation( - maybeMountedClientCertFile, - driverClientCertDataBase64, - DRIVER_CREDENTIALS_CLIENT_CERT_PATH) - val resolvedMountedCaCertFile = resolveSecretLocation( - maybeMountedCaCertFile, - driverCaCertDataBase64, - DRIVER_CREDENTIALS_CA_CERT_PATH) - - val sparkConfWithCredentialLocations = driverSparkConf - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", - resolvedMountedCaCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX", - resolvedMountedClientKeyFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX", - resolvedMountedClientCertFile) - .setOption( - s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX", - resolvedMountedOAuthTokenFile) - - // Redact all OAuth token values - sparkConfWithCredentialLocations - .getAll - .filter(_._1.endsWith(OAUTH_TOKEN_CONF_SUFFIX)).map(_._1) - .foreach { - sparkConfWithCredentialLocations.set(_, "") - } - sparkConfWithCredentialLocations - } - - private def safeFileConfToBase64(conf: String, fileType: String): Option[String] = { - submissionSparkConf.getOption(conf) - .map(new File(_)) - .map { file => - require(file.isFile, String.format("%s provided at %s does not exist or is not a file.", - fileType, file.getAbsolutePath)) - BaseEncoding.base64().encode(Files.toByteArray(file)) - } - } - - private def resolveSecretLocation( - mountedUserSpecified: Option[String], - valueMountedFromSubmitter: Option[String], - mountedCanonicalLocation: String): Option[String] = { - mountedUserSpecified.orElse(valueMountedFromSubmitter.map { _ => - mountedCanonicalLocation - }) - } - - /** - * Resolve a Kubernetes secret data entry from an optional client credential used by the - * driver to talk to the Kubernetes API server. - * - * @param userSpecifiedCredential the optional user-specified client credential. - * @param secretName name of the Kubernetes secret storing the client credential. - * @return a secret data entry in the form of a map from the secret name to the secret data, - * which may be empty if the user-specified credential is empty. - */ - private def resolveSecretData( - userSpecifiedCredential: Option[String], - secretName: String): Map[String, String] = { - userSpecifiedCredential.map { valueBase64 => - Map(secretName -> valueBase64) - }.getOrElse(Map.empty[String, String]) - } - - private implicit def augmentSparkConf(sparkConf: SparkConf): OptionSettableSparkConf = { - new OptionSettableSparkConf(sparkConf) - } -} - -private class OptionSettableSparkConf(sparkConf: SparkConf) { - def setOption(configEntry: String, option: Option[String]): SparkConf = { - option.foreach { opt => - sparkConf.set(configEntry, opt) - } - sparkConf - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala deleted file mode 100644 index 34af7cde6c1a9..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStep.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.ServiceBuilder - -import org.apache.spark.SparkConf -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.internal.Logging -import org.apache.spark.util.Clock - -/** - * Allows the driver to be reachable by executor pods through a headless service. The service's - * ports should correspond to the ports that the executor will reach the pod at for RPC. - */ -private[spark] class DriverServiceBootstrapStep( - resourceNamePrefix: String, - driverLabels: Map[String, String], - sparkConf: SparkConf, - clock: Clock) extends DriverConfigurationStep with Logging { - - import DriverServiceBootstrapStep._ - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - require(sparkConf.getOption(DRIVER_BIND_ADDRESS_KEY).isEmpty, - s"$DRIVER_BIND_ADDRESS_KEY is not supported in Kubernetes mode, as the driver's bind " + - "address is managed and set to the driver pod's IP address.") - require(sparkConf.getOption(DRIVER_HOST_KEY).isEmpty, - s"$DRIVER_HOST_KEY is not supported in Kubernetes mode, as the driver's hostname will be " + - "managed via a Kubernetes service.") - - val preferredServiceName = s"$resourceNamePrefix$DRIVER_SVC_POSTFIX" - val resolvedServiceName = if (preferredServiceName.length <= MAX_SERVICE_NAME_LENGTH) { - preferredServiceName - } else { - val randomServiceId = clock.getTimeMillis() - val shorterServiceName = s"spark-$randomServiceId$DRIVER_SVC_POSTFIX" - logWarning(s"Driver's hostname would preferably be $preferredServiceName, but this is " + - s"too long (must be <= $MAX_SERVICE_NAME_LENGTH characters). Falling back to use " + - s"$shorterServiceName as the driver service's name.") - shorterServiceName - } - - val driverPort = sparkConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT) - val driverBlockManagerPort = sparkConf.getInt( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key, DEFAULT_BLOCKMANAGER_PORT) - val driverService = new ServiceBuilder() - .withNewMetadata() - .withName(resolvedServiceName) - .endMetadata() - .withNewSpec() - .withClusterIP("None") - .withSelector(driverLabels.asJava) - .addNewPort() - .withName(DRIVER_PORT_NAME) - .withPort(driverPort) - .withNewTargetPort(driverPort) - .endPort() - .addNewPort() - .withName(BLOCK_MANAGER_PORT_NAME) - .withPort(driverBlockManagerPort) - .withNewTargetPort(driverBlockManagerPort) - .endPort() - .endSpec() - .build() - - val namespace = sparkConf.get(KUBERNETES_NAMESPACE) - val driverHostname = s"${driverService.getMetadata.getName}.$namespace.svc" - val resolvedSparkConf = driverSpec.driverSparkConf.clone() - .set(DRIVER_HOST_KEY, driverHostname) - .set("spark.driver.port", driverPort.toString) - .set( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, driverBlockManagerPort) - - driverSpec.copy( - driverSparkConf = resolvedSparkConf, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(driverService)) - } -} - -private[spark] object DriverServiceBootstrapStep { - val DRIVER_BIND_ADDRESS_KEY = org.apache.spark.internal.config.DRIVER_BIND_ADDRESS.key - val DRIVER_HOST_KEY = org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key - val DRIVER_SVC_POSTFIX = "-driver-svc" - val MAX_SERVICE_NAME_LENGTH = 63 -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala deleted file mode 100644 index 8607d6fba3234..0000000000000 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ - -import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.internal.config.{EXECUTOR_CLASS_PATH, EXECUTOR_JAVA_OPTIONS, EXECUTOR_MEMORY, EXECUTOR_MEMORY_OVERHEAD} -import org.apache.spark.util.Utils - -/** - * A factory class for bootstrapping and creating executor pods with the given bootstrapping - * components. - * - * @param sparkConf Spark configuration - * @param mountSecretsBootstrap an optional component for mounting user-specified secrets onto - * user-specified paths into the executor container - */ -private[spark] class ExecutorPodFactory( - sparkConf: SparkConf, - mountSecretsBootstrap: Option[MountSecretsBootstrap]) { - - private val executorExtraClasspath = sparkConf.get(EXECUTOR_CLASS_PATH) - - private val executorLabels = KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_LABEL_PREFIX) - require( - !executorLabels.contains(SPARK_APP_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_APP_ID_LABEL as it is reserved for Spark.") - require( - !executorLabels.contains(SPARK_EXECUTOR_ID_LABEL), - s"Custom executor labels cannot contain $SPARK_EXECUTOR_ID_LABEL as it is reserved for" + - " Spark.") - require( - !executorLabels.contains(SPARK_ROLE_LABEL), - s"Custom executor labels cannot contain $SPARK_ROLE_LABEL as it is reserved for Spark.") - - private val executorAnnotations = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_EXECUTOR_ANNOTATION_PREFIX) - private val nodeSelector = - KubernetesUtils.parsePrefixedKeyValuePairs( - sparkConf, - KUBERNETES_NODE_SELECTOR_PREFIX) - - private val executorContainerImage = sparkConf - .get(EXECUTOR_CONTAINER_IMAGE) - .getOrElse(throw new SparkException("Must specify the executor container image")) - private val imagePullPolicy = sparkConf.get(CONTAINER_IMAGE_PULL_POLICY) - private val imagePullSecrets = sparkConf.get(IMAGE_PULL_SECRETS) - private val blockManagerPort = sparkConf - .getInt("spark.blockmanager.port", DEFAULT_BLOCKMANAGER_PORT) - - private val executorPodNamePrefix = sparkConf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) - - private val executorMemoryMiB = sparkConf.get(EXECUTOR_MEMORY) - private val executorMemoryString = sparkConf.get( - EXECUTOR_MEMORY.key, EXECUTOR_MEMORY.defaultValueString) - - private val memoryOverheadMiB = sparkConf - .get(EXECUTOR_MEMORY_OVERHEAD) - .getOrElse(math.max((MEMORY_OVERHEAD_FACTOR * executorMemoryMiB).toInt, - MEMORY_OVERHEAD_MIN_MIB)) - private val executorMemoryWithOverhead = executorMemoryMiB + memoryOverheadMiB - - private val executorCores = sparkConf.getInt("spark.executor.cores", 1) - private val executorCoresRequest = if (sparkConf.contains(KUBERNETES_EXECUTOR_REQUEST_CORES)) { - sparkConf.get(KUBERNETES_EXECUTOR_REQUEST_CORES).get - } else { - executorCores.toString - } - private val executorLimitCores = sparkConf.get(KUBERNETES_EXECUTOR_LIMIT_CORES) - - /** - * Configure and construct an executor pod with the given parameters. - */ - def createExecutorPod( - executorId: String, - applicationId: String, - driverUrl: String, - executorEnvs: Seq[(String, String)], - driverPod: Pod, - nodeToLocalTaskCount: Map[String, Int]): Pod = { - val name = s"$executorPodNamePrefix-exec-$executorId" - - val parsedImagePullSecrets = KubernetesUtils.parseImagePullSecrets(imagePullSecrets) - - // hostname must be no longer than 63 characters, so take the last 63 characters of the pod - // name as the hostname. This preserves uniqueness since the end of name contains - // executorId - val hostname = name.substring(Math.max(0, name.length - 63)) - val resolvedExecutorLabels = Map( - SPARK_EXECUTOR_ID_LABEL -> executorId, - SPARK_APP_ID_LABEL -> applicationId, - SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ - executorLabels - val executorMemoryQuantity = new QuantityBuilder(false) - .withAmount(s"${executorMemoryWithOverhead}Mi") - .build() - val executorCpuQuantity = new QuantityBuilder(false) - .withAmount(executorCoresRequest) - .build() - val executorExtraClasspathEnv = executorExtraClasspath.map { cp => - new EnvVarBuilder() - .withName(ENV_CLASSPATH) - .withValue(cp) - .build() - } - val executorExtraJavaOptionsEnv = sparkConf - .get(EXECUTOR_JAVA_OPTIONS) - .map { opts => - val delimitedOpts = Utils.splitCommandString(opts) - delimitedOpts.zipWithIndex.map { - case (opt, index) => - new EnvVarBuilder().withName(s"$ENV_JAVA_OPT_PREFIX$index").withValue(opt).build() - } - }.getOrElse(Seq.empty[EnvVar]) - val executorEnv = (Seq( - (ENV_DRIVER_URL, driverUrl), - (ENV_EXECUTOR_CORES, executorCores.toString), - (ENV_EXECUTOR_MEMORY, executorMemoryString), - (ENV_APPLICATION_ID, applicationId), - // This is to set the SPARK_CONF_DIR to be /opt/spark/conf - (ENV_SPARK_CONF_DIR, SPARK_CONF_DIR_INTERNAL), - (ENV_EXECUTOR_ID, executorId)) ++ executorEnvs) - .map(env => new EnvVarBuilder() - .withName(env._1) - .withValue(env._2) - .build() - ) ++ Seq( - new EnvVarBuilder() - .withName(ENV_EXECUTOR_POD_IP) - .withValueFrom(new EnvVarSourceBuilder() - .withNewFieldRef("v1", "status.podIP") - .build()) - .build() - ) ++ executorExtraJavaOptionsEnv ++ executorExtraClasspathEnv.toSeq - val requiredPorts = Seq( - (BLOCK_MANAGER_PORT_NAME, blockManagerPort)) - .map { case (name, port) => - new ContainerPortBuilder() - .withName(name) - .withContainerPort(port) - .build() - } - - val executorContainer = new ContainerBuilder() - .withName("executor") - .withImage(executorContainerImage) - .withImagePullPolicy(imagePullPolicy) - .withNewResources() - .addToRequests("memory", executorMemoryQuantity) - .addToLimits("memory", executorMemoryQuantity) - .addToRequests("cpu", executorCpuQuantity) - .endResources() - .addAllToEnv(executorEnv.asJava) - .withPorts(requiredPorts.asJava) - .addToArgs("executor") - .build() - - val executorPod = new PodBuilder() - .withNewMetadata() - .withName(name) - .withLabels(resolvedExecutorLabels.asJava) - .withAnnotations(executorAnnotations.asJava) - .withOwnerReferences() - .addNewOwnerReference() - .withController(true) - .withApiVersion(driverPod.getApiVersion) - .withKind(driverPod.getKind) - .withName(driverPod.getMetadata.getName) - .withUid(driverPod.getMetadata.getUid) - .endOwnerReference() - .endMetadata() - .withNewSpec() - .withHostname(hostname) - .withRestartPolicy("Never") - .withNodeSelector(nodeSelector.asJava) - .withImagePullSecrets(parsedImagePullSecrets.asJava) - .endSpec() - .build() - - val containerWithLimitCores = executorLimitCores.map { limitCores => - val executorCpuLimitQuantity = new QuantityBuilder(false) - .withAmount(limitCores) - .build() - new ContainerBuilder(executorContainer) - .editResources() - .addToLimits("cpu", executorCpuLimitQuantity) - .endResources() - .build() - }.getOrElse(executorContainer) - - val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = - mountSecretsBootstrap.map { bootstrap => - (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) - }.getOrElse((executorPod, containerWithLimitCores)) - - - new PodBuilder(maybeSecretsMountedPod) - .editSpec() - .addToContainers(maybeSecretsMountedContainer) - .endSpec() - .build() - } -} diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala index ff5f6801da2a3..0ea80dfbc0d97 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterManager.scala @@ -21,7 +21,7 @@ import java.io.File import io.fabric8.kubernetes.client.Config import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.deploy.k8s.{KubernetesUtils, MountSecretsBootstrap, SparkKubernetesClientFactory} +import org.apache.spark.deploy.k8s.{KubernetesUtils, SparkKubernetesClientFactory} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.internal.Logging @@ -48,12 +48,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit scheduler: TaskScheduler): SchedulerBackend = { val executorSecretNamesToMountPaths = KubernetesUtils.parsePrefixedKeyValuePairs( sc.conf, KUBERNETES_EXECUTOR_SECRETS_PREFIX) - val mountSecretBootstrap = if (executorSecretNamesToMountPaths.nonEmpty) { - Some(new MountSecretsBootstrap(executorSecretNamesToMountPaths)) - } else { - None - } - val kubernetesClient = SparkKubernetesClientFactory.createKubernetesClient( KUBERNETES_MASTER_INTERNAL_URL, Some(sc.conf.get(KUBERNETES_NAMESPACE)), @@ -62,8 +56,6 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_TOKEN_PATH)), Some(new File(Config.KUBERNETES_SERVICE_ACCOUNT_CA_CRT_PATH))) - val executorPodFactory = new ExecutorPodFactory(sc.conf, mountSecretBootstrap) - val allocatorExecutor = ThreadUtils .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( @@ -71,7 +63,7 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit new KubernetesClusterSchedulerBackend( scheduler.asInstanceOf[TaskSchedulerImpl], sc.env.rpcEnv, - executorPodFactory, + new KubernetesExecutorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala index 9de4b16c30d3c..d86664c81071b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackend.scala @@ -32,6 +32,7 @@ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.SparkException import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.KubernetesConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SchedulerBackendUtils} @@ -40,7 +41,7 @@ import org.apache.spark.util.Utils private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, rpcEnv: RpcEnv, - executorPodFactory: ExecutorPodFactory, + executorBuilder: KubernetesExecutorBuilder, kubernetesClient: KubernetesClient, allocatorExecutor: ScheduledExecutorService, requestExecutorsService: ExecutorService) @@ -115,14 +116,19 @@ private[spark] class KubernetesClusterSchedulerBackend( for (_ <- 0 until math.min( currentTotalExpectedExecutors - runningExecutorsToPods.size, podAllocationSize)) { val executorId = EXECUTOR_ID_COUNTER.incrementAndGet().toString - val executorPod = executorPodFactory.createExecutorPod( + val executorConf = KubernetesConf.createExecutorConf( + conf, executorId, applicationId(), - driverUrl, - conf.getExecutorEnv, - driverPod, - currentNodeToLocalTaskCount) - executorsToAllocate(executorId) = executorPod + driverPod) + val executorPod = executorBuilder.buildFromFeatures(executorConf) + val podWithAttachedContainer = new PodBuilder(executorPod.pod) + .editOrNewSpec() + .addToContainers(executorPod.container) + .endSpec() + .build() + + executorsToAllocate(executorId) = podWithAttachedContainer logInfo( s"Requesting a new executor, total executors is now ${runningExecutorsToPods.size}") } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala new file mode 100644 index 0000000000000..22568fe7ea3be --- /dev/null +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, KubernetesRoleSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, MountSecretsFeatureStep} + +private[spark] class KubernetesExecutorBuilder( + provideBasicStep: (KubernetesConf[KubernetesExecutorSpecificConf]) => BasicExecutorFeatureStep = + new BasicExecutorFeatureStep(_), + provideSecretsStep: + (KubernetesConf[_ <: KubernetesRoleSpecificConf]) => MountSecretsFeatureStep = + new MountSecretsFeatureStep(_)) { + + def buildFromFeatures( + kubernetesConf: KubernetesConf[KubernetesExecutorSpecificConf]): SparkPod = { + val baseFeatures = Seq(provideBasicStep(kubernetesConf)) + val allFeatures = if (kubernetesConf.roleSecretNamesToMountPaths.nonEmpty) { + baseFeatures ++ Seq(provideSecretsStep(kubernetesConf)) + } else baseFeatures + var executorPod = SparkPod.initialPod() + for (feature <- allFeatures) { + executorPod = feature.configurePod(executorPod) + } + executorPod + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala new file mode 100644 index 0000000000000..f10202f7a3546 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/KubernetesConfSuite.scala @@ -0,0 +1,175 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.k8s + +import io.fabric8.kubernetes.api.model.{LocalObjectReferenceBuilder, PodBuilder} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.deploy.k8s.submit.JavaMainAppResource + +class KubernetesConfSuite extends SparkFunSuite { + + private val APP_NAME = "test-app" + private val RESOURCE_NAME_PREFIX = "prefix" + private val APP_ID = "test-id" + private val MAIN_CLASS = "test-class" + private val APP_ARGS = Array("arg1", "arg2") + private val CUSTOM_LABELS = Map( + "customLabel1Key" -> "customLabel1Value", + "customLabel2Key" -> "customLabel2Value") + private val CUSTOM_ANNOTATIONS = Map( + "customAnnotation1Key" -> "customAnnotation1Value", + "customAnnotation2Key" -> "customAnnotation2Value") + private val SECRET_NAMES_TO_MOUNT_PATHS = Map( + "secret1" -> "/mnt/secrets/secret1", + "secret2" -> "/mnt/secrets/secret2") + private val CUSTOM_ENVS = Map( + "customEnvKey1" -> "customEnvValue1", + "customEnvKey2" -> "customEnvValue2") + private val DRIVER_POD = new PodBuilder().build() + private val EXECUTOR_ID = "executor-id" + + test("Basic driver translated fields.") { + val sparkConf = new SparkConf(false) + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.appId === APP_ID) + assert(conf.sparkConf.getAll.toMap === sparkConf.getAll.toMap) + assert(conf.appResourceNamePrefix === RESOURCE_NAME_PREFIX) + assert(conf.roleSpecificConf.appName === APP_NAME) + assert(conf.roleSpecificConf.mainAppResource.isEmpty) + assert(conf.roleSpecificConf.mainClass === MAIN_CLASS) + assert(conf.roleSpecificConf.appArgs === APP_ARGS) + } + + test("Creating driver conf with and without the main app jar influences spark.jars") { + val sparkConf = new SparkConf(false) + .setJars(Seq("local:///opt/spark/jar1.jar")) + val mainAppJar = Some(JavaMainAppResource("local:///opt/spark/main.jar")) + val kubernetesConfWithMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + mainAppJar, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithMainJar.sparkConf.get("spark.jars") + .split(",") + === Array("local:///opt/spark/jar1.jar", "local:///opt/spark/main.jar")) + val kubernetesConfWithoutMainJar = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(kubernetesConfWithoutMainJar.sparkConf.get("spark.jars").split(",") + === Array("local:///opt/spark/jar1.jar")) + } + + test("Resolve driver labels, annotations, secret mount paths, and envs.") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$key", value) + } + CUSTOM_ENVS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_DRIVER_ENV_PREFIX$key", value) + } + + val conf = KubernetesConf.createDriverConf( + sparkConf, + APP_NAME, + RESOURCE_NAME_PREFIX, + APP_ID, + None, + MAIN_CLASS, + APP_ARGS) + assert(conf.roleLabels === Map( + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_DRIVER_ROLE) ++ + CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + assert(conf.roleEnvs === CUSTOM_ENVS) + } + + test("Basic executor translated fields.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleSpecificConf.executorId === EXECUTOR_ID) + assert(conf.roleSpecificConf.driverPod === DRIVER_POD) + } + + test("Image pull secrets.") { + val conf = KubernetesConf.createExecutorConf( + new SparkConf(false) + .set(IMAGE_PULL_SECRETS, "my-secret-1,my-secret-2 "), + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.imagePullSecrets() === + Seq( + new LocalObjectReferenceBuilder().withName("my-secret-1").build(), + new LocalObjectReferenceBuilder().withName("my-secret-2").build())) + } + + test("Set executor labels, annotations, and secrets") { + val sparkConf = new SparkConf(false) + CUSTOM_LABELS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_LABEL_PREFIX$key", value) + } + CUSTOM_ANNOTATIONS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_ANNOTATION_PREFIX$key", value) + } + SECRET_NAMES_TO_MOUNT_PATHS.foreach { case (key, value) => + sparkConf.set(s"$KUBERNETES_EXECUTOR_SECRETS_PREFIX$key", value) + } + + val conf = KubernetesConf.createExecutorConf( + sparkConf, + EXECUTOR_ID, + APP_ID, + DRIVER_POD) + assert(conf.roleLabels === Map( + SPARK_EXECUTOR_ID_LABEL -> EXECUTOR_ID, + SPARK_APP_ID_LABEL -> APP_ID, + SPARK_ROLE_LABEL -> SPARK_POD_EXECUTOR_ROLE) ++ CUSTOM_LABELS) + assert(conf.roleAnnotations === CUSTOM_ANNOTATIONS) + assert(conf.roleSecretNamesToMountPaths === SECRET_NAMES_TO_MOUNT_PATHS) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala new file mode 100644 index 0000000000000..eee85b8baa730 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicDriverFeatureStepSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model.LocalObjectReferenceBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ + +class BasicDriverFeatureStepSuite extends SparkFunSuite { + + private val APP_ID = "spark-app-id" + private val RESOURCE_NAME_PREFIX = "spark" + private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") + private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" + private val APP_NAME = "spark-test" + private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") + private val CUSTOM_ANNOTATION_KEY = "customAnnotation" + private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" + private val DRIVER_ANNOTATIONS = Map(CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE) + private val DRIVER_CUSTOM_ENV1 = "customDriverEnv1" + private val DRIVER_CUSTOM_ENV2 = "customDriverEnv2" + private val DRIVER_ENVS = Map( + DRIVER_CUSTOM_ENV1 -> DRIVER_CUSTOM_ENV1, + DRIVER_CUSTOM_ENV2 -> DRIVER_CUSTOM_ENV2) + private val TEST_IMAGE_PULL_SECRETS = Seq("my-secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + + test("Check the pod respects all configurations from the user.") { + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .set("spark.driver.cores", "2") + .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") + .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") + .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) + .set(CONTAINER_IMAGE, "spark-driver:latest") + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + DRIVER_ENVS) + + val featureStep = new BasicDriverFeatureStep(kubernetesConf) + val basePod = SparkPod.initialPod() + val configuredPod = featureStep.configurePod(basePod) + + assert(configuredPod.container.getName === DRIVER_CONTAINER_NAME) + assert(configuredPod.container.getImage === "spark-driver:latest") + assert(configuredPod.container.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) + + assert(configuredPod.container.getEnv.size === 3) + val envs = configuredPod.container + .getEnv + .asScala + .map(env => (env.getName, env.getValue)) + .toMap + assert(envs(DRIVER_CUSTOM_ENV1) === DRIVER_ENVS(DRIVER_CUSTOM_ENV1)) + assert(envs(DRIVER_CUSTOM_ENV2) === DRIVER_ENVS(DRIVER_CUSTOM_ENV2)) + + assert(configuredPod.pod.getSpec().getImagePullSecrets.asScala === + TEST_IMAGE_PULL_SECRET_OBJECTS) + + assert(configuredPod.container.getEnv.asScala.exists(envVar => + envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && + envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && + envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) + + val resourceRequirements = configuredPod.container.getResources + val requests = resourceRequirements.getRequests.asScala + assert(requests("cpu").getAmount === "2") + assert(requests("memory").getAmount === "456Mi") + val limits = resourceRequirements.getLimits.asScala + assert(limits("memory").getAmount === "456Mi") + assert(limits("cpu").getAmount === "4") + + val driverPodMetadata = configuredPod.pod.getMetadata + assert(driverPodMetadata.getName === "spark-driver-pod") + assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) + assert(driverPodMetadata.getAnnotations.asScala === DRIVER_ANNOTATIONS) + assert(configuredPod.pod.getSpec.getRestartPolicy === "Never") + + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true") + assert(featureStep.getAdditionalPodSystemProperties() === expectedSparkConf) + } + + test("Additional system properties resolve jars and set cluster-mode confs.") { + val allJars = Seq("local:///opt/spark/jar1.jar", "hdfs:///opt/spark/jar2.jar") + val allFiles = Seq("https://localhost:9000/file1.txt", "local:///opt/spark/file2.txt") + val sparkConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") + .setJars(allJars) + .set("spark.files", allFiles.mkString(",")) + .set(CONTAINER_IMAGE, "spark-driver:latest") + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, + APP_NAME, + MAIN_CLASS, + APP_ARGS), + RESOURCE_NAME_PREFIX, + APP_ID, + DRIVER_LABELS, + DRIVER_ANNOTATIONS, + Map.empty, + Map.empty) + val step = new BasicDriverFeatureStep(kubernetesConf) + val additionalProperties = step.getAdditionalPodSystemProperties() + val expectedSparkConf = Map( + KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", + "spark.app.id" -> APP_ID, + KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, + "spark.kubernetes.submitInDriver" -> "true", + "spark.jars" -> "/opt/spark/jar1.jar,hdfs:///opt/spark/jar2.jar", + "spark.files" -> "https://localhost:9000/file1.txt,/opt/spark/file2.txt") + assert(additionalProperties === expectedSparkConf) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala new file mode 100644 index 0000000000000..a764f7630b5c8 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/BasicExecutorFeatureStepSuite.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import scala.collection.JavaConverters._ + +import io.fabric8.kubernetes.api.model._ +import org.mockito.MockitoAnnotations +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.rpc.RpcEndpointAddress +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +class BasicExecutorFeatureStepSuite + extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { + + private val APP_ID = "app-id" + private val DRIVER_HOSTNAME = "localhost" + private val DRIVER_PORT = 7098 + private val DRIVER_ADDRESS = RpcEndpointAddress( + DRIVER_HOSTNAME, + DRIVER_PORT.toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + private val DRIVER_POD_NAME = "driver-pod" + + private val DRIVER_POD_UID = "driver-uid" + private val RESOURCE_NAME_PREFIX = "base" + private val EXECUTOR_IMAGE = "executor-image" + private val LABELS = Map("label1key" -> "label1value") + private val ANNOTATIONS = Map("annotation1key" -> "annotation1value") + private val TEST_IMAGE_PULL_SECRETS = Seq("my-1secret-1", "my-secret-2") + private val TEST_IMAGE_PULL_SECRET_OBJECTS = + TEST_IMAGE_PULL_SECRETS.map { secret => + new LocalObjectReferenceBuilder().withName(secret).build() + } + private val DRIVER_POD = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .withUid(DRIVER_POD_UID) + .endMetadata() + .withNewSpec() + .withNodeName("some-node") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private var baseConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + baseConf = new SparkConf() + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, RESOURCE_NAME_PREFIX) + .set(CONTAINER_IMAGE, EXECUTOR_IMAGE) + .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) + .set("spark.driver.host", DRIVER_HOSTNAME) + .set("spark.driver.port", DRIVER_PORT.toString) + .set(IMAGE_PULL_SECRETS, TEST_IMAGE_PULL_SECRETS.mkString(",")) + } + + test("basic executor pod has reasonable defaults") { + val step = new BasicExecutorFeatureStep( + KubernetesConf( + baseConf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + val executor = step.configurePod(SparkPod.initialPod()) + + // The executor pod name and default labels. + assert(executor.pod.getMetadata.getName === s"$RESOURCE_NAME_PREFIX-exec-1") + assert(executor.pod.getMetadata.getLabels.asScala === LABELS) + assert(executor.pod.getSpec.getImagePullSecrets.asScala === TEST_IMAGE_PULL_SECRET_OBJECTS) + + // There is exactly 1 container with no volume mounts and default memory limits. + // Default memory limit is 1024M + 384M (minimum overhead constant). + assert(executor.container.getImage === EXECUTOR_IMAGE) + assert(executor.container.getVolumeMounts.isEmpty) + assert(executor.container.getResources.getLimits.size() === 1) + assert(executor.container.getResources + .getLimits.get("memory").getAmount === "1408Mi") + + // The pod has no node selector, volumes. + assert(executor.pod.getSpec.getNodeSelector.isEmpty) + assert(executor.pod.getSpec.getVolumes.isEmpty) + + checkEnv(executor, Map()) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + test("executor pod hostnames get truncated to 63 characters") { + val conf = baseConf.clone() + val longPodNamePrefix = "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple" + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + longPodNamePrefix, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map.empty)) + assert(step.configurePod(SparkPod.initialPod()).pod.getSpec.getHostname.length === 63) + } + + test("classpath and extra java options get translated into environment variables") { + val conf = baseConf.clone() + conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") + conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") + + val step = new BasicExecutorFeatureStep( + KubernetesConf( + conf, + KubernetesExecutorSpecificConf("1", DRIVER_POD), + RESOURCE_NAME_PREFIX, + APP_ID, + LABELS, + ANNOTATIONS, + Map.empty, + Map("qux" -> "quux"))) + val executor = step.configurePod(SparkPod.initialPod()) + + checkEnv(executor, + Map("SPARK_JAVA_OPT_0" -> "foo=bar", + ENV_CLASSPATH -> "bar=baz", + "qux" -> "quux")) + checkOwnerReferences(executor.pod, DRIVER_POD_UID) + } + + // There is always exactly one controller reference, and it points to the driver pod. + private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { + assert(executor.getMetadata.getOwnerReferences.size() === 1) + assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) + assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) + } + + // Check that the expected environment variables are present. + private def checkEnv(executorPod: SparkPod, additionalEnvVars: Map[String, String]): Unit = { + val defaultEnvs = Map( + ENV_EXECUTOR_ID -> "1", + ENV_DRIVER_URL -> DRIVER_ADDRESS.toString, + ENV_EXECUTOR_CORES -> "1", + ENV_EXECUTOR_MEMORY -> "1g", + ENV_APPLICATION_ID -> APP_ID, + ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, + ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars + + assert(executorPod.container.getEnv.size() === defaultEnvs.size) + val mapEnvs = executorPod.container.getEnv.asScala.map { + x => (x.getName, x.getValue) + }.toMap + assert(defaultEnvs === mapEnvs) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala similarity index 67% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala index 64553d25883bb..9f817d3bfc79a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverKubernetesCredentialsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverKubernetesCredentialsFeatureStepSuite.scala @@ -14,34 +14,35 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features import java.io.File -import scala.collection.JavaConverters._ - import com.google.common.base.Charsets import com.google.common.io.{BaseEncoding, Files} import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder, Secret} +import org.mockito.{Mock, MockitoAnnotations} import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec import org.apache.spark.util.Utils -class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndAfter { +class DriverKubernetesCredentialsFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { private val KUBERNETES_RESOURCE_NAME_PREFIX = "spark" + private val APP_ID = "k8s-app" private var credentialsTempDirectory: File = _ - private val BASE_DRIVER_SPEC = new KubernetesDriverSpec( - driverPod = new PodBuilder().build(), - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) + private val BASE_DRIVER_POD = SparkPod.initialPod() + + @Mock + private var driverSpecificConf: KubernetesDriverSpecificConf = _ before { + MockitoAnnotations.initMocks(this) credentialsTempDirectory = Utils.createTempDir() } @@ -50,13 +51,19 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA } test("Don't set any credentials") { - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - new SparkConf(false), KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.isEmpty) + val kubernetesConf = KubernetesConf( + new SparkConf(false), + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalPodSystemProperties().isEmpty) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) } test("Only set credentials that are manually mounted.") { @@ -73,14 +80,23 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", "/mnt/secrets/my-ca.pem") + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver(BASE_DRIVER_SPEC) - assert(preparedDriverSpec.driverPod === BASE_DRIVER_SPEC.driverPod) - assert(preparedDriverSpec.driverContainer === BASE_DRIVER_SPEC.driverContainer) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === submissionSparkConf.getAll.toMap) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + assert(kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) === BASE_DRIVER_POD) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().isEmpty) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() + resolvedProperties.foreach { case (propKey, propValue) => + assert(submissionSparkConf.get(propKey) === propValue) + } } test("Mount credentials from the submission client as a secret.") { @@ -100,10 +116,17 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA .set( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX", caCertFile.getAbsolutePath) - val kubernetesCredentialsStep = new DriverKubernetesCredentialsStep( - submissionSparkConf, KUBERNETES_RESOURCE_NAME_PREFIX) - val preparedDriverSpec = kubernetesCredentialsStep.configureDriver( - BASE_DRIVER_SPEC.copy(driverSparkConf = submissionSparkConf)) + val kubernetesConf = KubernetesConf( + submissionSparkConf, + driverSpecificConf, + KUBERNETES_RESOURCE_NAME_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + val kubernetesCredentialsStep = new DriverKubernetesCredentialsFeatureStep(kubernetesConf) + val resolvedProperties = kubernetesCredentialsStep.getAdditionalPodSystemProperties() val expectedSparkConf = Map( s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$OAUTH_TOKEN_CONF_SUFFIX" -> "", s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$OAUTH_TOKEN_FILE_CONF_SUFFIX" -> @@ -113,16 +136,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> DRIVER_CREDENTIALS_CLIENT_CERT_PATH, s"$KUBERNETES_AUTH_DRIVER_MOUNTED_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - DRIVER_CREDENTIALS_CA_CERT_PATH, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_KEY_FILE_CONF_SUFFIX" -> - clientKeyFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CLIENT_CERT_FILE_CONF_SUFFIX" -> - clientCertFile.getAbsolutePath, - s"$KUBERNETES_AUTH_DRIVER_CONF_PREFIX.$CA_CERT_FILE_CONF_SUFFIX" -> - caCertFile.getAbsolutePath) - assert(preparedDriverSpec.driverSparkConf.getAll.toMap === expectedSparkConf) - assert(preparedDriverSpec.otherKubernetesResources.size === 1) - val credentialsSecret = preparedDriverSpec.otherKubernetesResources.head.asInstanceOf[Secret] + DRIVER_CREDENTIALS_CA_CERT_PATH) + assert(resolvedProperties === expectedSparkConf) + assert(kubernetesCredentialsStep.getAdditionalKubernetesResources().size === 1) + val credentialsSecret = kubernetesCredentialsStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Secret] assert(credentialsSecret.getMetadata.getName === s"$KUBERNETES_RESOURCE_NAME_PREFIX-kubernetes-credentials") val decodedSecretData = credentialsSecret.getData.asScala.map { data => @@ -134,12 +154,13 @@ class DriverKubernetesCredentialsStepSuite extends SparkFunSuite with BeforeAndA DRIVER_CREDENTIALS_CLIENT_KEY_SECRET_NAME -> "key", DRIVER_CREDENTIALS_CLIENT_CERT_SECRET_NAME -> "cert") assert(decodedSecretData === expectedSecretData) - val driverPodVolumes = preparedDriverSpec.driverPod.getSpec.getVolumes.asScala + val driverPod = kubernetesCredentialsStep.configurePod(BASE_DRIVER_POD) + val driverPodVolumes = driverPod.pod.getSpec.getVolumes.asScala assert(driverPodVolumes.size === 1) assert(driverPodVolumes.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverPodVolumes.head.getSecret != null) assert(driverPodVolumes.head.getSecret.getSecretName === credentialsSecret.getMetadata.getName) - val driverContainerVolumeMount = preparedDriverSpec.driverContainer.getVolumeMounts.asScala + val driverContainerVolumeMount = driverPod.container.getVolumeMounts.asScala assert(driverContainerVolumeMount.size === 1) assert(driverContainerVolumeMount.head.getName === DRIVER_CREDENTIALS_SECRET_VOLUME_NAME) assert(driverContainerVolumeMount.head.getMountPath === DRIVER_CREDENTIALS_SECRETS_BASE_DIR) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala new file mode 100644 index 0000000000000..c299d56865ec0 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/DriverServiceFeatureStepSuite.scala @@ -0,0 +1,227 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.Service +import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfter +import scala.collection.JavaConverters._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.Config._ +import org.apache.spark.deploy.k8s.Constants._ +import org.apache.spark.util.Clock + +class DriverServiceFeatureStepSuite extends SparkFunSuite with BeforeAndAfter { + + private val SHORT_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length) + + private val LONG_RESOURCE_NAME_PREFIX = + "a" * (DriverServiceFeatureStep.MAX_SERVICE_NAME_LENGTH - + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX.length + 1) + private val DRIVER_LABELS = Map( + "label1key" -> "label1value", + "label2key" -> "label2value") + + @Mock + private var clock: Clock = _ + + private var sparkConf: SparkConf = _ + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + } + + test("Headless service has a port for the driver RPC and the block manager.") { + sparkConf = sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + assert(configurationStep.configurePod(SparkPod.initialPod()) === SparkPod.initialPod()) + assert(configurationStep.getAdditionalKubernetesResources().size === 1) + assert(configurationStep.getAdditionalKubernetesResources().head.isInstanceOf[Service]) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + 9000, + 8080, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + driverService) + } + + test("Hostname and ports are set according to the service name.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf + .set("spark.driver.port", "9000") + .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) + .set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + + DriverServiceFeatureStep.DRIVER_SVC_POSTFIX + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Ports should resolve to defaults in SparkConf and in the service.") { + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + SHORT_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty)) + val resolvedService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + verifyService( + DEFAULT_DRIVER_PORT, + DEFAULT_BLOCKMANAGER_PORT, + s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}", + resolvedService) + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + assert(additionalProps("spark.driver.port") === DEFAULT_DRIVER_PORT.toString) + assert(additionalProps(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT.key) + === DEFAULT_BLOCKMANAGER_PORT.toString) + } + + test("Long prefixes should switch to using a generated name.") { + when(clock.getTimeMillis()).thenReturn(10000) + val configurationStep = new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + val driverService = configurationStep + .getAdditionalKubernetesResources() + .head + .asInstanceOf[Service] + val expectedServiceName = s"spark-10000${DriverServiceFeatureStep.DRIVER_SVC_POSTFIX}" + assert(driverService.getMetadata.getName === expectedServiceName) + val expectedHostName = s"$expectedServiceName.my-namespace.svc" + val additionalProps = configurationStep.getAdditionalPodSystemProperties() + verifySparkConfHostNames(additionalProps, expectedHostName) + } + + test("Disallow bind address and driver host to be set explicitly.") { + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver bind address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_BIND_ADDRESS_KEY} is" + + " not supported in Kubernetes mode, as the driver's bind address is managed" + + " and set to the driver pod's IP address.") + } + sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) + sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") + try { + new DriverServiceFeatureStep( + KubernetesConf( + sparkConf, + KubernetesDriverSpecificConf( + None, "main", "app", Seq.empty), + LONG_RESOURCE_NAME_PREFIX, + "app-id", + DRIVER_LABELS, + Map.empty, + Map.empty, + Map.empty), + clock) + fail("The driver host address should not be allowed.") + } catch { + case e: Throwable => + assert(e.getMessage === + s"requirement failed: ${DriverServiceFeatureStep.DRIVER_HOST_KEY} is" + + " not supported in Kubernetes mode, as the driver's hostname will be managed via" + + " a Kubernetes service.") + } + } + + private def verifyService( + driverPort: Int, + blockManagerPort: Int, + expectedServiceName: String, + service: Service): Unit = { + assert(service.getMetadata.getName === expectedServiceName) + assert(service.getSpec.getClusterIP === "None") + assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) + assert(service.getSpec.getPorts.size() === 2) + val driverServicePorts = service.getSpec.getPorts.asScala + assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) + assert(driverServicePorts.head.getPort.intValue() === driverPort) + assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) + assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) + assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) + assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) + } + + private def verifySparkConfHostNames( + driverSparkConf: Map[String, String], expectedHostName: String): Unit = { + assert(driverSparkConf( + org.apache.spark.internal.config.DRIVER_HOST_ADDRESS.key) === expectedHostName) + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala new file mode 100644 index 0000000000000..27bff74ce38af --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/KubernetesFeaturesTestUtils.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.{HasMetadata, PodBuilder, SecretBuilder} +import org.mockito.Matchers +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark.deploy.k8s.SparkPod + +object KubernetesFeaturesTestUtils { + + def getMockConfigStepForStepType[T <: KubernetesFeatureConfigStep]( + stepType: String, stepClass: Class[T]): T = { + val mockStep = mock(stepClass) + when(mockStep.getAdditionalKubernetesResources()).thenReturn( + getSecretsForStepType(stepType)) + + when(mockStep.getAdditionalPodSystemProperties()) + .thenReturn(Map(stepType -> stepType)) + when(mockStep.configurePod(Matchers.any(classOf[SparkPod]))) + .thenAnswer(new Answer[SparkPod]() { + override def answer(invocation: InvocationOnMock): SparkPod = { + val originalPod = invocation.getArgumentAt(0, classOf[SparkPod]) + val configuredPod = new PodBuilder(originalPod.pod) + .editOrNewMetadata() + .addToLabels(stepType, stepType) + .endMetadata() + .build() + SparkPod(configuredPod, originalPod.container) + } + }) + mockStep + } + + def getSecretsForStepType[T <: KubernetesFeatureConfigStep](stepType: String) + : Seq[HasMetadata] = { + Seq(new SecretBuilder() + .withNewMetadata() + .withName(stepType) + .endMetadata() + .build()) + } + +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala similarity index 64% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala index 960d0bda1d011..9d02f56cc206d 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/features/MountSecretsFeatureStepSuite.scala @@ -14,29 +14,38 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps +package org.apache.spark.deploy.k8s.features + +import io.fabric8.kubernetes.api.model.PodBuilder import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SecretVolumeUtils, SparkPod} -class DriverMountSecretsStepSuite extends SparkFunSuite { +class MountSecretsFeatureStepSuite extends SparkFunSuite { private val SECRET_FOO = "foo" private val SECRET_BAR = "bar" private val SECRET_MOUNT_PATH = "/etc/secrets/driver" test("mounts all given secrets") { - val baseDriverSpec = KubernetesDriverSpec.initialSpec(new SparkConf(false)) + val baseDriverPod = SparkPod.initialPod() val secretNamesToMountPaths = Map( SECRET_FOO -> SECRET_MOUNT_PATH, SECRET_BAR -> SECRET_MOUNT_PATH) + val sparkConf = new SparkConf(false) + val kubernetesConf = KubernetesConf( + sparkConf, + KubernetesExecutorSpecificConf("1", new PodBuilder().build()), + "resource-name-prefix", + "app-id", + Map.empty, + Map.empty, + secretNamesToMountPaths, + Map.empty) - val mountSecretsBootstrap = new MountSecretsBootstrap(secretNamesToMountPaths) - val mountSecretsStep = new DriverMountSecretsStep(mountSecretsBootstrap) - val configuredDriverSpec = mountSecretsStep.configureDriver(baseDriverSpec) - val driverPodWithSecretsMounted = configuredDriverSpec.driverPod - val driverContainerWithSecretsMounted = configuredDriverSpec.driverContainer + val step = new MountSecretsFeatureStep(kubernetesConf) + val driverPodWithSecretsMounted = step.configurePod(baseDriverPod).pod + val driverContainerWithSecretsMounted = step.configurePod(baseDriverPod).container Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach { volumeName => assert(SecretVolumeUtils.podHasVolume(driverPodWithSecretsMounted, volumeName)) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala index 6a501592f42a3..c1b203e03a357 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/ClientSuite.scala @@ -16,22 +16,17 @@ */ package org.apache.spark.deploy.k8s.submit -import scala.collection.JavaConverters._ - -import com.google.common.collect.Iterables import io.fabric8.kubernetes.api.model._ import io.fabric8.kubernetes.client.{KubernetesClient, Watch} import io.fabric8.kubernetes.client.dsl.{MixedOperation, NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable, PodResource} import org.mockito.{ArgumentCaptor, Mock, MockitoAnnotations} import org.mockito.Mockito.{doReturn, verify, when} -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter import org.scalatest.mockito.MockitoSugar._ import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.steps.DriverConfigurationStep class ClientSuite extends SparkFunSuite with BeforeAndAfter { @@ -39,6 +34,74 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { private val DRIVER_POD_API_VERSION = "v1" private val DRIVER_POD_KIND = "pod" private val KUBERNETES_RESOURCE_PREFIX = "resource-example" + private val POD_NAME = "driver" + private val CONTAINER_NAME = "container" + private val APP_ID = "app-id" + private val APP_NAME = "app" + private val MAIN_CLASS = "main" + private val APP_ARGS = Seq("arg1", "arg2") + private val RESOLVED_JAVA_OPTIONS = Map( + "conf1key" -> "conf1value", + "conf2key" -> "conf2value") + private val BUILT_DRIVER_POD = + new PodBuilder() + .withNewMetadata() + .withName(POD_NAME) + .endMetadata() + .withNewSpec() + .withHostname("localhost") + .endSpec() + .build() + private val BUILT_DRIVER_CONTAINER = new ContainerBuilder().withName(CONTAINER_NAME).build() + private val ADDITIONAL_RESOURCES = Seq( + new SecretBuilder().withNewMetadata().withName("secret").endMetadata().build()) + + private val BUILT_KUBERNETES_SPEC = KubernetesDriverSpec( + SparkPod(BUILT_DRIVER_POD, BUILT_DRIVER_CONTAINER), + ADDITIONAL_RESOURCES, + RESOLVED_JAVA_OPTIONS) + + private val FULL_EXPECTED_CONTAINER = new ContainerBuilder(BUILT_DRIVER_CONTAINER) + .addNewEnv() + .withName(ENV_SPARK_CONF_DIR) + .withValue(SPARK_CONF_DIR_INTERNAL) + .endEnv() + .addNewVolumeMount() + .withName(SPARK_CONF_VOLUME) + .withMountPath(SPARK_CONF_DIR_INTERNAL) + .endVolumeMount() + .build() + private val FULL_EXPECTED_POD = new PodBuilder(BUILT_DRIVER_POD) + .editSpec() + .addToContainers(FULL_EXPECTED_CONTAINER) + .addNewVolume() + .withName(SPARK_CONF_VOLUME) + .withNewConfigMap().withName(s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map").endConfigMap() + .endVolume() + .endSpec() + .build() + + private val POD_WITH_OWNER_REFERENCE = new PodBuilder(FULL_EXPECTED_POD) + .editMetadata() + .withUid(DRIVER_POD_UID) + .endMetadata() + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .build() + + private val ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES = ADDITIONAL_RESOURCES.map { secret => + new SecretBuilder(secret) + .editMetadata() + .addNewOwnerReference() + .withName(POD_NAME) + .withApiVersion(DRIVER_POD_API_VERSION) + .withKind(DRIVER_POD_KIND) + .withController(true) + .withUid(DRIVER_POD_UID) + .endOwnerReference() + .endMetadata() + .build() + } private type ResourceList = NamespaceListVisitFromServerGetDeleteRecreateWaitApplicable[ HasMetadata, Boolean] @@ -56,113 +119,86 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { @Mock private var loggingPodStatusWatcher: LoggingPodStatusWatcher = _ + @Mock + private var driverBuilder: KubernetesDriverBuilder = _ + @Mock private var resourceList: ResourceList = _ - private val submissionSteps = Seq(FirstTestConfigurationStep, SecondTestConfigurationStep) + private var kubernetesConf: KubernetesConf[KubernetesDriverSpecificConf] = _ + + private var sparkConf: SparkConf = _ private var createdPodArgumentCaptor: ArgumentCaptor[Pod] = _ private var createdResourcesArgumentCaptor: ArgumentCaptor[HasMetadata] = _ - private var createdContainerArgumentCaptor: ArgumentCaptor[Container] = _ before { MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf(false) + kubernetesConf = KubernetesConf[KubernetesDriverSpecificConf]( + sparkConf, + KubernetesDriverSpecificConf(None, MAIN_CLASS, APP_NAME, APP_ARGS), + KUBERNETES_RESOURCE_PREFIX, + APP_ID, + Map.empty, + Map.empty, + Map.empty, + Map.empty) + when(driverBuilder.buildFromFeatures(kubernetesConf)).thenReturn(BUILT_KUBERNETES_SPEC) when(kubernetesClient.pods()).thenReturn(podOperations) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.withName(POD_NAME)).thenReturn(namedPods) createdPodArgumentCaptor = ArgumentCaptor.forClass(classOf[Pod]) createdResourcesArgumentCaptor = ArgumentCaptor.forClass(classOf[HasMetadata]) - when(podOperations.create(createdPodArgumentCaptor.capture())).thenAnswer(new Answer[Pod] { - override def answer(invocation: InvocationOnMock): Pod = { - new PodBuilder(invocation.getArgumentAt(0, classOf[Pod])) - .editMetadata() - .withUid(DRIVER_POD_UID) - .endMetadata() - .withApiVersion(DRIVER_POD_API_VERSION) - .withKind(DRIVER_POD_KIND) - .build() - } - }) - when(podOperations.withName(FirstTestConfigurationStep.podName)).thenReturn(namedPods) + when(podOperations.create(FULL_EXPECTED_POD)).thenReturn(POD_WITH_OWNER_REFERENCE) when(namedPods.watch(loggingPodStatusWatcher)).thenReturn(mock[Watch]) doReturn(resourceList) .when(kubernetesClient) .resourceList(createdResourcesArgumentCaptor.capture()) } - test("The client should configure the pod with the submission steps.") { + test("The client should configure the pod using the builder.") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue - assert(createdPod.getMetadata.getName === FirstTestConfigurationStep.podName) - assert(createdPod.getMetadata.getLabels.asScala === - Map(FirstTestConfigurationStep.labelKey -> FirstTestConfigurationStep.labelValue)) - assert(createdPod.getMetadata.getAnnotations.asScala === - Map(SecondTestConfigurationStep.annotationKey -> - SecondTestConfigurationStep.annotationValue)) - assert(createdPod.getSpec.getContainers.size() === 1) - assert(createdPod.getSpec.getContainers.get(0).getName === - SecondTestConfigurationStep.containerName) + verify(podOperations).create(FULL_EXPECTED_POD) } test("The client should create Kubernetes resources") { - val EXAMPLE_JAVA_OPTS = "-XX:+HeapDumpOnOutOfMemoryError -XX:+PrintGCDetails" - val EXPECTED_JAVA_OPTS = "-XX\\:+HeapDumpOnOutOfMemoryError -XX\\:+PrintGCDetails" val submissionClient = new Client( - submissionSteps, - new SparkConf(false) - .set(org.apache.spark.internal.config.DRIVER_JAVA_OPTIONS, EXAMPLE_JAVA_OPTS), + driverBuilder, + kubernetesConf, kubernetesClient, false, "spark", loggingPodStatusWatcher, KUBERNETES_RESOURCE_PREFIX) submissionClient.run() - val createdPod = createdPodArgumentCaptor.getValue val otherCreatedResources = createdResourcesArgumentCaptor.getAllValues assert(otherCreatedResources.size === 2) - val secrets = otherCreatedResources.toArray - .filter(_.isInstanceOf[Secret]).map(_.asInstanceOf[Secret]) + val secrets = otherCreatedResources.toArray.filter(_.isInstanceOf[Secret]).toSeq + assert(secrets === ADDITIONAL_RESOURCES_WITH_OWNER_REFERENCES) val configMaps = otherCreatedResources.toArray .filter(_.isInstanceOf[ConfigMap]).map(_.asInstanceOf[ConfigMap]) assert(secrets.nonEmpty) - val secret = secrets.head - assert(secret.getMetadata.getName === FirstTestConfigurationStep.secretName) - assert(secret.getData.asScala === - Map(FirstTestConfigurationStep.secretKey -> FirstTestConfigurationStep.secretData)) - val ownerReference = Iterables.getOnlyElement(secret.getMetadata.getOwnerReferences) - assert(ownerReference.getName === createdPod.getMetadata.getName) - assert(ownerReference.getKind === DRIVER_POD_KIND) - assert(ownerReference.getUid === DRIVER_POD_UID) - assert(ownerReference.getApiVersion === DRIVER_POD_API_VERSION) assert(configMaps.nonEmpty) val configMap = configMaps.head assert(configMap.getMetadata.getName === s"$KUBERNETES_RESOURCE_PREFIX-driver-conf-map") assert(configMap.getData.containsKey(SPARK_CONF_FILE_NAME)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains(EXPECTED_JAVA_OPTS)) - assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains( - "spark.custom-conf=custom-conf-value")) - val driverContainer = Iterables.getOnlyElement(createdPod.getSpec.getContainers) - assert(driverContainer.getName === SecondTestConfigurationStep.containerName) - val driverEnv = driverContainer.getEnv.asScala.head - assert(driverEnv.getName === ENV_SPARK_CONF_DIR) - assert(driverEnv.getValue === SPARK_CONF_DIR_INTERNAL) - val driverMount = driverContainer.getVolumeMounts.asScala.head - assert(driverMount.getName === SPARK_CONF_VOLUME) - assert(driverMount.getMountPath === SPARK_CONF_DIR_INTERNAL) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf1key=conf1value")) + assert(configMap.getData.get(SPARK_CONF_FILE_NAME).contains("conf2key=conf2value")) } test("Waiting for app completion should stall on the watcher") { val submissionClient = new Client( - submissionSteps, - new SparkConf(false), + driverBuilder, + kubernetesConf, kubernetesClient, true, "spark", @@ -171,56 +207,4 @@ class ClientSuite extends SparkFunSuite with BeforeAndAfter { submissionClient.run() verify(loggingPodStatusWatcher).awaitCompletion() } - -} - -private object FirstTestConfigurationStep extends DriverConfigurationStep { - - val podName = "test-pod" - val secretName = "test-secret" - val labelKey = "first-submit" - val labelValue = "true" - val secretKey = "secretKey" - val secretData = "secretData" - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .withName(podName) - .addToLabels(labelKey, labelValue) - .endMetadata() - .build() - val additionalResource = new SecretBuilder() - .withNewMetadata() - .withName(secretName) - .endMetadata() - .addToData(secretKey, secretData) - .build() - driverSpec.copy( - driverPod = modifiedPod, - otherKubernetesResources = driverSpec.otherKubernetesResources ++ Seq(additionalResource)) - } -} - -private object SecondTestConfigurationStep extends DriverConfigurationStep { - val annotationKey = "second-submit" - val annotationValue = "submitted" - val sparkConfKey = "spark.custom-conf" - val sparkConfValue = "custom-conf-value" - val containerName = "driverContainer" - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val modifiedPod = new PodBuilder(driverSpec.driverPod) - .editMetadata() - .addToAnnotations(annotationKey, annotationValue) - .endMetadata() - .build() - val resolvedSparkConf = driverSpec.driverSparkConf.clone().set(sparkConfKey, sparkConfValue) - val modifiedContainer = new ContainerBuilder(driverSpec.driverContainer) - .withName(containerName) - .build() - driverSpec.copy( - driverPod = modifiedPod, - driverSparkConf = resolvedSparkConf, - driverContainer = modifiedContainer) - } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala deleted file mode 100644 index df34d2dbcb5be..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestratorSuite.scala +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit - -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.submit.steps._ - -class DriverConfigOrchestratorSuite extends SparkFunSuite { - - private val DRIVER_IMAGE = "driver-image" - private val IC_IMAGE = "init-container-image" - private val APP_ID = "spark-app-id" - private val KUBERNETES_RESOURCE_PREFIX = "example-prefix" - private val APP_NAME = "spark" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2") - private val SECRET_FOO = "foo" - private val SECRET_BAR = "bar" - private val SECRET_MOUNT_PATH = "/etc/secrets/driver" - - test("Base submission steps with a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep]) - } - - test("Base submission steps without a main app resource.") { - val sparkConf = new SparkConf(false).set(CONTAINER_IMAGE, DRIVER_IMAGE) - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Option.empty, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep]) - } - - test("Submission steps with driver secrets to mount") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_FOO", SECRET_MOUNT_PATH) - .set(s"$KUBERNETES_DRIVER_SECRETS_PREFIX$SECRET_BAR", SECRET_MOUNT_PATH) - val mainAppResource = JavaMainAppResource("local:///var/apps/jars/main.jar") - val orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(mainAppResource), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - validateStepTypes( - orchestrator, - classOf[BasicDriverConfigurationStep], - classOf[DriverServiceBootstrapStep], - classOf[DriverKubernetesCredentialsStep], - classOf[DependencyResolutionStep], - classOf[DriverMountSecretsStep]) - } - - test("Submission using client local dependencies") { - val sparkConf = new SparkConf(false) - .set(CONTAINER_IMAGE, DRIVER_IMAGE) - var orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("file:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - - sparkConf.set("spark.files", "/path/to/file1,/path/to/file2") - orchestrator = new DriverConfigOrchestrator( - APP_ID, - KUBERNETES_RESOURCE_PREFIX, - Some(JavaMainAppResource("local:///var/apps/jars/main.jar")), - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - assertThrows[SparkException] { - orchestrator.getAllConfigurationSteps - } - } - - private def validateStepTypes( - orchestrator: DriverConfigOrchestrator, - types: Class[_ <: DriverConfigurationStep]*): Unit = { - val steps = orchestrator.getAllConfigurationSteps - assert(steps.size === types.size) - assert(steps.map(_.getClass) === types) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala new file mode 100644 index 0000000000000..161f9afe7bba9 --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/KubernetesDriverBuilderSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.deploy.k8s.submit + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesDriverSpec, KubernetesDriverSpecificConf} +import org.apache.spark.deploy.k8s.features.{BasicDriverFeatureStep, DriverKubernetesCredentialsFeatureStep, DriverServiceFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesDriverBuilderSuite extends SparkFunSuite { + + private val BASIC_STEP_TYPE = "basic" + private val CREDENTIALS_STEP_TYPE = "credentials" + private val SERVICE_STEP_TYPE = "service" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicDriverFeatureStep]) + + private val credentialsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + CREDENTIALS_STEP_TYPE, classOf[DriverKubernetesCredentialsFeatureStep]) + + private val serviceStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SERVICE_STEP_TYPE, classOf[DriverServiceFeatureStep]) + + private val secretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest: KubernetesDriverBuilder = + new KubernetesDriverBuilder( + _ => basicFeatureStep, + _ => credentialsStep, + _ => serviceStep, + _ => secretsStep) + + test("Apply fundamental steps all the time.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesDriverSpecificConf( + None, + "test-app", + "main", + Seq.empty), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + CREDENTIALS_STEP_TYPE, + SERVICE_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedSpec: KubernetesDriverSpec, stepTypes: String*) + : Unit = { + assert(resolvedSpec.systemProperties.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedSpec.pod.pod.getMetadata.getLabels.get(stepType) === stepType) + assert(resolvedSpec.driverKubernetesResources.containsSlice( + KubernetesFeaturesTestUtils.getSecretsForStepType(stepType))) + assert(resolvedSpec.systemProperties(stepType) === stepType) + } + } +} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala deleted file mode 100644 index ee450fff8d376..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class BasicDriverConfigurationStepSuite extends SparkFunSuite { - - private val APP_ID = "spark-app-id" - private val RESOURCE_NAME_PREFIX = "spark" - private val DRIVER_LABELS = Map("labelkey" -> "labelvalue") - private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" - private val APP_NAME = "spark-test" - private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") - private val CUSTOM_ANNOTATION_KEY = "customAnnotation" - private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" - private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" - private val DRIVER_CUSTOM_ENV_KEY2 = "customDriverEnv2" - - test("Set all possible configurations from the user.") { - val sparkConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, "spark-driver-pod") - .set(org.apache.spark.internal.config.DRIVER_CLASS_PATH, "/opt/spark/spark-examples.jar") - .set("spark.driver.cores", "2") - .set(KUBERNETES_DRIVER_LIMIT_CORES, "4") - .set(org.apache.spark.internal.config.DRIVER_MEMORY.key, "256M") - .set(org.apache.spark.internal.config.DRIVER_MEMORY_OVERHEAD, 200L) - .set(CONTAINER_IMAGE, "spark-driver:latest") - .set(s"$KUBERNETES_DRIVER_ANNOTATION_PREFIX$CUSTOM_ANNOTATION_KEY", CUSTOM_ANNOTATION_VALUE) - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY1", "customDriverEnv1") - .set(s"$KUBERNETES_DRIVER_ENV_KEY$DRIVER_CUSTOM_ENV_KEY2", "customDriverEnv2") - .set(IMAGE_PULL_SECRETS, "imagePullSecret1, imagePullSecret2") - - val submissionStep = new BasicDriverConfigurationStep( - APP_ID, - RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - CONTAINER_IMAGE_PULL_POLICY, - APP_NAME, - MAIN_CLASS, - APP_ARGS, - sparkConf) - val basePod = new PodBuilder().withNewMetadata().endMetadata().withNewSpec().endSpec().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = basePod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = submissionStep.configureDriver(baseDriverSpec) - - assert(preparedDriverSpec.driverContainer.getName === DRIVER_CONTAINER_NAME) - assert(preparedDriverSpec.driverContainer.getImage === "spark-driver:latest") - assert(preparedDriverSpec.driverContainer.getImagePullPolicy === CONTAINER_IMAGE_PULL_POLICY) - - assert(preparedDriverSpec.driverContainer.getEnv.size === 4) - val envs = preparedDriverSpec.driverContainer - .getEnv - .asScala - .map(env => (env.getName, env.getValue)) - .toMap - assert(envs(ENV_CLASSPATH) === "/opt/spark/spark-examples.jar") - assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") - assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") - - assert(preparedDriverSpec.driverContainer.getEnv.asScala.exists(envVar => - envVar.getName.equals(ENV_DRIVER_BIND_ADDRESS) && - envVar.getValueFrom.getFieldRef.getApiVersion.equals("v1") && - envVar.getValueFrom.getFieldRef.getFieldPath.equals("status.podIP"))) - - val resourceRequirements = preparedDriverSpec.driverContainer.getResources - val requests = resourceRequirements.getRequests.asScala - assert(requests("cpu").getAmount === "2") - assert(requests("memory").getAmount === "456Mi") - val limits = resourceRequirements.getLimits.asScala - assert(limits("memory").getAmount === "456Mi") - assert(limits("cpu").getAmount === "4") - - val driverPodMetadata = preparedDriverSpec.driverPod.getMetadata - assert(driverPodMetadata.getName === "spark-driver-pod") - assert(driverPodMetadata.getLabels.asScala === DRIVER_LABELS) - val expectedAnnotations = Map( - CUSTOM_ANNOTATION_KEY -> CUSTOM_ANNOTATION_VALUE, - SPARK_APP_NAME_ANNOTATION -> APP_NAME) - assert(driverPodMetadata.getAnnotations.asScala === expectedAnnotations) - - val driverPodSpec = preparedDriverSpec.driverPod.getSpec - assert(driverPodSpec.getRestartPolicy === "Never") - assert(driverPodSpec.getImagePullSecrets.size() === 2) - assert(driverPodSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(driverPodSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - val resolvedSparkConf = preparedDriverSpec.driverSparkConf.getAll.toMap - val expectedSparkConf = Map( - KUBERNETES_DRIVER_POD_NAME.key -> "spark-driver-pod", - "spark.app.id" -> APP_ID, - KUBERNETES_EXECUTOR_POD_NAME_PREFIX.key -> RESOURCE_NAME_PREFIX, - "spark.kubernetes.submitInDriver" -> "true") - assert(resolvedSparkConf === expectedSparkConf) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala deleted file mode 100644 index ca43fc97dc991..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DependencyResolutionStepSuite.scala +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import java.io.File - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.{ContainerBuilder, HasMetadata, PodBuilder} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec - -class DependencyResolutionStepSuite extends SparkFunSuite { - - private val SPARK_JARS = Seq( - "apps/jars/jar1.jar", - "local:///var/apps/jars/jar2.jar") - - private val SPARK_FILES = Seq( - "apps/files/file1.txt", - "local:///var/apps/files/file2.txt") - - test("Added dependencies should be resolved in Spark configuration and environment") { - val dependencyResolutionStep = new DependencyResolutionStep( - SPARK_JARS, - SPARK_FILES) - val driverPod = new PodBuilder().build() - val baseDriverSpec = KubernetesDriverSpec( - driverPod = driverPod, - driverContainer = new ContainerBuilder().build(), - driverSparkConf = new SparkConf(false), - otherKubernetesResources = Seq.empty[HasMetadata]) - val preparedDriverSpec = dependencyResolutionStep.configureDriver(baseDriverSpec) - assert(preparedDriverSpec.driverPod === driverPod) - assert(preparedDriverSpec.otherKubernetesResources.isEmpty) - val resolvedSparkJars = preparedDriverSpec.driverSparkConf.get("spark.jars").split(",").toSet - val expectedResolvedSparkJars = Set( - "apps/jars/jar1.jar", - "/var/apps/jars/jar2.jar") - assert(resolvedSparkJars === expectedResolvedSparkJars) - val resolvedSparkFiles = preparedDriverSpec.driverSparkConf.get("spark.files").split(",").toSet - val expectedResolvedSparkFiles = Set( - "apps/files/file1.txt", - "/var/apps/files/file2.txt") - assert(resolvedSparkFiles === expectedResolvedSparkFiles) - val driverEnv = preparedDriverSpec.driverContainer.getEnv.asScala - assert(driverEnv.size === 1) - assert(driverEnv.head.getName === ENV_MOUNTED_CLASSPATH) - val resolvedDriverClasspath = driverEnv.head.getValue.split(File.pathSeparator).toSet - val expectedResolvedDriverClasspath = expectedResolvedSparkJars - assert(resolvedDriverClasspath === expectedResolvedDriverClasspath) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala deleted file mode 100644 index 78c8c3ba1afbd..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverServiceBootstrapStepSuite.scala +++ /dev/null @@ -1,180 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.deploy.k8s.submit.steps - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model.Service -import org.mockito.{Mock, MockitoAnnotations} -import org.mockito.Mockito.when -import org.scalatest.BeforeAndAfter - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec -import org.apache.spark.util.Clock - -class DriverServiceBootstrapStepSuite extends SparkFunSuite with BeforeAndAfter { - - private val SHORT_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length) - - private val LONG_RESOURCE_NAME_PREFIX = - "a" * (DriverServiceBootstrapStep.MAX_SERVICE_NAME_LENGTH - - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX.length + 1) - private val DRIVER_LABELS = Map( - "label1key" -> "label1value", - "label2key" -> "label2value") - - @Mock - private var clock: Clock = _ - - private var sparkConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - sparkConf = new SparkConf(false) - } - - test("Headless service has a port for the driver RPC and the block manager.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - assert(resolvedDriverSpec.otherKubernetesResources.size === 1) - assert(resolvedDriverSpec.otherKubernetesResources.head.isInstanceOf[Service]) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - verifyService( - 9000, - 8080, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - driverService) - } - - test("Hostname and ports are set according to the service name.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf - .set("spark.driver.port", "9000") - .set(org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT, 8080) - .set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val expectedServiceName = SHORT_RESOURCE_NAME_PREFIX + - DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Ports should resolve to defaults in SparkConf and in the service.") { - val configurationStep = new DriverServiceBootstrapStep( - SHORT_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf, - clock) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - verifyService( - DEFAULT_DRIVER_PORT, - DEFAULT_BLOCKMANAGER_PORT, - s"$SHORT_RESOURCE_NAME_PREFIX${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}", - resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service]) - assert(resolvedDriverSpec.driverSparkConf.get("spark.driver.port") === - DEFAULT_DRIVER_PORT.toString) - assert(resolvedDriverSpec.driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_BLOCK_MANAGER_PORT) === DEFAULT_BLOCKMANAGER_PORT) - } - - test("Long prefixes should switch to using a generated name.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(KUBERNETES_NAMESPACE, "my-namespace"), - clock) - when(clock.getTimeMillis()).thenReturn(10000) - val baseDriverSpec = KubernetesDriverSpec.initialSpec(sparkConf.clone()) - val resolvedDriverSpec = configurationStep.configureDriver(baseDriverSpec) - val driverService = resolvedDriverSpec.otherKubernetesResources.head.asInstanceOf[Service] - val expectedServiceName = s"spark-10000${DriverServiceBootstrapStep.DRIVER_SVC_POSTFIX}" - assert(driverService.getMetadata.getName === expectedServiceName) - val expectedHostName = s"$expectedServiceName.my-namespace.svc" - verifySparkConfHostNames(resolvedDriverSpec.driverSparkConf, expectedHostName) - } - - test("Disallow bind address and driver host to be set explicitly.") { - val configurationStep = new DriverServiceBootstrapStep( - LONG_RESOURCE_NAME_PREFIX, - DRIVER_LABELS, - sparkConf.set(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS, "host"), - clock) - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver bind address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_BIND_ADDRESS_KEY} is" + - " not supported in Kubernetes mode, as the driver's bind address is managed" + - " and set to the driver pod's IP address.") - } - sparkConf.remove(org.apache.spark.internal.config.DRIVER_BIND_ADDRESS) - sparkConf.set(org.apache.spark.internal.config.DRIVER_HOST_ADDRESS, "host") - try { - configurationStep.configureDriver(KubernetesDriverSpec.initialSpec(sparkConf)) - fail("The driver host address should not be allowed.") - } catch { - case e: Throwable => - assert(e.getMessage === - s"requirement failed: ${DriverServiceBootstrapStep.DRIVER_HOST_KEY} is" + - " not supported in Kubernetes mode, as the driver's hostname will be managed via" + - " a Kubernetes service.") - } - } - - private def verifyService( - driverPort: Int, - blockManagerPort: Int, - expectedServiceName: String, - service: Service): Unit = { - assert(service.getMetadata.getName === expectedServiceName) - assert(service.getSpec.getClusterIP === "None") - assert(service.getSpec.getSelector.asScala === DRIVER_LABELS) - assert(service.getSpec.getPorts.size() === 2) - val driverServicePorts = service.getSpec.getPorts.asScala - assert(driverServicePorts.head.getName === DRIVER_PORT_NAME) - assert(driverServicePorts.head.getPort.intValue() === driverPort) - assert(driverServicePorts.head.getTargetPort.getIntVal === driverPort) - assert(driverServicePorts(1).getName === BLOCK_MANAGER_PORT_NAME) - assert(driverServicePorts(1).getPort.intValue() === blockManagerPort) - assert(driverServicePorts(1).getTargetPort.getIntVal === blockManagerPort) - } - - private def verifySparkConfHostNames( - driverSparkConf: SparkConf, expectedHostName: String): Unit = { - assert(driverSparkConf.get( - org.apache.spark.internal.config.DRIVER_HOST_ADDRESS) === expectedHostName) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala deleted file mode 100644 index d73df20f0f956..0000000000000 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.scheduler.cluster.k8s - -import scala.collection.JavaConverters._ - -import io.fabric8.kubernetes.api.model._ -import org.mockito.MockitoAnnotations -import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} - -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.Config._ -import org.apache.spark.deploy.k8s.Constants._ -import org.apache.spark.deploy.k8s.MountSecretsBootstrap - -class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterEach { - - private val driverPodName: String = "driver-pod" - private val driverPodUid: String = "driver-uid" - private val executorPrefix: String = "base" - private val executorImage: String = "executor-image" - private val imagePullSecrets: String = "imagePullSecret1, imagePullSecret2" - private val driverPod = new PodBuilder() - .withNewMetadata() - .withName(driverPodName) - .withUid(driverPodUid) - .endMetadata() - .withNewSpec() - .withNodeName("some-node") - .endSpec() - .withNewStatus() - .withHostIP("192.168.99.100") - .endStatus() - .build() - private var baseConf: SparkConf = _ - - before { - MockitoAnnotations.initMocks(this) - baseConf = new SparkConf() - .set(KUBERNETES_DRIVER_POD_NAME, driverPodName) - .set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, executorPrefix) - .set(CONTAINER_IMAGE, executorImage) - .set(KUBERNETES_DRIVER_SUBMIT_CHECK, true) - .set(IMAGE_PULL_SECRETS, imagePullSecrets) - } - - test("basic executor pod has reasonable defaults") { - val factory = new ExecutorPodFactory(baseConf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - // The executor pod name and default labels. - assert(executor.getMetadata.getName === s"$executorPrefix-exec-1") - assert(executor.getMetadata.getLabels.size() === 3) - assert(executor.getMetadata.getLabels.get(SPARK_EXECUTOR_ID_LABEL) === "1") - - // There is exactly 1 container with no volume mounts and default memory limits and requests. - // Default memory limit/request is 1024M + 384M (minimum overhead constant). - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getImage === executorImage) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.isEmpty) - assert(executor.getSpec.getContainers.get(0).getResources.getLimits.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources - .getRequests.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getContainers.get(0).getResources - .getLimits.get("memory").getAmount === "1408Mi") - assert(executor.getSpec.getImagePullSecrets.size() === 2) - assert(executor.getSpec.getImagePullSecrets.get(0).getName === "imagePullSecret1") - assert(executor.getSpec.getImagePullSecrets.get(1).getName === "imagePullSecret2") - - // The pod has no node selector, volumes. - assert(executor.getSpec.getNodeSelector.isEmpty) - assert(executor.getSpec.getVolumes.isEmpty) - - checkEnv(executor, Map()) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor core request specification") { - var factory = new ExecutorPodFactory(baseConf, None) - var executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "1") - - val conf = baseConf.clone() - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "0.1") - factory = new ExecutorPodFactory(conf, None) - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "0.1") - - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - factory = new ExecutorPodFactory(conf, None) - conf.set(KUBERNETES_EXECUTOR_REQUEST_CORES, "100m") - executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - assert(executor.getSpec.getContainers.get(0).getResources.getRequests.get("cpu").getAmount - === "100m") - } - - test("executor pod hostnames get truncated to 63 characters") { - val conf = baseConf.clone() - conf.set(KUBERNETES_EXECUTOR_POD_NAME_PREFIX, - "loremipsumdolorsitametvimatelitrefficiendisuscipianturvixlegeresple") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getHostname.length === 63) - } - - test("classpath and extra java options get translated into environment variables") { - val conf = baseConf.clone() - conf.set(org.apache.spark.internal.config.EXECUTOR_JAVA_OPTIONS, "foo=bar") - conf.set(org.apache.spark.internal.config.EXECUTOR_CLASS_PATH, "bar=baz") - - val factory = new ExecutorPodFactory(conf, None) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)]("qux" -> "quux"), driverPod, Map[String, Int]()) - - checkEnv(executor, - Map("SPARK_JAVA_OPT_0" -> "foo=bar", - ENV_CLASSPATH -> "bar=baz", - "qux" -> "quux")) - checkOwnerReferences(executor, driverPodUid) - } - - test("executor secrets get mounted") { - val conf = baseConf.clone() - - val secretsBootstrap = new MountSecretsBootstrap(Map("secret1" -> "/var/secret1")) - val factory = new ExecutorPodFactory(conf, Some(secretsBootstrap)) - val executor = factory.createExecutorPod( - "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.size() === 1) - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") - - // check volume mounted. - assert(executor.getSpec.getVolumes.size() === 1) - assert(executor.getSpec.getVolumes.get(0).getSecret.getSecretName === "secret1") - - checkOwnerReferences(executor, driverPodUid) - } - - // There is always exactly one controller reference, and it points to the driver pod. - private def checkOwnerReferences(executor: Pod, driverPodUid: String): Unit = { - assert(executor.getMetadata.getOwnerReferences.size() === 1) - assert(executor.getMetadata.getOwnerReferences.get(0).getUid === driverPodUid) - assert(executor.getMetadata.getOwnerReferences.get(0).getController === true) - } - - // Check that the expected environment variables are present. - private def checkEnv(executor: Pod, additionalEnvVars: Map[String, String]): Unit = { - val defaultEnvs = Map( - ENV_EXECUTOR_ID -> "1", - ENV_DRIVER_URL -> "dummy", - ENV_EXECUTOR_CORES -> "1", - ENV_EXECUTOR_MEMORY -> "1g", - ENV_APPLICATION_ID -> "dummy", - ENV_SPARK_CONF_DIR -> SPARK_CONF_DIR_INTERNAL, - ENV_EXECUTOR_POD_IP -> null) ++ additionalEnvVars - - assert(executor.getSpec.getContainers.size() === 1) - assert(executor.getSpec.getContainers.get(0).getEnv.size() === defaultEnvs.size) - val mapEnvs = executor.getSpec.getContainers.get(0).getEnv.asScala.map { - x => (x.getName, x.getValue) - }.toMap - assert(defaultEnvs === mapEnvs) - } -} diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala index b2f26f205a329..96065e83f069c 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesClusterSchedulerBackendSuite.scala @@ -18,11 +18,12 @@ package org.apache.spark.scheduler.cluster.k8s import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} -import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.api.model.{ContainerBuilder, DoneablePod, Pod, PodBuilder, PodList} import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} import io.fabric8.kubernetes.client.Watcher.Action import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} -import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.hamcrest.{BaseMatcher, Description, Matcher} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Matchers, Mock, MockitoAnnotations} import org.mockito.Matchers.{any, eq => mockitoEq} import org.mockito.Mockito.{doNothing, never, times, verify, when} import org.scalatest.BeforeAndAfter @@ -31,6 +32,7 @@ import scala.collection.JavaConverters._ import scala.concurrent.Future import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ import org.apache.spark.rpc._ @@ -47,8 +49,6 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private val SPARK_DRIVER_HOST = "localhost" private val SPARK_DRIVER_PORT = 7077 private val POD_ALLOCATION_INTERVAL = "1m" - private val DRIVER_URL = RpcEndpointAddress( - SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString private val FIRST_EXECUTOR_POD = new PodBuilder() .withNewMetadata() .withName("pod1") @@ -94,7 +94,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn private var requestExecutorsService: ExecutorService = _ @Mock - private var executorPodFactory: ExecutorPodFactory = _ + private var executorBuilder: KubernetesExecutorBuilder = _ @Mock private var kubernetesClient: KubernetesClient = _ @@ -399,7 +399,7 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn new KubernetesClusterSchedulerBackend( taskSchedulerImpl, rpcEnv, - executorPodFactory, + executorBuilder, kubernetesClient, allocatorExecutor, requestExecutorsService) { @@ -428,13 +428,22 @@ class KubernetesClusterSchedulerBackendSuite extends SparkFunSuite with BeforeAn .addToLabels(SPARK_EXECUTOR_ID_LABEL, executorId.toString) .endMetadata() .build() - when(executorPodFactory.createExecutorPod( - executorId.toString, - APP_ID, - DRIVER_URL, - sparkConf.getExecutorEnv, - driverPod, - Map.empty)).thenReturn(resolvedPod) - resolvedPod + val resolvedContainer = new ContainerBuilder().build() + when(executorBuilder.buildFromFeatures(Matchers.argThat( + new BaseMatcher[KubernetesConf[KubernetesExecutorSpecificConf]] { + override def matches(argument: scala.Any) + : Boolean = { + argument.isInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] && + argument.asInstanceOf[KubernetesConf[KubernetesExecutorSpecificConf]] + .roleSpecificConf.executorId == executorId.toString + } + + override def describeTo(description: Description): Unit = {} + }))).thenReturn(SparkPod(resolvedPod, resolvedContainer)) + new PodBuilder(resolvedPod) + .editSpec() + .addToContainers(resolvedContainer) + .endSpec() + .build() } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala new file mode 100644 index 0000000000000..f5270623f8acc --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/KubernetesExecutorBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.k8s + +import io.fabric8.kubernetes.api.model.PodBuilder + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.deploy.k8s.{KubernetesConf, KubernetesExecutorSpecificConf, SparkPod} +import org.apache.spark.deploy.k8s.features.{BasicExecutorFeatureStep, KubernetesFeaturesTestUtils, MountSecretsFeatureStep} + +class KubernetesExecutorBuilderSuite extends SparkFunSuite { + private val BASIC_STEP_TYPE = "basic" + private val SECRETS_STEP_TYPE = "mount-secrets" + + private val basicFeatureStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + BASIC_STEP_TYPE, classOf[BasicExecutorFeatureStep]) + private val mountSecretsStep = KubernetesFeaturesTestUtils.getMockConfigStepForStepType( + SECRETS_STEP_TYPE, classOf[MountSecretsFeatureStep]) + + private val builderUnderTest = new KubernetesExecutorBuilder( + _ => basicFeatureStep, + _ => mountSecretsStep) + + test("Basic steps are consistently applied.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map.empty, + Map.empty) + validateStepTypesApplied(builderUnderTest.buildFromFeatures(conf), BASIC_STEP_TYPE) + } + + test("Apply secrets step if secrets are present.") { + val conf = KubernetesConf( + new SparkConf(false), + KubernetesExecutorSpecificConf( + "executor-id", new PodBuilder().build()), + "prefix", + "appId", + Map.empty, + Map.empty, + Map("secret" -> "secretMountPath"), + Map.empty) + validateStepTypesApplied( + builderUnderTest.buildFromFeatures(conf), + BASIC_STEP_TYPE, + SECRETS_STEP_TYPE) + } + + private def validateStepTypesApplied(resolvedPod: SparkPod, stepTypes: String*): Unit = { + assert(resolvedPod.pod.getMetadata.getLabels.size === stepTypes.size) + stepTypes.foreach { stepType => + assert(resolvedPod.pod.getMetadata.getLabels.get(stepType) === stepType) + } + } +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index c1ae12aabb8cc..17234b120ae13 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -29,7 +29,6 @@ import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.Utils /** * Handles registering and unregistering the application with the YARN ResourceManager. @@ -71,7 +70,8 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) + amClient.registerApplicationMaster(driverRef.address.host, driverRef.address.port, + trackingUrl) registered = true } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 81bb2513ac82e..dacb990a438a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -394,6 +394,7 @@ object FunctionRegistry { expression[TruncTimestamp]("date_trunc"), expression[UnixTimestamp]("unix_timestamp"), expression[DayOfWeek]("dayofweek"), + expression[WeekDay]("weekday"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), expression[TimeWindow]("window"), @@ -407,6 +408,7 @@ object FunctionRegistry { expression[MapValues]("map_values"), expression[Size]("size"), expression[SortArray]("sort_array"), + expression[ArrayMax]("array_max"), expression[Reverse]("reverse"), CreateStruct.registryEntry, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index b55043c270644..ff9d6d7a7dded 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -345,7 +345,7 @@ object UnsupportedOperationChecker { plan.foreachUp { implicit subPlan => subPlan match { case (_: Project | _: Filter | _: MapElements | _: MapPartitions | - _: DeserializeToObject | _: SerializeFromObject) => + _: DeserializeToObject | _: SerializeFromObject | _: SubqueryAlias) => case node if node.nodeName == "StreamingRelationV2" => case node => throwError(s"Continuous processing does not support ${node.nodeName} operations.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index a65f58fa61ff4..71e23175168e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.parser.ParserUtils import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode} import org.apache.spark.sql.catalyst.trees.TreeNode @@ -335,7 +335,7 @@ case class UnresolvedRegex(regexPattern: String, table: Option[String], caseSens * @param names the names to be associated with each output of computing [[child]]. */ case class MultiAlias(child: Expression, names: Seq[String]) - extends UnaryExpression with NamedExpression with CodegenFallback { + extends UnaryExpression with NamedExpression with Unevaluable { override def name: String = throw new UnresolvedException(this, "name") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 9212c3de1f814..942dfd4292610 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -674,11 +674,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { val evals = evalChildren.map(eval => s""" |${eval.code} - |if (!${eval.isNull} && (${ev.isNull} || - | ${ctx.genGreater(dataType, eval.value, ev.value)})) { - | ${ev.isNull} = false; - | ${ev.value} = ${eval.value}; - |} + |${ctx.reassignIfGreater(dataType, ev, eval)} """.stripMargin ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 0abfc9fa4c465..c86c5beded9d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -699,6 +699,23 @@ class CodegenContext { case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code for updating `partialResult` if `item` is greater than it. + * + * @param dataType data type of the expressions + * @param partialResult `ExprCode` representing the partial result which has to be updated + * @param item `ExprCode` representing the new expression to evaluate for the result + */ + def reassignIfGreater(dataType: DataType, partialResult: ExprCode, item: ExprCode): String = { + s""" + |if (!${item.isNull} && (${partialResult.isNull} || + | ${genGreater(dataType, item.value, partialResult.value)})) { + | ${partialResult.isNull} = false; + | ${partialResult.value} = ${item.value}; + |} + """.stripMargin + } + /** * Generates code to do null safe execution, i.e. only execute the code when the input is not * null by adding null check if necessary. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index a6ba7393ecf60..a62ab96a3cc66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -21,7 +21,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -375,3 +375,69 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Returns the maximum value in the array. + */ +@ExpressionDescription( + usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.", + examples = """ + Examples: + > SELECT _FUNC_(array(1, 20, null, 3)); + 20 + """, since = "2.4.0") +case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType) + + private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName") + } else { + typeCheckResult + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) + val javaType = CodeGenerator.javaType(dataType) + val i = ctx.freshName("i") + val item = ExprCode("", + isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"), + value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType)) + ev.copy(code = + s""" + |${childGen.code} + |boolean ${ev.isNull} = true; + |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + |if (!${childGen.isNull}) { + | for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) { + | ${ctx.reassignIfGreater(dataType, ev, item)} + | } + |} + """.stripMargin) + } + + override protected def nullSafeEval(input: Any): Any = { + var max: Any = null + input.asInstanceOf[ArrayData].foreach(dataType, (_, item) => + if (item != null && (max == null || ordering.gt(item, max))) { + max = item + } + ) + max + } + + override def dataType: DataType = child.dataType match { + case ArrayType(dt, _) => dt + case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.") + } + + override def prettyName: String = "array_max" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 32fdb13afbbfa..b9b2cd5bdb9f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -426,36 +426,71 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa """, since = "2.3.0") // scalastyle:on line.size.limit -case class DayOfWeek(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { +case class DayOfWeek(child: Expression) extends DayWeek { - override def inputTypes: Seq[AbstractDataType] = Seq(DateType) - - override def dataType: DataType = IntegerType + override protected def nullSafeEval(date: Any): Any = { + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + cal.get(Calendar.DAY_OF_WEEK) + } - @transient private lazy val c = { - Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + nullSafeCodeGen(ctx, ev, time => { + val cal = classOf[Calendar].getName + val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") + val c = "calDayOfWeek" + ctx.addImmutableStateIfNotExists(cal, c, + v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") + s""" + $c.setTimeInMillis($time * 1000L * 3600L * 24L); + ${ev.value} = $c.get($cal.DAY_OF_WEEK); + """ + }) } +} + +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the day of the week for date/timestamp (0 = Monday, 1 = Tuesday, ..., 6 = Sunday).", + examples = """ + Examples: + > SELECT _FUNC_('2009-07-30'); + 3 + """, + since = "2.4.0") +// scalastyle:on line.size.limit +case class WeekDay(child: Expression) extends DayWeek { override protected def nullSafeEval(date: Any): Any = { - c.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) - c.get(Calendar.DAY_OF_WEEK) + cal.setTimeInMillis(date.asInstanceOf[Int] * 1000L * 3600L * 24L) + (cal.get(Calendar.DAY_OF_WEEK) + 5 ) % 7 } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val c = "calDayOfWeek" + val c = "calWeekDay" ctx.addImmutableStateIfNotExists(cal, c, v => s"""$v = $cal.getInstance($dtu.getTimeZone("UTC"));""") s""" $c.setTimeInMillis($time * 1000L * 3600L * 24L); - ${ev.value} = $c.get($cal.DAY_OF_WEEK); + ${ev.value} = ($c.get($cal.DAY_OF_WEEK) + 5) % 7; """ }) } } +abstract class DayWeek extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(DateType) + + override def dataType: DataType = IntegerType + + @transient protected lazy val cal: Calendar = { + Calendar.getInstance(DateTimeUtils.getTimeZone("UTC")) + } +} + // scalastyle:off line.size.limit @ExpressionDescription( usage = "_FUNC_(date) - Returns the week of the year of the given date. A week is considered to start on a Monday and week 1 is the first week with >3 days.", diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9a1bbc675e397..5fb59ef350b8b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -138,6 +138,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) operatorOptimizationBatch) :+ Batch("Join Reorder", Once, CostBasedJoinReorder) :+ + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :+ Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :+ Batch("Object Expressions Optimization", fixedPoint, @@ -733,6 +735,16 @@ object EliminateSorts extends Rule[LogicalPlan] { } } +/** + * Removes Sort operation if the child is already sorted + */ +object RemoveRedundantSorts extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case Sort(orders, true, child) if SortOrder.orderingSatisfies(child.outputOrdering, orders) => + child + } +} + /** * Removes filters that can be evaluated trivially. This can be done through the following ways: * 1) by eliding the filter for cases where it will always evaluate to `true`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index c8ccd9bd03994..42034403d6d03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -219,6 +219,11 @@ abstract class LogicalPlan * Refreshes (or invalidates) any metadata/data cached in the plan recursively. */ def refresh(): Unit = children.foreach(_.refresh()) + + /** + * Returns the output ordering that this plan generates. + */ + def outputOrdering: Seq[SortOrder] = Nil } /** @@ -274,3 +279,7 @@ abstract class BinaryNode extends LogicalPlan { override final def children: Seq[LogicalPlan] = Seq(left, right) } + +abstract class OrderPreservingUnaryNode extends UnaryNode { + override final def outputOrdering: Seq[SortOrder] = child.outputOrdering +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index a4fca790dd086..10df504795430 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -43,11 +43,12 @@ case class ReturnAnswer(child: LogicalPlan) extends UnaryNode { * This node is inserted at the top of a subquery when it is optimized. This makes sure we can * recognize a subquery as such, and it allows us to write subquery aware transformations. */ -case class Subquery(child: LogicalPlan) extends UnaryNode { +case class Subquery(child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output } -case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { +case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) + extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) override def maxRows: Option[Long] = child.maxRows @@ -125,7 +126,7 @@ case class Generate( } case class Filter(condition: Expression, child: LogicalPlan) - extends UnaryNode with PredicateHelper { + extends OrderPreservingUnaryNode with PredicateHelper { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows @@ -469,6 +470,7 @@ case class Sort( child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = child.maxRows + override def outputOrdering: Seq[SortOrder] = order } /** Factory for constructing new `Range` nodes. */ @@ -522,6 +524,15 @@ case class Range( override def computeStats(): Statistics = { Statistics(sizeInBytes = LongType.defaultSize * numElements) } + + override def outputOrdering: Seq[SortOrder] = { + val order = if (step > 0) { + Ascending + } else { + Descending + } + output.map(a => SortOrder(a, order)) + } } case class Aggregate( @@ -728,7 +739,7 @@ object Limit { * * See [[Limit]] for more information. */ -case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { limitExpr match { @@ -744,7 +755,7 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN * * See [[Limit]] for more information. */ -case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { +case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends OrderPreservingUnaryNode { override def output: Seq[Attribute] = child.output override def maxRowsPerPartition: Option[Long] = { @@ -764,7 +775,7 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo case class SubqueryAlias( alias: String, child: LogicalPlan) - extends UnaryNode { + extends OrderPreservingUnaryNode { override def doCanonicalize(): LogicalPlan = child.canonicalized diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 1c8ab9c62623e..0dc47bfe075d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -930,6 +930,13 @@ object SQLConf { .intConf .createWithDefault(100) + val STREAMING_CHECKPOINT_FILE_MANAGER_CLASS = + buildConf("spark.sql.streaming.checkpointFileManagerClass") + .doc("The class used to write checkpoint files atomically. This class must be a subclass " + + "of the interface CheckpointFileManager.") + .internal() + .stringConf + val NDV_MAX_ERROR = buildConf("spark.sql.statistics.ndv.maxError") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index ac1c18ecf625a..a64172a53e17b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -106,6 +106,16 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + test("Array max") { + checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), ArrayType(IntegerType))), 10) + checkEvaluation( + ArrayMax(Literal.create(Seq[String](null, "abc", ""), ArrayType(StringType))), "abc") + checkEvaluation(ArrayMax(Literal.create(Seq(null), ArrayType(LongType))), null) + checkEvaluation(ArrayMax(Literal.create(null, ArrayType(StringType))), null) + checkEvaluation( + ArrayMax(Literal.create(Seq(1.123, 0.1234, 1.121), ArrayType(DoubleType))), 1.123) + } + test("Reverse") { // Primitive-type elements val ai0 = Literal.create(Seq(2, 1, 4, 3), ArrayType(IntegerType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 786266a2c13c0..080ec487cfa6a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -211,6 +211,17 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(DayOfWeek, DateType) } + test("WeekDay") { + checkEvaluation(WeekDay(Literal.create(null, DateType)), null) + checkEvaluation(WeekDay(Literal(d)), 2) + checkEvaluation(WeekDay(Cast(Literal(sdfDate.format(d)), DateType, gmtId)), 2) + checkEvaluation(WeekDay(Cast(Literal(ts), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Cast(Literal("2011-05-06"), DateType, gmtId)), 4) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("2017-05-27 13:10:15").getTime))), 5) + checkEvaluation(WeekDay(Literal(new Date(sdf.parse("1582-10-15 13:10:15").getTime))), 4) + checkConsistencyBetweenInterpretedAndCodegen(WeekDay, DateType) + } + test("WeekOfYear") { checkEvaluation(WeekOfYear(Literal.create(null, DateType)), null) checkEvaluation(WeekOfYear(Literal(d)), 15) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala new file mode 100644 index 0000000000000..2319ab8046e56 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RemoveRedundantSortsSuite.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.{Analyzer, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{CASE_SENSITIVE, ORDER_BY_ORDINAL} + +class RemoveRedundantSortsSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Remove Redundant Sorts", Once, + RemoveRedundantSorts) :: + Batch("Collapse Project", Once, + CollapseProject) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + + test("remove redundant order by") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val unnecessaryReordered = orderedPlan.select('a).orderBy('a.asc, 'b.desc_nullsFirst) + val optimized = Optimize.execute(unnecessaryReordered.analyze) + val correctAnswer = orderedPlan.select('a).analyze + comparePlans(Optimize.execute(optimized), correctAnswer) + } + + test("do not remove sort if the order is different") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc_nullsFirst) + val reorderedDifferently = orderedPlan.select('a).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(reorderedDifferently.analyze) + val correctAnswer = reorderedDifferently.analyze + comparePlans(optimized, correctAnswer) + } + + test("filters don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.where('a > Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.where('a > Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("limits don't affect order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc, 'b.desc) + val filteredAndReordered = orderedPlan.limit(Literal(10)).orderBy('a.asc, 'b.desc) + val optimized = Optimize.execute(filteredAndReordered.analyze) + val correctAnswer = orderedPlan.limit(Literal(10)).analyze + comparePlans(optimized, correctAnswer) + } + + test("range is already sorted") { + val inputPlan = Range(1L, 1000L, 1, 10) + val orderedPlan = inputPlan.orderBy('id.asc) + val optimized = Optimize.execute(orderedPlan.analyze) + val correctAnswer = inputPlan.analyze + comparePlans(optimized, correctAnswer) + + val reversedPlan = inputPlan.orderBy('id.desc) + val reversedOptimized = Optimize.execute(reversedPlan.analyze) + val reversedCorrectAnswer = reversedPlan.analyze + comparePlans(reversedOptimized, reversedCorrectAnswer) + + val negativeStepInputPlan = Range(10L, 1L, -1, 10) + val negativeStepOrderedPlan = negativeStepInputPlan.orderBy('id.desc) + val negativeStepOptimized = Optimize.execute(negativeStepOrderedPlan.analyze) + val negativeStepCorrectAnswer = negativeStepInputPlan.analyze + comparePlans(negativeStepOptimized, negativeStepCorrectAnswer) + } + + test("sort should not be removed when there is a node which doesn't guarantee any order") { + val orderedPlan = testRelation.select('a, 'b).orderBy('a.asc) + val groupedAndResorted = orderedPlan.groupBy('a)(sum('a)).orderBy('a.asc) + val optimized = Optimize.execute(groupedAndResorted.analyze) + val correctAnswer = groupedAndResorted.analyze + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 0aee1d7be5788..917168162b236 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3189,10 +3189,10 @@ class Dataset[T] private[sql]( private[sql] def collectToPython(): Int = { EvaluatePython.registerPicklers() - withNewExecutionId { + withAction("collectToPython", queryExecution) { plan => val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) - val iter = new SerDeUtil.AutoBatchedPickler( - queryExecution.executedPlan.executeCollect().iterator.map(toJava)) + val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler( + plan.executeCollect().iterator.map(toJava)) PythonRDD.serveIterator(iter, "serve-DataFrame") } } @@ -3201,8 +3201,9 @@ class Dataset[T] private[sql]( * Collect a Dataset as ArrowPayload byte arrays and serve to PySpark. */ private[sql] def collectAsArrowToPython(): Int = { - withNewExecutionId { - val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable) + withAction("collectAsArrowToPython", queryExecution) { plan => + val iter: Iterator[Array[Byte]] = + toArrowPayload(plan).collect().iterator.map(_.asPythonSerializable) PythonRDD.serveIterator(iter, "serve-Arrow") } } @@ -3311,14 +3312,19 @@ class Dataset[T] private[sql]( } /** Convert to an RDD of ArrowPayload byte arrays */ - private[sql] def toArrowPayload: RDD[ArrowPayload] = { + private[sql] def toArrowPayload(plan: SparkPlan): RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone - queryExecution.toRdd.mapPartitionsInternal { iter => + plan.execute().mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toPayloadIterator( iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } + + // This is only used in tests, for now. + private[sql] def toArrowPayload: RDD[ArrowPayload] = { + toArrowPayload(queryExecution.executedPlan) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index b107492fbb330..c502e583a55c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -44,7 +44,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.Utils +import org.apache.spark.util.{CallSite, Utils} /** @@ -81,6 +81,9 @@ class SparkSession private( @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => + // The call site where this SparkSession was constructed. + private val creationSite: CallSite = Utils.getCallSite() + private[sql] def this(sc: SparkContext) { this(sc, None, None, new SparkSessionExtensions) } @@ -763,7 +766,7 @@ class SparkSession private( @InterfaceStability.Stable -object SparkSession { +object SparkSession extends Logging { /** * Builder for [[SparkSession]]. @@ -1090,4 +1093,20 @@ object SparkSession { } } + private[spark] def cleanupAnyExistingSession(): Unit = { + val session = getActiveSession.orElse(getDefaultSession) + if (session.isDefined) { + logWarning( + s"""An existing Spark session exists as the active or default session. + |This probably means another suite leaked it. Attempting to stop it before continuing. + |This existing Spark session was created at: + | + |${session.get.creationSite.longForm} + | + """.stripMargin) + session.get.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index d68aeb275afda..a8794be7280c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -99,7 +99,7 @@ class CacheManager extends Logging { sparkSession.sessionState.conf.columnBatchSize, storageLevel, sparkSession.sessionState.executePlan(planToCache).executedPlan, tableName, - planToCache.stats) + planToCache) cachedData.add(CachedData(planToCache, inMemoryRelation)) } } @@ -148,7 +148,7 @@ class CacheManager extends Logging { storageLevel = cd.cachedRepresentation.storageLevel, child = spark.sessionState.executePlan(cd.plan).executedPlan, tableName = cd.cachedRepresentation.tableName, - statsOfPlanToCache = cd.plan.stats) + logicalPlan = cd.plan) needToRecache += cd.copy(cachedRepresentation = newCache) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index f3555508185fe..be50a1571a2ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -125,7 +125,7 @@ case class LogicalRDD( output: Seq[Attribute], rdd: RDD[InternalRow], outputPartitioning: Partitioning = UnknownPartitioning(0), - outputOrdering: Seq[SortOrder] = Nil, + override val outputOrdering: Seq[SortOrder] = Nil, override val isStreaming: Boolean = false)(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 2579046e30708..a7ba9b86a176f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, Statistics} +import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, LogicalPlan, Statistics} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.storage.StorageLevel import org.apache.spark.util.LongAccumulator @@ -39,9 +39,9 @@ object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String], - statsOfPlanToCache: Statistics): InMemoryRelation = + logicalPlan: LogicalPlan): InMemoryRelation = new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)( - statsOfPlanToCache = statsOfPlanToCache) + statsOfPlanToCache = logicalPlan.stats, outputOrdering = logicalPlan.outputOrdering) } @@ -64,7 +64,8 @@ case class InMemoryRelation( tableName: Option[String])( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val sizeInBytesStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator, - statsOfPlanToCache: Statistics) + statsOfPlanToCache: Statistics, + override val outputOrdering: Seq[SortOrder]) extends logical.LeafNode with MultiInstanceRelation { override protected def innerChildren: Seq[SparkPlan] = Seq(child) @@ -76,7 +77,8 @@ case class InMemoryRelation( tableName = None)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache) + statsOfPlanToCache, + outputOrdering) override def producedAttributes: AttributeSet = outputSet @@ -159,7 +161,7 @@ case class InMemoryRelation( def withOutput(newOutput: Seq[Attribute]): InMemoryRelation = { InMemoryRelation( newOutput, useCompression, batchSize, storageLevel, child, tableName)( - _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache) + _cachedColumnBuffers, sizeInBytesStats, statsOfPlanToCache, outputOrdering) } override def newInstance(): this.type = { @@ -172,7 +174,8 @@ case class InMemoryRelation( tableName)( _cachedColumnBuffers, sizeInBytesStats, - statsOfPlanToCache).asInstanceOf[this.type] + statsOfPlanToCache, + outputOrdering).asInstanceOf[this.type] } def cachedColumnBuffers: RDD[CachedBatch] = _cachedColumnBuffers diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 4046396d0e614..a66a07673e25f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -85,7 +85,7 @@ class CatalogFileIndex( sparkSession, new Path(baseLocation.get), fileStatusCache, partitionSpec, Option(timeNs)) } else { new InMemoryFileIndex( - sparkSession, rootPaths, table.storage.properties, partitionSchema = None) + sparkSession, rootPaths, table.storage.properties, userSpecifiedSchema = None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b84ea769808f9..f16d824201e77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -23,7 +23,6 @@ import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} import scala.util.{Failure, Success, Try} -import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil @@ -103,24 +102,6 @@ case class DataSource( bucket.sortColumnNames, "in the sort definition", equality) } - /** - * In the read path, only managed tables by Hive provide the partition columns properly when - * initializing this class. All other file based data sources will try to infer the partitioning, - * and then cast the inferred types to user specified dataTypes if the partition columns exist - * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or - * inconsistent data types as reported in SPARK-21463. - * @param fileIndex A FileIndex that will perform partition inference - * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` - */ - private def combineInferredAndUserSpecifiedPartitionSchema(fileIndex: FileIndex): StructType = { - val resolved = fileIndex.partitionSchema.map { partitionField => - // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred - userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( - partitionField) - } - StructType(resolved) - } - /** * Get the schema of the given FileFormat, if provided by `userSpecifiedSchema`, or try to infer * it. In the read path, only managed tables by Hive provide the partition columns properly when @@ -140,31 +121,26 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource - * @param fileStatusCache the shared cache for file statuses to speed up listing + * @param fileIndex optional [[InMemoryFileIndex]] for getting partition schema and file list * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ private def getOrInferFileFormatSchema( format: FileFormat, - fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = { - // the operations below are expensive therefore try not to do them if we don't need to, e.g., + fileIndex: Option[InMemoryFileIndex] = None): (StructType, StructType) = { + // The operations below are expensive therefore try not to do them if we don't need to, e.g., // in streaming mode, we have already inferred and registered partition columns, we will // never have to materialize the lazy val below - lazy val tempFileIndex = { - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.toSeq.flatMap { path => - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - }.toArray - new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) + lazy val tempFileIndex = fileIndex.getOrElse { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = false, checkFilesExist = false) + createInMemoryFileIndex(globbedPaths) } + val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning // columns properly unless it is a Hive DataSource - combineInferredAndUserSpecifiedPartitionSchema(tempFileIndex) + tempFileIndex.partitionSchema } else { // maintain old behavior before SPARK-18510. If userSpecifiedSchema is empty used inferred // partitioning @@ -356,13 +332,7 @@ case class DataSource( caseInsensitiveOptions.get("path").toSeq ++ paths, sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) - val tempFileCatalog = new MetadataLogFileIndex(sparkSession, basePath, None) - val fileCatalog = if (userSpecifiedSchema.nonEmpty) { - val partitionSchema = combineInferredAndUserSpecifiedPartitionSchema(tempFileCatalog) - new MetadataLogFileIndex(sparkSession, basePath, Option(partitionSchema)) - } else { - tempFileCatalog - } + val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema) val dataSchema = userSpecifiedSchema.orElse { format.inferSchema( sparkSession, @@ -384,24 +354,23 @@ case class DataSource( // This is a non-streaming file based datasource. case (format: FileFormat, _) => - val allPaths = caseInsensitiveOptions.get("path") ++ paths - val hadoopConf = sparkSession.sessionState.newHadoopConf() - val globbedPaths = allPaths.flatMap( - DataSource.checkAndGlobPathIfNecessary(hadoopConf, _, checkFilesExist)).toArray - - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) - - val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && - catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { + val globbedPaths = + checkAndGlobPathIfNecessary(checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) + val useCatalogFileIndex = sparkSession.sqlContext.conf.manageFilesourcePartitions && + catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog && + catalogTable.get.partitionColumnNames.nonEmpty + val (fileCatalog, dataSchema, partitionSchema) = if (useCatalogFileIndex) { val defaultTableSize = sparkSession.sessionState.conf.defaultSizeInBytes - new CatalogFileIndex( + val index = new CatalogFileIndex( sparkSession, catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) + (index, catalogTable.get.dataSchema, catalogTable.get.partitionSchema) } else { - new InMemoryFileIndex( - sparkSession, globbedPaths, options, Some(partitionSchema), fileStatusCache) + val index = createInMemoryFileIndex(globbedPaths) + val (resultDataSchema, resultPartitionSchema) = + getOrInferFileFormatSchema(format, Some(index)) + (index, resultDataSchema, resultPartitionSchema) } HadoopFsRelation( @@ -552,6 +521,40 @@ case class DataSource( sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.") } } + + /** Returns an [[InMemoryFileIndex]] that can be used to get partition schema and file list. */ + private def createInMemoryFileIndex(globbedPaths: Seq[Path]): InMemoryFileIndex = { + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex( + sparkSession, globbedPaths, options, userSpecifiedSchema, fileStatusCache) + } + + /** + * Checks and returns files in all the paths. + */ + private def checkAndGlobPathIfNecessary( + checkEmptyGlobPath: Boolean, + checkFilesExist: Boolean): Seq[Path] = { + val allPaths = caseInsensitiveOptions.get("path") ++ paths + val hadoopConf = sparkSession.sessionState.newHadoopConf() + allPaths.flatMap { path => + val hdfsPath = new Path(path) + val fs = hdfsPath.getFileSystem(hadoopConf) + val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) + + if (checkEmptyGlobPath && globPath.isEmpty) { + throw new AnalysisException(s"Path does not exist: $qualified") + } + + // Sufficient to check head of the globPath seq for non-glob scenario + // Don't need to check once again if files exist in streaming mode + if (checkFilesExist && !fs.exists(globPath.head)) { + throw new AnalysisException(s"Path does not exist: ${globPath.head}") + } + globPath + }.toSeq + } } object DataSource extends Logging { @@ -699,30 +702,6 @@ object DataSource extends Logging { locationUri = path.map(CatalogUtils.stringToURI), properties = optionsWithoutPath) } - /** - * If `path` is a file pattern, return all the files that match it. Otherwise, return itself. - * If `checkFilesExist` is `true`, also check the file existence. - */ - private def checkAndGlobPathIfNecessary( - hadoopConf: Configuration, - path: String, - checkFilesExist: Boolean): Seq[Path] = { - val hdfsPath = new Path(path) - val fs = hdfsPath.getFileSystem(hadoopConf) - val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - val globPath = SparkHadoopUtil.get.globPathIfNecessary(fs, qualified) - - if (globPath.isEmpty) { - throw new AnalysisException(s"Path does not exist: $qualified") - } - // Sufficient to check head of the globPath seq for non-glob scenario - // Don't need to check once again if files exist in streaming mode - if (checkFilesExist && !fs.exists(globPath.head)) { - throw new AnalysisException(s"Path does not exist: ${globPath.head}") - } - globPath - } - /** * Called before writing into a FileFormat based data source to make sure the * supplied schema is not empty. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 318ada0ceefc5..739d1f456e3ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -41,17 +41,17 @@ import org.apache.spark.util.SerializableConfiguration * @param rootPathsSpecified the list of root table paths to scan (some of which might be * filtered out later) * @param parameters as set of options to control discovery - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, rootPathsSpecified: Seq[Path], parameters: Map[String, String], - partitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( - sparkSession, parameters, partitionSchema, fileStatusCache) { + sparkSession, parameters, userSpecifiedSchema, fileStatusCache) { // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 6b6f6388d54e8..cc8af7b92c454 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -34,13 +34,13 @@ import org.apache.spark.sql.types.{StringType, StructType} * It provides the necessary methods to parse partition data based on a set of files. * * @param parameters as set of options to control partition discovery - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ abstract class PartitioningAwareFileIndex( sparkSession: SparkSession, parameters: Map[String, String], - userPartitionSchema: Option[StructType], + userSpecifiedSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends FileIndex with Logging { import PartitioningAwareFileIndex.BASE_PATH_PARAM @@ -126,35 +126,32 @@ abstract class PartitioningAwareFileIndex( val caseInsensitiveOptions = CaseInsensitiveMap(parameters) val timeZoneId = caseInsensitiveOptions.get(DateTimeUtils.TIMEZONE_OPTION) .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) - - userPartitionSchema match { + val inferredPartitionSpec = PartitioningUtils.parsePartitions( + leafDirs, + typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, + basePaths = basePaths, + timeZoneId = timeZoneId) + userSpecifiedSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val spec = PartitioningUtils.parsePartitions( - leafDirs, - typeInference = false, - basePaths = basePaths, - timeZoneId = timeZoneId) + val userPartitionSchema = + combineInferredAndUserSpecifiedPartitionSchema(inferredPartitionSpec) - // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => + val dt = inferredPartitionSpec.partitionColumns.fields(i).dataType Cast( - Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType, + Literal.create(row.get(i, dt), dt), + userPartitionSchema.fields(i).dataType, Option(timeZoneId)).eval() }: _*) } - PartitionSpec(userProvidedSchema, spec.partitions.map { part => + PartitionSpec(userPartitionSchema, inferredPartitionSpec.partitions.map { part => part.copy(values = castPartitionValuesToUserSchema(part.values)) }) case _ => - PartitioningUtils.parsePartitions( - leafDirs, - typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths, - timeZoneId = timeZoneId) + inferredPartitionSpec } } @@ -236,6 +233,25 @@ abstract class PartitioningAwareFileIndex( val name = path.getName !((name.startsWith("_") && !name.contains("=")) || name.startsWith(".")) } + + /** + * In the read path, only managed tables by Hive provide the partition columns properly when + * initializing this class. All other file based data sources will try to infer the partitioning, + * and then cast the inferred types to user specified dataTypes if the partition columns exist + * inside `userSpecifiedSchema`, otherwise we can hit data corruption bugs like SPARK-18510, or + * inconsistent data types as reported in SPARK-21463. + * @param spec A partition inference result + * @return The PartitionSchema resolved from inference and cast according to `userSpecifiedSchema` + */ + private def combineInferredAndUserSpecifiedPartitionSchema(spec: PartitionSpec): StructType = { + val equality = sparkSession.sessionState.conf.resolver + val resolved = spec.partitionColumns.map { partitionField => + // SPARK-18510: try to get schema from userSpecifiedSchema, otherwise fallback to inferred + userSpecifiedSchema.flatMap(_.find(f => equality(f.name, partitionField.name))).getOrElse( + partitionField) + } + StructType(resolved) + } } object PartitioningAwareFileIndex { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala new file mode 100644 index 0000000000000..606ba250ad9d2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManager.scala @@ -0,0 +1,349 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io.{FileNotFoundException, IOException, OutputStream} +import java.util.{EnumSet, UUID} + +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ +import org.apache.hadoop.fs.local.{LocalFs, RawLocalFs} +import org.apache.hadoop.fs.permission.FsPermission + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.RenameHelperMethods +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.util.Utils + +/** + * An interface to abstract out all operation related to streaming checkpoints. Most importantly, + * the key operation this interface provides is `createAtomic(path, overwrite)` which returns a + * `CancellableFSDataOutputStream`. This method is used by [[HDFSMetadataLog]] and + * [[org.apache.spark.sql.execution.streaming.state.StateStore StateStore]] implementations + * to write a complete checkpoint file atomically (i.e. no partial file will be visible), with or + * without overwrite. + * + * This higher-level interface above the Hadoop FileSystem is necessary because + * different implementation of FileSystem/FileContext may have different combination of operations + * to provide the desired atomic guarantees (e.g. write-to-temp-file-and-rename, + * direct-write-and-cancel-on-failure) and this abstraction allow different implementations while + * keeping the usage simple (`createAtomic` -> `close` or `cancel`). + */ +trait CheckpointFileManager { + + import org.apache.spark.sql.execution.streaming.CheckpointFileManager._ + + /** + * Create a file and make its contents available atomically after the output stream is closed. + * + * @param path Path to create + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def createAtomic(path: Path, overwriteIfPossible: Boolean): CancellableFSDataOutputStream + + /** Open a file for reading, or throw exception if it does not exist. */ + def open(path: Path): FSDataInputStream + + /** List the files in a path that match a filter. */ + def list(path: Path, filter: PathFilter): Array[FileStatus] + + /** List all the files in a path. */ + def list(path: Path): Array[FileStatus] = { + list(path, new PathFilter { override def accept(path: Path): Boolean = true }) + } + + /** Make directory at the give path and all its parent directories as needed. */ + def mkdirs(path: Path): Unit + + /** Whether path exists */ + def exists(path: Path): Boolean + + /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ + def delete(path: Path): Unit + + /** Is the default file system this implementation is operating on the local file system. */ + def isLocal: Boolean +} + +object CheckpointFileManager extends Logging { + + /** + * Additional methods in CheckpointFileManager implementations that allows + * [[RenameBasedFSDataOutputStream]] get atomicity by write-to-temp-file-and-rename + */ + sealed trait RenameHelperMethods { self => CheckpointFileManager + /** Create a file with overwrite. */ + def createTempFile(path: Path): FSDataOutputStream + + /** + * Rename a file. + * + * @param srcPath Source path to rename + * @param dstPath Destination path to rename to + * @param overwriteIfPossible If true, then the implementations must do a best-effort attempt to + * overwrite the file if it already exists. It should not throw + * any exception if the file exists. However, if false, then the + * implementation must not overwrite if the file alraedy exists and + * must throw `FileAlreadyExistsException` in that case. + */ + def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit + } + + /** + * An interface to add the cancel() operation to [[FSDataOutputStream]]. This is used + * mainly by `CheckpointFileManager.createAtomic` to write a file atomically. + * + * @see [[CheckpointFileManager]]. + */ + abstract class CancellableFSDataOutputStream(protected val underlyingStream: OutputStream) + extends FSDataOutputStream(underlyingStream, null) { + /** Cancel the `underlyingStream` and ensure that the output file is not generated. */ + def cancel(): Unit + } + + /** + * An implementation of [[CancellableFSDataOutputStream]] that writes a file atomically by writing + * to a temporary file and then renames. + */ + sealed class RenameBasedFSDataOutputStream( + fm: CheckpointFileManager with RenameHelperMethods, + finalPath: Path, + tempPath: Path, + overwriteIfPossible: Boolean) + extends CancellableFSDataOutputStream(fm.createTempFile(tempPath)) { + + def this(fm: CheckpointFileManager with RenameHelperMethods, path: Path, overwrite: Boolean) = { + this(fm, path, generateTempPath(path), overwrite) + } + + logInfo(s"Writing atomically to $finalPath using temp file $tempPath") + @volatile private var terminated = false + + override def close(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + try { + fm.renameTempFile(tempPath, finalPath, overwriteIfPossible) + } catch { + case fe: FileAlreadyExistsException => + logWarning( + s"Failed to rename temp file $tempPath to $finalPath because file exists", fe) + if (!overwriteIfPossible) throw fe + } + logInfo(s"Renamed temp file $tempPath to $finalPath") + } finally { + terminated = true + } + } + + override def cancel(): Unit = synchronized { + try { + if (terminated) return + underlyingStream.close() + fm.delete(tempPath) + } catch { + case NonFatal(e) => + logWarning(s"Error cancelling write to $finalPath", e) + } finally { + terminated = true + } + } + } + + + /** Create an instance of [[CheckpointFileManager]] based on the path and configuration. */ + def create(path: Path, hadoopConf: Configuration): CheckpointFileManager = { + val fileManagerClass = hadoopConf.get( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key) + if (fileManagerClass != null) { + return Utils.classForName(fileManagerClass) + .getConstructor(classOf[Path], classOf[Configuration]) + .newInstance(path, hadoopConf) + .asInstanceOf[CheckpointFileManager] + } + try { + // Try to create a manager based on `FileContext` because HDFS's `FileContext.rename() + // gives atomic renames, which is what we rely on for the default implementation + // `CheckpointFileManager.createAtomic`. + new FileContextBasedCheckpointFileManager(path, hadoopConf) + } catch { + case e: UnsupportedFileSystemException => + logWarning( + "Could not use FileContext API for managing Structured Streaming checkpoint files at " + + s"$path. Using FileSystem API instead for managing log files. If the implementation " + + s"of FileSystem.rename() is not atomic, then the correctness and fault-tolerance of" + + s"your Structured Streaming is not guaranteed.") + new FileSystemBasedCheckpointFileManager(path, hadoopConf) + } + } + + private def generateTempPath(path: Path): Path = { + val tc = org.apache.spark.TaskContext.get + val tid = if (tc != null) ".TID" + tc.taskAttemptId else "" + new Path(path.getParent, s".${path.getName}.${UUID.randomUUID}${tid}.tmp") + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileSystem]] API. */ +class FileSystemBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + protected val fs = path.getFileSystem(hadoopConf) + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fs.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fs.mkdirs(path, FsPermission.getDirDefault) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + fs.create(path, true) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fs.open(path) + } + + override def exists(path: Path): Boolean = { + try + return fs.getFileStatus(path) != null + catch { + case e: FileNotFoundException => + return false + } + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + if (!overwriteIfPossible && fs.exists(dstPath)) { + throw new FileAlreadyExistsException( + s"Failed to rename $srcPath to $dstPath as destination already exists") + } + + if (!fs.rename(srcPath, dstPath)) { + // FileSystem.rename() returning false is very ambiguous as it can be for many reasons. + // This tries to make a best effort attempt to return the most appropriate exception. + if (fs.exists(dstPath)) { + if (!overwriteIfPossible) { + throw new FileAlreadyExistsException(s"Failed to rename as $dstPath already exists") + } + } else if (!fs.exists(srcPath)) { + throw new FileNotFoundException(s"Failed to rename as $srcPath was not found") + } else { + val msg = s"Failed to rename temp file $srcPath to $dstPath as rename returned false" + logWarning(msg) + throw new IOException(msg) + } + } + } + + override def delete(path: Path): Unit = { + try { + fs.delete(path, true) + } catch { + case e: FileNotFoundException => + logInfo(s"Failed to delete $path as it does not exist") + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fs match { + case _: LocalFileSystem | _: RawLocalFileSystem => true + case _ => false + } +} + + +/** An implementation of [[CheckpointFileManager]] using Hadoop's [[FileContext]] API. */ +class FileContextBasedCheckpointFileManager(path: Path, hadoopConf: Configuration) + extends CheckpointFileManager with RenameHelperMethods with Logging { + + import CheckpointFileManager._ + + private val fc = if (path.toUri.getScheme == null) { + FileContext.getFileContext(hadoopConf) + } else { + FileContext.getFileContext(path.toUri, hadoopConf) + } + + override def list(path: Path, filter: PathFilter): Array[FileStatus] = { + fc.util.listStatus(path, filter) + } + + override def mkdirs(path: Path): Unit = { + fc.mkdir(path, FsPermission.getDirDefault, true) + } + + override def createTempFile(path: Path): FSDataOutputStream = { + import CreateFlag._ + import Options._ + fc.create( + path, EnumSet.of(CREATE, OVERWRITE), CreateOpts.checksumParam(ChecksumOpt.createDisabled())) + } + + override def createAtomic( + path: Path, + overwriteIfPossible: Boolean): CancellableFSDataOutputStream = { + new RenameBasedFSDataOutputStream(this, path, overwriteIfPossible) + } + + override def open(path: Path): FSDataInputStream = { + fc.open(path) + } + + override def exists(path: Path): Boolean = { + fc.util.exists(path) + } + + override def renameTempFile(srcPath: Path, dstPath: Path, overwriteIfPossible: Boolean): Unit = { + import Options.Rename._ + fc.rename(srcPath, dstPath, if (overwriteIfPossible) OVERWRITE else NONE) + } + + + override def delete(path: Path): Unit = { + try { + fc.delete(path, true) + } catch { + case e: FileNotFoundException => + // ignore if file has already been deleted + } + } + + override def isLocal: Boolean = fc.getDefaultFileSystem match { + case _: LocalFs | _: RawLocalFs => true // LocalFs = RawLocalFs + ChecksumFs + case _ => false + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 00bc215a5dc8c..bd0a46115ceb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -57,10 +57,10 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(implicitly[ClassTag[T]].runtimeClass != classOf[Seq[_]], "Should not create a log with type Seq, use Arrays instead - see SPARK-17372") - import HDFSMetadataLog._ - val metadataPath = new Path(path) - protected val fileManager = createFileManager() + + protected val fileManager = + CheckpointFileManager.create(metadataPath, sparkSession.sessionState.newHadoopConf) if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) @@ -109,84 +109,31 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: require(metadata != null, "'null' metadata cannot written to a metadata log") get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written - writeBatch(batchId, metadata) + writeBatchToFile(metadata, batchIdToPath(batchId)) true } } - private def writeTempBatch(metadata: T): Option[Path] = { - while (true) { - val tempPath = new Path(metadataPath, s".${UUID.randomUUID.toString}.tmp") - try { - val output = fileManager.create(tempPath) - try { - serialize(metadata, output) - return Some(tempPath) - } finally { - output.close() - } - } catch { - case e: FileAlreadyExistsException => - // Failed to create "tempPath". There are two cases: - // 1. Someone is creating "tempPath" too. - // 2. This is a restart. "tempPath" has already been created but not moved to the final - // batch file (not committed). - // - // For both cases, the batch has not yet been committed. So we can retry it. - // - // Note: there is a potential risk here: if HDFSMetadataLog A is running, people can use - // the same metadata path to create "HDFSMetadataLog" and fail A. However, this is not a - // big problem because it requires the attacker must have the permission to write the - // metadata path. In addition, the old Streaming also have this issue, people can create - // malicious checkpoint files to crash a Streaming application too. - } - } - None - } - - /** - * Write a batch to a temp file then rename it to the batch file. + /** Write a batch to a temp file then rename it to the batch file. * * There may be multiple [[HDFSMetadataLog]] using the same metadata path. Although it is not a * valid behavior, we still need to prevent it from destroying the files. */ - private def writeBatch(batchId: Long, metadata: T): Unit = { - val tempPath = writeTempBatch(metadata).getOrElse( - throw new IllegalStateException(s"Unable to create temp batch file $batchId")) + private def writeBatchToFile(metadata: T, path: Path): Unit = { + val output = fileManager.createAtomic(path, overwriteIfPossible = false) try { - // Try to commit the batch - // It will fail if there is an existing file (someone has committed the batch) - logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") - fileManager.rename(tempPath, batchIdToPath(batchId)) - - // SPARK-17475: HDFSMetadataLog should not leak CRC files - // If the underlying filesystem didn't rename the CRC file, delete it. - val crcPath = new Path(tempPath.getParent(), s".${tempPath.getName()}.crc") - if (fileManager.exists(crcPath)) fileManager.delete(crcPath) + serialize(metadata, output) + output.close() } catch { case e: FileAlreadyExistsException => - // If "rename" fails, it means some other "HDFSMetadataLog" has committed the batch. - // So throw an exception to tell the user this is not a valid behavior. + output.cancel() + // If next batch file already exists, then another concurrently running query has + // written it. throw new ConcurrentModificationException( - s"Multiple HDFSMetadataLog are using $path", e) - } finally { - fileManager.delete(tempPath) - } - } - - /** - * @return the deserialized metadata in a batch file, or None if file not exist. - * @throws IllegalArgumentException when path does not point to a batch file. - */ - def get(batchFile: Path): Option[T] = { - if (fileManager.exists(batchFile)) { - if (isBatchFile(batchFile)) { - get(pathToBatchId(batchFile)) - } else { - throw new IllegalArgumentException(s"File ${batchFile} is not a batch file!") - } - } else { - None + s"Multiple streaming queries are concurrently using $path", e) + case e: Throwable => + output.cancel() + throw e } } @@ -219,7 +166,7 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) }.sorted - verifyBatchIds(batchIds, startId, endId) + HDFSMetadataLog.verifyBatchIds(batchIds, startId, endId) batchIds.map(batchId => (batchId, get(batchId))).filter(_._2.isDefined).map { case (batchId, metadataOption) => @@ -280,19 +227,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: } } - private def createFileManager(): FileManager = { - val hadoopConf = sparkSession.sessionState.newHadoopConf() - try { - new FileContextManager(metadataPath, hadoopConf) - } catch { - case e: UnsupportedFileSystemException => - logWarning("Could not use FileContext API for managing metadata log files at path " + - s"$metadataPath. Using FileSystem API instead for managing log files. The log may be " + - s"inconsistent under failures.") - new FileSystemManager(metadataPath, hadoopConf) - } - } - /** * Parse the log version from the given `text` -- will throw exception when the parsed version * exceeds `maxSupportedVersion`, or when `text` is malformed (such as "xyz", "v", "v-1", @@ -327,135 +261,6 @@ class HDFSMetadataLog[T <: AnyRef : ClassTag](sparkSession: SparkSession, path: object HDFSMetadataLog { - /** A simple trait to abstract out the file management operations needed by HDFSMetadataLog. */ - trait FileManager { - - /** List the files in a path that match a filter. */ - def list(path: Path, filter: PathFilter): Array[FileStatus] - - /** Make directory at the give path and all its parent directories as needed. */ - def mkdirs(path: Path): Unit - - /** Whether path exists */ - def exists(path: Path): Boolean - - /** Open a file for reading, or throw exception if it does not exist. */ - def open(path: Path): FSDataInputStream - - /** Create path, or throw exception if it already exists */ - def create(path: Path): FSDataOutputStream - - /** - * Atomically rename path, or throw exception if it cannot be done. - * Should throw FileNotFoundException if srcPath does not exist. - * Should throw FileAlreadyExistsException if destPath already exists. - */ - def rename(srcPath: Path, destPath: Path): Unit - - /** Recursively delete a path if it exists. Should not throw exception if file doesn't exist. */ - def delete(path: Path): Unit - } - - /** - * Default implementation of FileManager using newer FileContext API. - */ - class FileContextManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fc = if (path.toUri.getScheme == null) { - FileContext.getFileContext(hadoopConf) - } else { - FileContext.getFileContext(path.toUri, hadoopConf) - } - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fc.util.listStatus(path, filter) - } - - override def rename(srcPath: Path, destPath: Path): Unit = { - fc.rename(srcPath, destPath) - } - - override def mkdirs(path: Path): Unit = { - fc.mkdir(path, FsPermission.getDirDefault, true) - } - - override def open(path: Path): FSDataInputStream = { - fc.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fc.create(path, EnumSet.of(CreateFlag.CREATE)) - } - - override def exists(path: Path): Boolean = { - fc.util().exists(path) - } - - override def delete(path: Path): Unit = { - try { - fc.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - - /** - * Implementation of FileManager using older FileSystem API. Note that this implementation - * cannot provide atomic renaming of paths, hence can lead to consistency issues. This - * should be used only as a backup option, when FileContextManager cannot be used. - */ - class FileSystemManager(path: Path, hadoopConf: Configuration) extends FileManager { - private val fs = path.getFileSystem(hadoopConf) - - override def list(path: Path, filter: PathFilter): Array[FileStatus] = { - fs.listStatus(path, filter) - } - - /** - * Rename a path. Note that this implementation is not atomic. - * @throws FileNotFoundException if source path does not exist. - * @throws FileAlreadyExistsException if destination path already exists. - * @throws IOException if renaming fails for some unknown reason. - */ - override def rename(srcPath: Path, destPath: Path): Unit = { - if (!fs.exists(srcPath)) { - throw new FileNotFoundException(s"Source path does not exist: $srcPath") - } - if (fs.exists(destPath)) { - throw new FileAlreadyExistsException(s"Destination path already exists: $destPath") - } - if (!fs.rename(srcPath, destPath)) { - throw new IOException(s"Failed to rename $srcPath to $destPath") - } - } - - override def mkdirs(path: Path): Unit = { - fs.mkdirs(path, FsPermission.getDirDefault) - } - - override def open(path: Path): FSDataInputStream = { - fs.open(path) - } - - override def create(path: Path): FSDataOutputStream = { - fs.create(path, false) - } - - override def exists(path: Path): Boolean = { - fs.exists(path) - } - - override def delete(path: Path): Unit = { - try { - fs.delete(path, true) - } catch { - case e: FileNotFoundException => - // ignore if file has already been deleted - } - } - } - /** * Verify if batchIds are continuous and between `startId` and `endId`. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index 1da703cefd8ea..5cacdd070b735 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -30,14 +30,14 @@ import org.apache.spark.sql.types.StructType * A [[FileIndex]] that generates the list of files to processing by reading them from the * metadata log files generated by the [[FileStreamSink]]. * - * @param userPartitionSchema an optional partition schema that will be use to provide types for - * the discovered partitions + * @param userSpecifiedSchema an optional user specified schema that will be use to provide + * types for the discovered partitions */ class MetadataLogFileIndex( sparkSession: SparkSession, path: Path, - userPartitionSchema: Option[StructType]) - extends PartitioningAwareFileIndex(sparkSession, Map.empty, userPartitionSchema) { + userSpecifiedSchema: Option[StructType]) + extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) { private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") @@ -51,7 +51,7 @@ class MetadataLogFileIndex( } override protected val leafDirToChildrenFiles: Map[Path, Array[FileStatus]] = { - allFilesFromLog.toArray.groupBy(_.getPath.getParent) + allFilesFromLog.groupBy(_.getPath.getParent) } override def rootPaths: Seq[Path] = path :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 3f5002a4e6937..df722b953228b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.streaming.state -import java.io.{DataInputStream, DataOutputStream, FileNotFoundException, IOException} +import java.io._ import java.nio.channels.ClosedChannelException import java.util.Locale @@ -27,13 +27,16 @@ import scala.util.Random import scala.util.control.NonFatal import com.google.common.io.ByteStreams +import org.apache.commons.io.IOUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs._ import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.io.LZ4CompressionCodec import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.streaming.CheckpointFileManager +import org.apache.spark.sql.execution.streaming.CheckpointFileManager.CancellableFSDataOutputStream import org.apache.spark.sql.types.StructType import org.apache.spark.util.{SizeEstimator, Utils} @@ -87,10 +90,10 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case object ABORTED extends STATE private val newVersion = version + 1 - private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") - private lazy val tempDeltaFileStream = compressStream(fs.create(tempDeltaFile, true)) @volatile private var state: STATE = UPDATING - @volatile private var finalDeltaFile: Path = null + private val finalDeltaFile: Path = deltaFile(newVersion) + private lazy val deltaFileStream = fm.createAtomic(finalDeltaFile, overwriteIfPossible = true) + private lazy val compressedStream = compressStream(deltaFileStream) override def id: StateStoreId = HDFSBackedStateStoreProvider.this.stateStoreId @@ -103,14 +106,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val keyCopy = key.copy() val valueCopy = value.copy() mapToUpdate.put(keyCopy, valueCopy) - writeUpdateToDeltaFile(tempDeltaFileStream, keyCopy, valueCopy) + writeUpdateToDeltaFile(compressedStream, keyCopy, valueCopy) } override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") val prevValue = mapToUpdate.remove(key) if (prevValue != null) { - writeRemoveToDeltaFile(tempDeltaFileStream, key) + writeRemoveToDeltaFile(compressedStream, key) } } @@ -126,8 +129,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit verify(state == UPDATING, "Cannot commit after already committed or aborted") try { - finalizeDeltaFile(tempDeltaFileStream) - finalDeltaFile = commitUpdates(newVersion, mapToUpdate, tempDeltaFile) + commitUpdates(newVersion, mapToUpdate, compressedStream) state = COMMITTED logInfo(s"Committed version $newVersion for $this to file $finalDeltaFile") newVersion @@ -140,23 +142,14 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Abort all the updates made on this store. This store will not be usable any more. */ override def abort(): Unit = { - verify(state == UPDATING || state == ABORTED, "Cannot abort after already committed") - try { + // This if statement is to ensure that files are deleted only if there are changes to the + // StateStore. We have two StateStores for each task, one which is used only for reading, and + // the other used for read+write. We don't want the read-only to delete state files. + if (state == UPDATING) { + state = ABORTED + cancelDeltaFile(compressedStream, deltaFileStream) + } else { state = ABORTED - if (tempDeltaFileStream != null) { - tempDeltaFileStream.close() - } - if (tempDeltaFile != null) { - fs.delete(tempDeltaFile, true) - } - } catch { - case c: ClosedChannelException => - // This can happen when underlying file output stream has been closed before the - // compression stream. - logDebug(s"Error aborting version $newVersion into $this", c) - - case e: Exception => - logWarning(s"Error aborting version $newVersion into $this", e) } logInfo(s"Aborted version $newVersion for $this") } @@ -212,7 +205,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit this.valueSchema = valueSchema this.storeConf = storeConf this.hadoopConf = hadoopConf - fs.mkdirs(baseDir) + fm.mkdirs(baseDir) } override def stateStoreId: StateStoreId = stateStoreId_ @@ -251,31 +244,15 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit private lazy val loadedMaps = new mutable.HashMap[Long, MapType] private lazy val baseDir = stateStoreId.storeCheckpointLocation() - private lazy val fs = baseDir.getFileSystem(hadoopConf) + private lazy val fm = CheckpointFileManager.create(baseDir, hadoopConf) private lazy val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf) private case class StoreFile(version: Long, path: Path, isSnapshot: Boolean) - /** Commit a set of updates to the store with the given new version */ - private def commitUpdates(newVersion: Long, map: MapType, tempDeltaFile: Path): Path = { + private def commitUpdates(newVersion: Long, map: MapType, output: DataOutputStream): Unit = { synchronized { - val finalDeltaFile = deltaFile(newVersion) - - // scalastyle:off - // Renaming a file atop an existing one fails on HDFS - // (http://hadoop.apache.org/docs/stable/hadoop-project-dist/hadoop-common/filesystem/filesystem.html). - // Hence we should either skip the rename step or delete the target file. Because deleting the - // target file will break speculation, skipping the rename step is the only choice. It's still - // semantically correct because Structured Streaming requires rerunning a batch should - // generate the same output. (SPARK-19677) - // scalastyle:on - if (fs.exists(finalDeltaFile)) { - fs.delete(tempDeltaFile, true) - } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) { - throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile") - } + finalizeDeltaFile(output) loadedMaps.put(newVersion, map) - finalDeltaFile } } @@ -365,7 +342,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit val fileToRead = deltaFile(version) var input: DataInputStream = null val sourceStream = try { - fs.open(fileToRead) + fm.open(fileToRead) } catch { case f: FileNotFoundException => throw new IllegalStateException( @@ -412,12 +389,12 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } private def writeSnapshotFile(version: Long, map: MapType): Unit = { - val fileToWrite = snapshotFile(version) - val tempFile = - new Path(fileToWrite.getParent, s"${fileToWrite.getName}.temp-${Random.nextLong}") + val targetFile = snapshotFile(version) + var rawOutput: CancellableFSDataOutputStream = null var output: DataOutputStream = null - Utils.tryWithSafeFinally { - output = compressStream(fs.create(tempFile, false)) + try { + rawOutput = fm.createAtomic(targetFile, overwriteIfPossible = true) + output = compressStream(rawOutput) val iter = map.entrySet().iterator() while(iter.hasNext) { val entry = iter.next() @@ -429,16 +406,34 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit output.write(valueBytes) } output.writeInt(-1) - } { - if (output != null) output.close() + output.close() + } catch { + case e: Throwable => + cancelDeltaFile(compressedStream = output, rawStream = rawOutput) + throw e } - if (fs.exists(fileToWrite)) { - // Skip rename if the file is alreayd created. - fs.delete(tempFile, true) - } else if (!fs.rename(tempFile, fileToWrite)) { - throw new IOException(s"Failed to rename $tempFile to $fileToWrite") + logInfo(s"Written snapshot file for version $version of $this at $targetFile") + } + + /** + * Try to cancel the underlying stream and safely close the compressed stream. + * + * @param compressedStream the compressed stream. + * @param rawStream the underlying stream which needs to be cancelled. + */ + private def cancelDeltaFile( + compressedStream: DataOutputStream, + rawStream: CancellableFSDataOutputStream): Unit = { + try { + if (rawStream != null) rawStream.cancel() + IOUtils.closeQuietly(compressedStream) + } catch { + case e: FSError if e.getCause.isInstanceOf[IOException] => + // Closing the compressedStream causes the stream to write/flush flush data into the + // rawStream. Since the rawStream is already closed, there may be errors. + // Usually its an IOException. However, Hadoop's RawLocalFileSystem wraps + // IOException into FSError. } - logInfo(s"Written snapshot file for version $version of $this at $fileToWrite") } private def readSnapshotFile(version: Long): Option[MapType] = { @@ -447,7 +442,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit var input: DataInputStream = null try { - input = decompressStream(fs.open(fileToRead)) + input = decompressStream(fm.open(fileToRead)) var eof = false while (!eof) { @@ -508,7 +503,6 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit case None => // The last map is not loaded, probably some other instance is in charge } - } } catch { case NonFatal(e) => @@ -534,7 +528,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit } val filesToDelete = files.filter(_.version < earliestFileToRetain.version) filesToDelete.foreach { f => - fs.delete(f.path, true) + fm.delete(f.path) } logInfo(s"Deleted files older than ${earliestFileToRetain.version} for $this: " + filesToDelete.mkString(", ")) @@ -576,7 +570,7 @@ private[state] class HDFSBackedStateStoreProvider extends StateStoreProvider wit /** Fetch all the files that back the store */ private def fetchFiles(): Seq[StoreFile] = { val files: Seq[FileStatus] = try { - fs.listStatus(baseDir) + fm.list(baseDir) } catch { case _: java.io.FileNotFoundException => Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d1d9f95cb0977..7eb68c21569ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -459,7 +459,6 @@ object StateStore extends Logging { private def coordinatorRef: Option[StateStoreCoordinatorRef] = loadedProviders.synchronized { val env = SparkEnv.get if (env != null) { - logInfo("Env is not null") val isDriver = env.executorId == SparkContext.DRIVER_IDENTIFIER || env.executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER @@ -467,13 +466,12 @@ object StateStore extends Logging { // as SparkContext + SparkEnv may have been restarted. Hence, when running in driver, // always recreate the reference. if (isDriver || _coordRef == null) { - logInfo("Getting StateStoreCoordinatorRef") + logDebug("Getting StateStoreCoordinatorRef") _coordRef = StateStoreCoordinatorRef.forExecutor(env) } logInfo(s"Retrieved reference to StateStoreCoordinator: ${_coordRef}") Some(_coordRef) } else { - logInfo("Env is null") _coordRef = null None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 63951893a0025..8c1b5f66014f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3292,6 +3292,14 @@ object functions { */ def sort_array(e: Column, asc: Boolean): Column = withExpr { SortArray(e.expr, lit(asc).expr) } + /** + * Returns the maximum value in the array. + * + * @group collection_funcs + * @since 2.4.0 + */ + def array_max(e: Column): Column = withExpr { ArrayMax(e.expr) } + /** * Returns a reversed string or an array with reverse order of elements. * @group collection_funcs diff --git a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql index adea2bfa82cd3..547c2bef02b24 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/datetime.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/datetime.sql @@ -25,3 +25,5 @@ create temporary view ttf2 as select * from values select current_date = current_date(), current_timestamp = current_timestamp(), a, b from ttf2; select a, b from ttf2 order by a, current_date; + +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15'); diff --git a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out index bbb6851e69c7e..4e1cfa6e48c1c 100644 --- a/sql/core/src/test/resources/sql-tests/results/datetime.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/datetime.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 9 +-- Number of queries: 10 -- !query 0 @@ -81,3 +81,10 @@ struct -- !query 8 output 1 2 2 3 + +-- !query 9 +select weekday('2007-02-03'), weekday('2009-07-30'), weekday('2017-05-27'), weekday(null), weekday('1582-10-15 13:10:15') +-- !query 3 schema +struct +-- !query 3 output +5 3 5 NULL 4 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala index cee85ec8af04d..949505e449fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -39,7 +39,7 @@ class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { def computeChiSquareTest(): Double = { val n = 10000 // Trigger a sort - val data = spark.range(0, n, 1, 1).sort('id) + val data = spark.range(0, n, 1, 1).sort('id.desc) .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() // Compute histogram for the number of records per partition post sort diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 2103e12b566a0..cbddbf88a60f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("array_max function") { + val df = Seq( + Seq[Option[Int]](Some(1), Some(3), Some(2)), + Seq.empty[Option[Int]], + Seq[Option[Int]](None), + Seq[Option[Int]](None, Some(1), Some(-100)) + ).toDF("a") + + val answer = Seq(Row(3), Row(null), Row(null), Row(1)) + + checkAnswer(df.select(array_max(df("a"))), answer) + checkAnswer(df.selectExpr("array_max(a)"), answer) + } + test("reverse function") { val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codegen on diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala index 4efae4c46c2e1..7d1366092d1e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SessionStateSuite.scala @@ -44,6 +44,8 @@ class SessionStateSuite extends SparkFunSuite { if (activeSession != null) { activeSession.stop() activeSession = null + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } super.afterAll() } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala similarity index 50% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala rename to sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala index 91e9a9f211335..d2a6358ee822b 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestQueryExecutionListener.scala @@ -14,25 +14,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit.steps -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec +package org.apache.spark.sql + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.util.QueryExecutionListener + + +class TestQueryExecutionListener extends QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + OnSuccessCall.isOnSuccessCalled.set(true) + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { } +} /** - * A driver configuration step for mounting user-specified secrets onto user-specified paths. - * - * @param bootstrap a utility actually handling mounting of the secrets. + * This has a variable to check if `onSuccess` is actually called or not. Currently, this is for + * the test case in PySpark. See SPARK-23942. */ -private[spark] class DriverMountSecretsStep( - bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { - - override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) - val container = bootstrap.mountSecrets(driverSpec.driverContainer) - driverSpec.copy( - driverPod = pod, - driverContainer = container - ) - } +object OnSuccessCall { + val isOnSuccessCalled = new AtomicBoolean(false) + + def isCalled(): Boolean = isOnSuccessCalled.get() + + def clear(): Unit = isOnSuccessCalled.set(false) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f8b26f5b28cc7..40915a102bab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.{execution, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition, Sort} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} @@ -197,6 +197,19 @@ class PlannerSuite extends SharedSQLContext { assert(planned.child.isInstanceOf[CollectLimitExec]) } + test("SPARK-23375: Cached sorted data doesn't need to be re-sorted") { + val query = testData.select('key, 'value).sort('key.desc).cache() + assert(query.queryExecution.optimizedPlan.isInstanceOf[InMemoryRelation]) + val resorted = query.sort('key.desc) + assert(resorted.queryExecution.optimizedPlan.collect { case s: Sort => s}.isEmpty) + assert(resorted.select('key).collect().map(_.getInt(0)).toSeq == + (1 to 100).reverse) + // with a different order, the sort is needed + val sortedAsc = query.sort('key) + assert(sortedAsc.queryExecution.optimizedPlan.collect { case s: Sort => s}.size == 1) + assert(sortedAsc.select('key).collect().map(_.getInt(0)).toSeq == (1 to 100)) + } + test("PartitioningCollection") { withTempView("normal", "small", "tiny") { testData.createOrReplaceTempView("normal") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 26b63e8e8490f..9b7b316211d30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ @@ -42,7 +43,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val storageLevel = MEMORY_ONLY val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan val inMemoryRelation = InMemoryRelation(useCompression = true, 5, storageLevel, plan, None, - data.logicalPlan.stats) + data.logicalPlan) assert(inMemoryRelation.cachedColumnBuffers.getStorageLevel == storageLevel) inMemoryRelation.cachedColumnBuffers.collect().head match { @@ -119,7 +120,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("simple columnar query") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) } @@ -138,7 +139,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { val logicalPlan = testData.select('value, 'key).logicalPlan val plan = spark.sessionState.executePlan(logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - logicalPlan.stats) + logicalPlan) checkAnswer(scan, testData.collect().map { case Row(key: Int, value: String) => value -> key @@ -155,7 +156,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-1436 regression: in-memory columns must be able to be accessed multiple times") { val plan = spark.sessionState.executePlan(testData.logicalPlan).sparkPlan val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None, - testData.logicalPlan.stats) + testData.logicalPlan) checkAnswer(scan, testData.collect().toSeq) checkAnswer(scan, testData.collect().toSeq) @@ -329,7 +330,7 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-17549: cached table size should be correctly calculated") { val data = spark.sparkContext.parallelize(1 to 10, 5).toDF() val plan = spark.sessionState.executePlan(data.logicalPlan).sparkPlan - val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan.stats) + val cached = InMemoryRelation(true, 5, MEMORY_ONLY, plan, None, data.logicalPlan) // Materialize the data. val expectedAnswer = data.collect() @@ -455,7 +456,8 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { val attribute = AttributeReference("a", IntegerType)() val localTableScanExec = LocalTableScanExec(Seq(attribute), Nil) - val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, null) + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, localTableScanExec, None, + LocalRelation(Seq(attribute), Nil)) val tableScanExec = InMemoryTableScanExec(Seq(attribute), Seq(In(attribute, Nil)), testRelation) assert(tableScanExec.partitionFilters.isEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index c1d61b843d899..8764f0c42cf9f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -401,7 +401,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi sparkSession = spark, rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], - partitionSchema = None) + userSpecifiedSchema = None) // This should not fail. fileCatalog.listLeafFiles(Seq(new Path(tempDir))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala new file mode 100644 index 0000000000000..fe59cb25d5005 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CheckpointFileManagerSuite.scala @@ -0,0 +1,192 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import java.io._ +import java.net.URI + +import scala.util.Random + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +abstract class CheckpointFileManagerTests extends SparkFunSuite { + + def createManager(path: Path): CheckpointFileManager + + test("mkdirs, list, createAtomic, open, delete, exists") { + withTempPath { p => + val basePath = new Path(p.getAbsolutePath) + val fm = createManager(basePath) + // Mkdirs + val dir = new Path(s"$basePath/dir/subdir/subsubdir") + assert(!fm.exists(dir)) + fm.mkdirs(dir) + assert(fm.exists(dir)) + fm.mkdirs(dir) + + // List + val acceptAllFilter = new PathFilter { + override def accept(path: Path): Boolean = true + } + val rejectAllFilter = new PathFilter { + override def accept(path: Path): Boolean = false + } + assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) + assert(fm.list(basePath, rejectAllFilter).length === 0) + + // Create atomic without overwrite + var path = new Path(s"$dir/file") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = false).close() + assert(fm.exists(path)) + quietly { + intercept[IOException] { + // should throw exception since file exists and overwrite is false + fm.createAtomic(path, overwriteIfPossible = false).close() + } + } + + // Create atomic with overwrite if possible + path = new Path(s"$dir/file2") + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).cancel() + assert(!fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() + assert(fm.exists(path)) + fm.createAtomic(path, overwriteIfPossible = true).close() // should not throw exception + + // Open and delete + fm.open(path).close() + fm.delete(path) + assert(!fm.exists(path)) + intercept[IOException] { + fm.open(path) + } + fm.delete(path) // should not throw exception + } + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } +} + +class CheckpointFileManagerSuite extends SparkFunSuite with SharedSparkSession { + + test("CheckpointFileManager.create() should pick up user-specified class from conf") { + withSQLConf( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key -> + classOf[CreateAtomicTestManager].getName) { + val fileManager = + CheckpointFileManager.create(new Path("/"), spark.sessionState.newHadoopConf) + assert(fileManager.isInstanceOf[CreateAtomicTestManager]) + } + } + + test("CheckpointFileManager.create() should fallback from FileContext to FileSystem") { + import CheckpointFileManagerSuiteFileSystem.scheme + spark.conf.set(s"fs.$scheme.impl", classOf[CheckpointFileManagerSuiteFileSystem].getName) + quietly { + withTempDir { temp => + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog.add(0, "batch0")) + assert(metadataLog.getLatest() === Some(0 -> "batch0")) + assert(metadataLog.get(0) === Some("batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) + + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") + assert(metadataLog2.get(0) === Some("batch0")) + assert(metadataLog2.getLatest() === Some(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) + } + } + } +} + +class FileContextBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileContextBasedCheckpointFileManager(path, new Configuration()) + } +} + +class FileSystemBasedCheckpointFileManagerSuite extends CheckpointFileManagerTests { + override def createManager(path: Path): CheckpointFileManager = { + new FileSystemBasedCheckpointFileManager(path, new Configuration()) + } +} + + +/** A fake implementation to test different characteristics of CheckpointFileManager interface */ +class CreateAtomicTestManager(path: Path, hadoopConf: Configuration) + extends FileSystemBasedCheckpointFileManager(path, hadoopConf) { + + import CheckpointFileManager._ + + override def createAtomic(path: Path, overwrite: Boolean): CancellableFSDataOutputStream = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + } + val originalOut = super.createAtomic(path, overwrite) + + new CancellableFSDataOutputStream(originalOut) { + override def close(): Unit = { + if (CreateAtomicTestManager.shouldFailInCreateAtomic) { + throw new IOException("Copy failed intentionally") + } + super.close() + } + + override def cancel(): Unit = { + CreateAtomicTestManager.cancelCalledInCreateAtomic = true + originalOut.cancel() + } + } + } +} + +object CreateAtomicTestManager { + @volatile var shouldFailInCreateAtomic = false + @volatile var cancelCalledInCreateAtomic = false +} + + +/** + * CheckpointFileManagerSuiteFileSystem to test fallback of the CheckpointFileManager + * from FileContext to FileSystem API. + */ +private class CheckpointFileManagerSuiteFileSystem extends RawLocalFileSystem { + import CheckpointFileManagerSuiteFileSystem.scheme + + override def getUri: URI = { + URI.create(s"$scheme:///") + } +} + +private object CheckpointFileManagerSuiteFileSystem { + val scheme = s"CheckpointFileManagerSuiteFileSystem${math.abs(Random.nextInt)}" +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 12eaf63415081..ec961a9ecb592 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -22,15 +22,10 @@ import java.nio.charset.StandardCharsets._ import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - import CompactibleFileStreamLog._ /** -- testing of `object CompactibleFileStreamLog` begins -- */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 4677769c12a35..9268306ce4275 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -17,46 +17,22 @@ package org.apache.spark.sql.execution.streaming -import java.io.{File, FileNotFoundException, IOException} -import java.net.URI +import java.io.File import java.util.ConcurrentModificationException import scala.language.implicitConversions -import scala.util.Random -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs._ import org.scalatest.concurrent.Waiters._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.streaming.FakeFileSystem._ -import org.apache.spark.sql.execution.streaming.HDFSMetadataLog.{FileContextManager, FileManager, FileSystemManager} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { - /** To avoid caching of FS objects */ - override protected def sparkConf = - super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") - private implicit def toOption[A](a: A): Option[A] = Option(a) - test("FileManager: FileContextManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileContextManager(path, new Configuration)) - } - } - - test("FileManager: FileSystemManager") { - withTempDir { temp => - val path = new Path(temp.getAbsolutePath) - testFileManager(path, new FileSystemManager(path, new Configuration)) - } - } - test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir @@ -82,26 +58,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - spark.conf.set( - s"fs.$scheme.impl", - classOf[FakeFileSystem].getName) - withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog.add(0, "batch0")) - assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - - - val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://${temp.toURI.getPath}") - assert(metadataLog2.get(0) === Some("batch0")) - assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) - - } - } - test("HDFSMetadataLog: purge") { withTempDir { temp => val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) @@ -121,7 +77,8 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { // There should be exactly one file, called "2", in the metadata directory. // This check also tests for regressions of SPARK-17475 - val allFiles = new File(metadataLog.metadataPath.toString).listFiles().toSeq + val allFiles = new File(metadataLog.metadataPath.toString).listFiles() + .filter(!_.getName.startsWith(".")).toSeq assert(allFiles.size == 1) assert(allFiles(0).getName() == "2") } @@ -172,7 +129,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - test("HDFSMetadataLog: metadata directory collision") { + testQuietly("HDFSMetadataLog: metadata directory collision") { withTempDir { temp => val waiter = new Waiter val maxBatchId = 100 @@ -206,60 +163,6 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } } - /** Basic test case for [[FileManager]] implementation. */ - private def testFileManager(basePath: Path, fm: FileManager): Unit = { - // Mkdirs - val dir = new Path(s"$basePath/dir/subdir/subsubdir") - assert(!fm.exists(dir)) - fm.mkdirs(dir) - assert(fm.exists(dir)) - fm.mkdirs(dir) - - // List - val acceptAllFilter = new PathFilter { - override def accept(path: Path): Boolean = true - } - val rejectAllFilter = new PathFilter { - override def accept(path: Path): Boolean = false - } - assert(fm.list(basePath, acceptAllFilter).exists(_.getPath.getName == "dir")) - assert(fm.list(basePath, rejectAllFilter).length === 0) - - // Create - val path = new Path(s"$dir/file") - assert(!fm.exists(path)) - fm.create(path).close() - assert(fm.exists(path)) - intercept[IOException] { - fm.create(path) - } - - // Open and delete - fm.open(path).close() - fm.delete(path) - assert(!fm.exists(path)) - intercept[IOException] { - fm.open(path) - } - fm.delete(path) // should not throw exception - - // Rename - val path1 = new Path(s"$dir/file1") - val path2 = new Path(s"$dir/file2") - fm.create(path1).close() - assert(fm.exists(path1)) - fm.rename(path1, path2) - intercept[FileNotFoundException] { - fm.rename(path1, path2) - } - val path3 = new Path(s"$dir/file3") - fm.create(path3).close() - assert(fm.exists(path3)) - intercept[FileAlreadyExistsException] { - fm.rename(path2, path3) - } - } - test("verifyBatchIds") { import HDFSMetadataLog.verifyBatchIds verifyBatchIds(Seq(1L, 2L, 3L), Some(1L), Some(3L)) @@ -277,14 +180,3 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { intercept[IllegalStateException](verifyBatchIds(Seq(1, 2, 4, 5), Some(1L), Some(5L))) } } - -/** FakeFileSystem to test fallback of the HDFSMetadataLog from FileContext to FileSystem API */ -class FakeFileSystem extends RawLocalFileSystem { - override def getUri: URI = { - URI.create(s"$scheme:///") - } -} - -object FakeFileSystem { - val scheme = s"HDFSMetadataLogSuite${math.abs(Random.nextInt)}" -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index c843b65020d8c..73f8705060402 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI import java.util.UUID -import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable @@ -28,17 +27,17 @@ import scala.util.Random import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} +import org.apache.hadoop.fs._ import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkFunSuite} +import org.apache.spark._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.count import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -138,7 +137,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] assert(getData(provider, 19) === Set("a" -> 19)) } - test("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { + testQuietly("SPARK-19677: Committing a delta file atop an existing one should not fail on HDFS") { val conf = new Configuration() conf.set("fs.fake.impl", classOf[RenameLikeHDFSFileSystem].getName) conf.set("fs.defaultFS", "fake:///") @@ -344,7 +343,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } - test("SPARK-18342: commit fails when rename fails") { + testQuietly("SPARK-18342: commit fails when rename fails") { import RenameReturnsFalseFileSystem._ val dir = scheme + "://" + newDir() val conf = new Configuration() @@ -366,7 +365,7 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] def numTempFiles: Int = { if (deltaFileDir.exists) { - deltaFileDir.listFiles.map(_.getName).count(n => n.contains("temp") && !n.startsWith(".")) + deltaFileDir.listFiles.map(_.getName).count(n => n.endsWith(".tmp")) } else 0 } @@ -471,6 +470,43 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("error writing [version].delta cancels the output stream") { + + val hadoopConf = new Configuration() + hadoopConf.set( + SQLConf.STREAMING_CHECKPOINT_FILE_MANAGER_CLASS.parent.key, + classOf[CreateAtomicTestManager].getName) + val remoteDir = Utils.createTempDir().getAbsolutePath + + val provider = newStoreProvider( + opId = Random.nextInt, partition = 0, dir = remoteDir, hadoopConf = hadoopConf) + + // Disable failure of output stream and generate versions + CreateAtomicTestManager.shouldFailInCreateAtomic = false + for (version <- 1 to 10) { + val store = provider.getStore(version - 1) + put(store, version.toString, version) // update "1" -> 1, "2" -> 2, ... + store.commit() + } + val version10Data = (1L to 10).map(_.toString).map(x => x -> x).toSet + + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store = provider.getStore(10) + // Fail commit for next version and verify that reloading resets the files + CreateAtomicTestManager.shouldFailInCreateAtomic = true + put(store, "11", 11) + val e = intercept[IllegalStateException] { quietly { store.commit() } } + assert(e.getCause.isInstanceOf[IOException]) + CreateAtomicTestManager.shouldFailInCreateAtomic = false + + // Abort commit for next version and verify that reloading resets the files + CreateAtomicTestManager.cancelCalledInCreateAtomic = false + val store2 = provider.getStore(10) + put(store2, "11", 11) + store2.abort() + assert(CreateAtomicTestManager.cancelCalledInCreateAtomic) + } + override def newStoreProvider(): HDFSBackedStateStoreProvider = { newStoreProvider(opId = Random.nextInt(), partition = 0) } @@ -720,6 +756,14 @@ abstract class StateStoreSuiteBase[ProviderClass <: StateStoreProvider] * this provider */ def getData(storeProvider: ProviderClass, version: Int): Set[(String, Int)] + + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } } object StateStoreTestsHelper { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index f5884b9c8de12..ef74efef156d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -171,6 +171,25 @@ class ContinuousSuite extends ContinuousSuiteBase { "Continuous processing does not support current time operations.")) } + test("subquery alias") { + val df = spark.readStream + .format("rate") + .option("numPartitions", "5") + .option("rowsPerSecond", "5") + .load() + .createOrReplaceTempView("rate") + val test = spark.sql("select value from rate where value > 5") + + testStream(test, useV2Sink = true)( + StartStream(longContinuousTrigger), + AwaitEpoch(0), + Execute(waitForRateSourceTriggers(_, 2)), + IncrementEpoch(), + Execute(waitForRateSourceTriggers(_, 4)), + IncrementEpoch(), + CheckAnswerRowsContains(scala.Range(6, 20).map(Row(_)))) + } + test("repeatedly restart") { val df = spark.readStream .format("rate") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index e758c865b908f..8968dbf36d507 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -60,6 +60,7 @@ trait SharedSparkSession protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { + SparkSession.cleanupAnyExistingSession() new TestSparkSession(sparkConf) } @@ -92,11 +93,22 @@ trait SharedSparkSession * Stop the underlying [[org.apache.spark.SparkContext]], if any. */ protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index cc8907a0bbc93..b5444a4217924 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -381,7 +381,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { case (attr, ordinal) => soi.getStructFieldRef(attr.name) -> ordinal - }.unzip + }.toArray.unzip /** * Builds specific unwrappers ahead of time according to object inspector diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index 1a86c604d5da3..3af163af0968c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -419,7 +419,7 @@ class PartitionedTablePerfStatsSuite HiveCatalogMetrics.reset() spark.read.load(dir.getAbsolutePath) assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) - assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 0) } } }