diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index ccd8504ba0d7b..56435a706bf56 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -61,6 +61,15 @@ jackson-annotations + + org.apache.hadoop + hadoop-client + + + org.apache.hadoop + hadoop-yarn-common + + org.slf4j diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 8f354ad78bbaa..d8697287d2859 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -75,6 +75,7 @@ public class TransportClient implements Closeable { private final Channel channel; private final TransportResponseHandler handler; @Nullable private String clientId; + @Nullable private String clientUser; private volatile boolean timedOut; public TransportClient(Channel channel, TransportResponseHandler handler) { @@ -114,6 +115,25 @@ public void setClientId(String id) { this.clientId = id; } + /** + * Returns the user name used by the client to authenticate itself when authentication is enabled. + * + * @return The client User Name, or null if authentication is disabled. + */ + public String getClientUser() { + return clientUser; + } + + /** + * Sets the authenticated client's user name. This is meant to be used by the authentication layer. + * + * Trying to set a different client User Name after it's been set will result in an exception. + */ + public void setClientUser(String user) { + Preconditions.checkState(clientUser == null, "Client User Name has already been set."); + this.clientUser = user; + } + /** * Requests a single chunk from the remote side, from the pre-negotiated streamId. * diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java index 3c263783a6104..e80f84f30135e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthClientBootstrap.java @@ -95,8 +95,9 @@ public void doBootstrap(TransportClient client, Channel channel) { private void doSparkAuth(TransportClient client, Channel channel) throws GeneralSecurityException, IOException { + String user = secretKeyHolder.getSaslUser(appId); String secretKey = secretKeyHolder.getSecretKey(appId); - try (AuthEngine engine = new AuthEngine(appId, secretKey, conf)) { + try (AuthEngine engine = new AuthEngine(appId, user, secretKey, conf)) { ClientChallenge challenge = engine.challenge(); ByteBuf challengeData = Unpooled.buffer(challenge.encodedLength()); challenge.encode(challengeData); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java index b769ebeba36cc..10e8e8570a2f9 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthEngine.java @@ -39,11 +39,17 @@ import org.apache.commons.crypto.cipher.CryptoCipherFactory; import org.apache.commons.crypto.random.CryptoRandom; import org.apache.commons.crypto.random.CryptoRandomFactory; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenIdentifier; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.apache.spark.network.util.TransportConf; +import static org.apache.spark.network.util.HadoopSecurityUtils.decodeMasterKey; +import static org.apache.spark.network.util.HadoopSecurityUtils.getClientToAMSecretKey; +import static org.apache.spark.network.util.HadoopSecurityUtils.getIdentifier; + /** * A helper class for abstracting authentication and key negotiation details. This is used by * both client and server sides, since the operations are basically the same. @@ -54,11 +60,13 @@ class AuthEngine implements Closeable { private static final BigInteger ONE = new BigInteger(new byte[] { 0x1 }); private final byte[] appId; - private final char[] secret; + private final byte[] user; + private char[] secret; private final TransportConf conf; private final Properties cryptoConf; private final CryptoRandom random; + private String clientUser; private byte[] authNonce; @VisibleForTesting @@ -69,13 +77,25 @@ class AuthEngine implements Closeable { private CryptoCipher decryptor; AuthEngine(String appId, String secret, TransportConf conf) throws GeneralSecurityException { + this(appId, "",secret, conf); + } + + AuthEngine(String appId, String user, String secret, TransportConf conf) throws GeneralSecurityException { this.appId = appId.getBytes(UTF_8); + this.user = user.getBytes(UTF_8); this.conf = conf; this.cryptoConf = conf.cryptoConf(); this.secret = secret.toCharArray(); this.random = CryptoRandomFactory.getCryptoRandom(cryptoConf); } + /** + * Returns the user name of the client. + */ + public String getClientUserName() { + return clientUser; + } + /** * Create the client challenge. * @@ -89,6 +109,7 @@ ClientChallenge challenge() throws GeneralSecurityException, IOException { this.challenge = randomBytes(conf.encryptionKeyLength() / Byte.SIZE); return new ClientChallenge(new String(appId, UTF_8), + new String(user, UTF_8), conf.keyFactoryAlgorithm(), conf.keyFactoryIterations(), conf.cipherTransformation(), @@ -106,9 +127,22 @@ ClientChallenge challenge() throws GeneralSecurityException, IOException { */ ServerResponse respond(ClientChallenge clientChallenge) throws GeneralSecurityException, IOException { + SecretKeySpec authKey; + if (conf.isConnectionUsingTokens()) { + // Create a secret from client's token identifier and AM's master key. + ClientToAMTokenSecretManager secretManager = new ClientToAMTokenSecretManager(null, + decodeMasterKey(new String(secret))); + ClientToAMTokenIdentifier identifier = getIdentifier(clientChallenge.user); + secret = getClientToAMSecretKey(identifier, secretManager); + + clientUser = identifier.getUser().getShortUserName(); + } else { + clientUser = clientChallenge.user; + } + + authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, clientChallenge.nonce, + clientChallenge.keyLength); - SecretKeySpec authKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, - clientChallenge.nonce, clientChallenge.keyLength); initializeForAuth(clientChallenge.cipher, clientChallenge.nonce, authKey); byte[] challenge = validateChallenge(clientChallenge.nonce, clientChallenge.challenge); @@ -119,6 +153,7 @@ ServerResponse respond(ClientChallenge clientChallenge) SecretKeySpec sessionKey = generateKey(clientChallenge.kdf, clientChallenge.iterations, sessionNonce, clientChallenge.keyLength); + this.sessionCipher = new TransportCipher(cryptoConf, clientChallenge.cipher, sessionKey, inputIv, outputIv); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 8a6e3858081bf..b50e9ff1a10ac 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -114,12 +114,14 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // Here we have the client challenge, so perform the new auth protocol and set up the channel. AuthEngine engine = null; try { + String user = secretKeyHolder.getSaslUser(challenge.appId); String secret = secretKeyHolder.getSecretKey(challenge.appId); Preconditions.checkState(secret != null, "Trying to authenticate non-registered app %s.", challenge.appId); LOG.debug("Authenticating challenge for app {}.", challenge.appId); - engine = new AuthEngine(challenge.appId, secret, conf); + engine = new AuthEngine(challenge.appId, user, secret, conf); ServerResponse response = engine.respond(challenge); + client.setClientUser(engine.getClientUserName()); ByteBuf responseData = Unpooled.buffer(response.encodedLength()); response.encode(responseData); callback.onSuccess(responseData.nioBuffer()); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java index 819b8a7efbdba..a6fff5e4b2a7c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/ClientChallenge.java @@ -35,6 +35,7 @@ public class ClientChallenge implements Encodable { private static final byte TAG_BYTE = (byte) 0xFA; public final String appId; + public final String user; public final String kdf; public final int iterations; public final String cipher; @@ -42,8 +43,19 @@ public class ClientChallenge implements Encodable { public final byte[] nonce; public final byte[] challenge; + public ClientChallenge( + String appId, + String kdf, + int iterations, + String cipher, + int keyLength, + byte[] nonce, + byte[] challenge) { + this(appId, "", kdf, iterations, cipher, keyLength, nonce, challenge); + } public ClientChallenge( String appId, + String user, String kdf, int iterations, String cipher, @@ -51,6 +63,7 @@ public ClientChallenge( byte[] nonce, byte[] challenge) { this.appId = appId; + this.user = user; this.kdf = kdf; this.iterations = iterations; this.cipher = cipher; @@ -63,6 +76,7 @@ public ClientChallenge( public int encodedLength() { return 1 + 4 + 4 + Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(user) + Encoders.Strings.encodedLength(kdf) + Encoders.Strings.encodedLength(cipher) + Encoders.ByteArrays.encodedLength(nonce) + @@ -73,6 +87,7 @@ public int encodedLength() { public void encode(ByteBuf buf) { buf.writeByte(TAG_BYTE); Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, user); Encoders.Strings.encode(buf, kdf); buf.writeInt(iterations); Encoders.Strings.encode(buf, cipher); @@ -89,6 +104,7 @@ public static ClientChallenge decodeMessage(ByteBuffer buffer) { } return new ClientChallenge( + Encoders.Strings.decode(buf), Encoders.Strings.decode(buf), Encoders.Strings.decode(buf), buf.readInt(), diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 647813772294e..5b95001b0f07d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -57,7 +57,7 @@ public SaslClientBootstrap(TransportConf conf, String appId, SecretKeyHolder sec */ @Override public void doBootstrap(TransportClient client, Channel channel) { - SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption()); + SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder, conf.saslEncryption(), conf); try { byte[] payload = saslClient.firstToken(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index 0231428318add..721bd3089b511 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -95,7 +95,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // First message in the handshake, setup the necessary state. client.setClientId(saslMessage.appId); saslServer = new SparkSaslServer(saslMessage.appId, secretKeyHolder, - conf.saslServerAlwaysEncrypt()); + conf.saslServerAlwaysEncrypt(), conf); } byte[] response; @@ -114,6 +114,7 @@ public void receive(TransportClient client, ByteBuffer message, RpcResponseCallb // method returns. This assumes that the code ensures, through other means, that no outbound // messages are being written to the channel while negotiation is still going on. if (saslServer.isComplete()) { + client.setClientUser(saslServer.getUserName()); if (!SparkSaslServer.QOP_AUTH_CONF.equals(saslServer.getNegotiatedProperty(Sasl.QOP))) { logger.debug("SASL authentication successful for channel {}", client); complete(true); diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java index b6256debb8e3e..5d984a77885d7 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslClient.java @@ -35,8 +35,11 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.spark.network.util.TransportConf; + import static org.apache.spark.network.sasl.SparkSaslServer.*; + /** * A SASL Client for Spark which simply keeps track of the state of a single SASL session, from the * initial state to the "authenticated" state. This client initializes the protocol via a @@ -48,12 +51,25 @@ public class SparkSaslClient implements SaslEncryptionBackend { private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; private final String expectedQop; + private TransportConf conf; private SaslClient saslClient; - public SparkSaslClient(String secretKeyId, SecretKeyHolder secretKeyHolder, boolean encrypt) { + public SparkSaslClient( + String secretKeyId, + SecretKeyHolder secretKeyHolder, + boolean alwaysEncrypt) { + this(secretKeyId,secretKeyHolder,alwaysEncrypt, null); + } + + public SparkSaslClient( + String secretKeyId, + SecretKeyHolder secretKeyHolder, + boolean encrypt, + TransportConf conf) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; this.expectedQop = encrypt ? QOP_AUTH_CONF : QOP_AUTH; + this.conf = conf; Map saslProps = ImmutableMap.builder() .put(Sasl.QOP, expectedQop) @@ -131,11 +147,23 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback if (callback instanceof NameCallback) { logger.trace("SASL client callback: setting username"); NameCallback nc = (NameCallback) callback; - nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + if (conf != null && conf.isConnectionUsingTokens()) { + // Token Identifier is already encoded + nc.setName(secretKeyHolder.getSaslUser(secretKeyId)); + } else { + nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); + } + } else if (callback instanceof PasswordCallback) { logger.trace("SASL client callback: setting password"); PasswordCallback pc = (PasswordCallback) callback; - pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + if (conf != null && conf.isConnectionUsingTokens()) { + // Token Identifier is already encoded + pc.setPassword(secretKeyHolder.getSecretKey(secretKeyId).toCharArray()); + } else { + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + + } } else if (callback instanceof RealmCallback) { logger.trace("SASL client callback: setting realm"); RealmCallback rc = (RealmCallback) callback; diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java index 00f3e83dbc8b3..f836bfbf9e913 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SparkSaslServer.java @@ -40,6 +40,15 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenIdentifier; + +import static org.apache.spark.network.util.HadoopSecurityUtils.decodeMasterKey; +import static org.apache.spark.network.util.HadoopSecurityUtils.getClientToAMSecretKey; +import static org.apache.spark.network.util.HadoopSecurityUtils.getIdentifier; + +import org.apache.spark.network.util.TransportConf; + /** * A SASL Server for Spark which simply keeps track of the state of a single SASL session, from the * initial state to the "authenticated" state. (It is not a server in the sense of accepting @@ -73,14 +82,25 @@ public class SparkSaslServer implements SaslEncryptionBackend { /** Identifier for a certain secret key within the secretKeyHolder. */ private final String secretKeyId; private final SecretKeyHolder secretKeyHolder; + private TransportConf conf; + private String clientUser; private SaslServer saslServer; public SparkSaslServer( String secretKeyId, SecretKeyHolder secretKeyHolder, boolean alwaysEncrypt) { + this(secretKeyId, secretKeyHolder, alwaysEncrypt, null); + } + + public SparkSaslServer( + String secretKeyId, + SecretKeyHolder secretKeyHolder, + boolean alwaysEncrypt, + TransportConf conf) { this.secretKeyId = secretKeyId; this.secretKeyHolder = secretKeyHolder; + this.conf = conf; // Sasl.QOP is a comma-separated list of supported values. The value that allows encryption // is listed first since it's preferred over the non-encrypted one (if the client also @@ -98,6 +118,13 @@ public SparkSaslServer( } } + /** + * Returns the user name of the client. + */ + public String getUserName() { + return clientUser; + } + /** * Determines whether the authentication exchange has completed successfully. */ @@ -156,15 +183,16 @@ public byte[] unwrap(byte[] data, int offset, int len) throws SaslException { private class DigestCallbackHandler implements CallbackHandler { @Override public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { + NameCallback nc = null; + PasswordCallback pc = null; for (Callback callback : callbacks) { if (callback instanceof NameCallback) { logger.trace("SASL server callback: setting username"); - NameCallback nc = (NameCallback) callback; + nc = (NameCallback) callback; nc.setName(encodeIdentifier(secretKeyHolder.getSaslUser(secretKeyId))); } else if (callback instanceof PasswordCallback) { logger.trace("SASL server callback: setting password"); - PasswordCallback pc = (PasswordCallback) callback; - pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + pc = (PasswordCallback) callback; } else if (callback instanceof RealmCallback) { logger.trace("SASL server callback: setting realm"); RealmCallback rc = (RealmCallback) callback; @@ -182,10 +210,21 @@ public void handle(Callback[] callbacks) throws IOException, UnsupportedCallback throw new UnsupportedCallbackException(callback, "Unrecognized SASL DIGEST-MD5 Callback"); } } + if (pc != null) { + if (conf != null && conf.isConnectionUsingTokens()) { + ClientToAMTokenSecretManager secretManager = new ClientToAMTokenSecretManager(null, + decodeMasterKey(secretKeyHolder.getSecretKey(secretKeyId))); + ClientToAMTokenIdentifier identifier = getIdentifier(nc.getDefaultName()); + clientUser = identifier.getUser().getShortUserName(); + pc.setPassword(getClientToAMSecretKey(identifier, secretManager)); + } else { + clientUser = nc.getDefaultName(); + pc.setPassword(encodePassword(secretKeyHolder.getSecretKey(secretKeyId))); + } + } } } - - /* Encode a byte[] identifier as a Base64-encoded string. */ + /** Encode a String identifier as a Base64-encoded string. */ public static String encodeIdentifier(String identifier) { Preconditions.checkNotNull(identifier, "User cannot be null if SASL is enabled"); return getBase64EncodedString(identifier); diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/HadoopSecurityUtils.java b/common/network-common/src/main/java/org/apache/spark/network/util/HadoopSecurityUtils.java new file mode 100644 index 0000000000000..cffceb53a7361 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/util/HadoopSecurityUtils.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.util; + +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; + +import org.apache.hadoop.security.token.SecretManager.InvalidToken; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenSecretManager; +import org.apache.hadoop.yarn.security.client.ClientToAMTokenIdentifier; + +/** + * Utility methods related to the hadoop security + */ +public class HadoopSecurityUtils { + + /** Creates an ClientToAMTokenIdentifier from the encoded Base-64 String */ + public static ClientToAMTokenIdentifier getIdentifier(String id) throws InvalidToken { + byte[] tokenId = byteBufToByte(Base64.decode( + Unpooled.wrappedBuffer(id.getBytes(StandardCharsets.UTF_8)))); + + ClientToAMTokenIdentifier tokenIdentifier = new ClientToAMTokenIdentifier(); + try { + tokenIdentifier.readFields(new DataInputStream(new ByteArrayInputStream(tokenId))); + } catch (IOException e) { + throw (InvalidToken) new InvalidToken( + "Can't de-serialize tokenIdentifier").initCause(e); + } + return tokenIdentifier; + } + + /** Returns an Base64-encoded secretKey created from the Identifier and the secretmanager */ + public static char[] getClientToAMSecretKey(ClientToAMTokenIdentifier tokenid, + ClientToAMTokenSecretManager secretManager) throws InvalidToken { + byte[] password = secretManager.retrievePassword(tokenid); + return Base64.encode(Unpooled.wrappedBuffer(password)).toString(StandardCharsets.UTF_8) + .toCharArray(); + } + + /** Decode a base64-encoded MasterKey as a byte[] array. */ + public static byte[] decodeMasterKey(String masterKey) { + ByteBuf masterKeyByteBuf = Base64.decode(Unpooled.wrappedBuffer(masterKey.getBytes(StandardCharsets.UTF_8))); + return byteBufToByte(masterKeyByteBuf); + } + + /** Convert an ByteBuf to a byte[] array. */ + private static byte[] byteBufToByte(ByteBuf byteBuf) { + byte[] byteArray = new byte[byteBuf.readableBytes()]; + byteBuf.readBytes(byteArray); + return byteArray; + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 88256b810bf04..d0b975bfdde6e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -108,6 +108,9 @@ public int numConnectionsPerPeer() { /** Number of threads used in the client thread pool. Default to 0, which is 2x#cores. */ public int clientThreads() { return conf.getInt(SPARK_NETWORK_IO_CLIENTTHREADS_KEY, 0); } + /** If true, the current RPC connection is a Client to AM connection */ + public boolean isConnectionUsingTokens() { return conf.getBoolean("spark.rpc.connectionUsingTokens", false); } + /** * Receive buffer size (SO_RCVBUF). * Note: the optimal size for receive buffer and send buffer should be diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2480e56b72ccf..9a9518837cb69 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -222,10 +222,11 @@ private[spark] class SecurityManager( setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", "")) setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", "")) - setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")); - setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")); + setViewAclsGroups(sparkConf.get("spark.ui.view.acls.groups", "")) + setModifyAclsGroups(sparkConf.get("spark.modify.acls.groups", "")) - private val secretKey = generateSecretKey() + private var identifier = "sparkSaslUser" + private var secretKey = generateSecretKey() logInfo("SecurityManager: authentication " + (if (authOn) "enabled" else "disabled") + "; ui acls " + (if (aclsOn) "enabled" else "disabled") + "; users with view permissions: " + viewAcls.toString() + @@ -533,11 +534,23 @@ private[spark] class SecurityManager( /** * Gets the user used for authenticating SASL connections. - * For now use a single hardcoded user. * @return the SASL user as a String */ - def getSaslUser(): String = "sparkSaslUser" + def getSaslUser(): String = identifier + + /** + * This can be a user name or unique identifier + */ + def setSaslUser(ident: String) { + identifier = ident + } + /** + * set the secret key + */ + def setSecretKey(secret: String) { + secretKey = secret + } /** * Gets the secret key. * @return the secret key as a String if authentication is enabled, otherwise returns null diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 0ea14361b2f77..27c19cc0ba232 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -60,7 +60,7 @@ import org.apache.spark.util._ */ private[deploy] object SparkSubmitAction extends Enumeration { type SparkSubmitAction = Value - val SUBMIT, KILL, REQUEST_STATUS = Value + val SUBMIT, KILL, REQUEST_STATUS, UPLOAD_CREDENTIALS = Value } /** @@ -123,15 +123,87 @@ object SparkSubmit extends CommandLineUtils { case SparkSubmitAction.SUBMIT => submit(appArgs) case SparkSubmitAction.KILL => kill(appArgs) case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) + case SparkSubmitAction.UPLOAD_CREDENTIALS => uploadCredentials(appArgs) } } /** - * Kill an existing submission using the REST protocol. Standalone and Mesos cluster mode only. + * Kill an existing submission */ private def kill(args: SparkSubmitArguments): Unit = { - new RestSubmissionClient(args.master) - .killSubmission(args.submissionToKill) + if (args.master.startsWith("yarn")) { + // Use RPC protocol. YARN mode only. + val (_, _, sysProps, _) = prepareSubmitEnvironment(args) + val applicationID = Seq("--arg", args.submissionToKill) + if (args.proxyUser != null) { + val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser, + UserGroupInformation.getCurrentUser()) + try { + proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + runMethod(applicationID, ArrayBuffer(), sysProps, + "org.apache.spark.deploy.yarn.Client", "yarnKillSubmission", args.verbose) + } + }) + } catch { + case e: Exception => + // Hadoop's AuthorizationException suppresses the exception's stack trace, which + // makes the message printed to the output by the JVM not very helpful. Instead, + // detect exceptions with empty stack traces here, and treat them differently. + if (e.getStackTrace().length == 0) { + // scalastyle:off println + printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + // scalastyle:on println + exitFn(1) + } else { + throw e + } + } + } else { + runMethod(applicationID, ArrayBuffer(), sysProps, "org.apache.spark.deploy.yarn.Client", + "yarnKillSubmission", args.verbose) + } + } else { + // Use Rest protocol. Standalone and Mesos cluster mode only.. + new RestSubmissionClient(args.master) + .killSubmission(args.submissionToKill) + } + } + + /** + * + */ + private def uploadCredentials(args: SparkSubmitArguments): Unit = { + val (_, _, sysProps, _) = prepareSubmitEnvironment(args) + val applicationID = Seq("--arg", args.submissionToUploadCred) + if (args.proxyUser != null) { + val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser, + UserGroupInformation.getCurrentUser()) + try { + proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { + override def run(): Unit = { + runMethod(applicationID, ArrayBuffer(), sysProps, + "org.apache.spark.deploy.yarn.Client", "yarnUploadCredentials", args.verbose) + } + }) + } catch { + case e: Exception => + // Hadoop's AuthorizationException suppresses the exception's stack trace, which + // makes the message printed to the output by the JVM not very helpful. Instead, + // detect exceptions with empty stack traces here, and treat them differently. + if (e.getStackTrace().length == 0) { + // scalastyle:off println + printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}") + // scalastyle:on println + exitFn(1) + } else { + throw e + } + } + } else { + runMethod(applicationID, ArrayBuffer(), sysProps, "org.apache.spark.deploy.yarn.Client", + "yarnUploadCredentials", args.verbose) + } } /** @@ -163,7 +235,7 @@ object SparkSubmit extends CommandLineUtils { try { proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMethod(childArgs, childClasspath, sysProps, childMainClass, "main", args.verbose) } }) } catch { @@ -181,7 +253,7 @@ object SparkSubmit extends CommandLineUtils { } } } else { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMethod(childArgs, childClasspath, sysProps, childMainClass, "main", args.verbose) } } @@ -681,20 +753,22 @@ object SparkSubmit extends CommandLineUtils { } /** - * Run the main method of the child class using the provided launch environment. + * Run a method of the child class using the provided launch environment. * * Note that this main class will not be the one provided by the user if we're * running cluster deploy mode or python applications. */ - private def runMain( + private def runMethod( childArgs: Seq[String], childClasspath: Seq[String], sysProps: Map[String, String], childMainClass: String, + childMainMethod: String, verbose: Boolean): Unit = { // scalastyle:off println if (verbose) { printStream.println(s"Main class:\n$childMainClass") + printStream.println(s"Main method:\n$childMainMethod") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing printStream.println(s"System properties:\n${Utils.redact(sysProps).mkString("\n")}") @@ -751,7 +825,7 @@ object SparkSubmit extends CommandLineUtils { printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") } - val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) + val mainMethod = mainClass.getMethod(childMainMethod, new Array[String](0).getClass) if (!Modifier.isStatic(mainMethod.getModifiers)) { throw new IllegalStateException("The main method in the given main class must be static") } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index a7722e4f86023..09014ca8db3cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -78,6 +78,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S var supervise: Boolean = false var driverCores: String = null var submissionToKill: String = null + var submissionToUploadCred: String = null var submissionToRequestStatusFor: String = null var useRest: Boolean = true // used internally @@ -245,6 +246,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S action match { case SUBMIT => validateSubmitArguments() case KILL => validateKillArguments() + case UPLOAD_CREDENTIALS => validateUploadCredArguments() case REQUEST_STATUS => validateStatusRequestArguments() } } @@ -294,15 +296,26 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S } private def validateKillArguments(): Unit = { - if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { + if (!master.startsWith("spark://") && !master.startsWith("mesos://") + && !master.startsWith("yarn")) { SparkSubmit.printErrorAndExit( - "Killing submissions is only supported in standalone or Mesos mode!") + "Killing submissions is only supported in YARN, standalone or Mesos mode!") } if (submissionToKill == null) { SparkSubmit.printErrorAndExit("Please specify a submission to kill.") } } + private def validateUploadCredArguments(): Unit = { + if (!master.startsWith("yarn")) { + SparkSubmit.printErrorAndExit( + "Credential Uploading is only supported in Yarn mode!") + } + if (submissionToUploadCred == null) { + SparkSubmit.printErrorAndExit("Please specify a submission to upload credentials to.") + } + } + private def validateStatusRequestArguments(): Unit = { if (!master.startsWith("spark://") && !master.startsWith("mesos://")) { SparkSubmit.printErrorAndExit( @@ -406,6 +419,12 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") } action = KILL + case UPLOAD_CRED_SUBMISSION => + submissionToUploadCred = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $UPLOAD_CREDENTIALS.") + } + action = UPLOAD_CREDENTIALS case STATUS => submissionToRequestStatusFor = value @@ -564,9 +583,11 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | --driver-cores NUM Number of cores used by the driver, only in cluster mode | (Default: 1). | + | All Spark cluster deploy mode or Yarn client mode only: + | --kill SUBMISSION_ID If given, kills the driver specified. + | | Spark standalone or Mesos with cluster deploy mode only: | --supervise If given, restarts the driver on failure. - | --kill SUBMISSION_ID If given, kills the driver specified. | --status SUBMISSION_ID If given, requests the status of the driver specified. | | Spark standalone and Mesos only: @@ -577,6 +598,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | or all available cores on the worker in standalone mode) | | YARN-only: + | --upload_cred SUBMISSION_ID Upload credential of the given application | --queue QUEUE_NAME The YARN queue to submit to (Default: "default"). | --num-executors NUM Number of executors to launch (Default: 2). | If dynamic allocation is enabled, the initial number of diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index a2f1aa22b0063..2a4417579615f 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -17,15 +17,19 @@ package org.apache.spark.executor +import java.io.{ByteArrayInputStream, DataInputStream} import java.net.URL import java.nio.ByteBuffer import java.util.Locale import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.{Failure, Success} import scala.util.control.NonFatal +import org.apache.hadoop.security.{Credentials, UserGroupInformation} + import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil @@ -125,6 +129,25 @@ private[spark] class CoarseGrainedExecutorBackend( }.start() } + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case UploadCredentials(c) => + try { + val dataInput = new DataInputStream(new ByteArrayInputStream(c)) + val credentials = new Credentials + credentials.readFields(dataInput) + logInfo(s"Update credentials with Tokens " + + s"${credentials.getAllTokens.asScala.map(_.getKind.toString).mkString(",")} " + + "to executor") + UserGroupInformation.getCurrentUser.addCredentials(credentials) + context.reply(true) + } catch { + case NonFatal(e) => logWarning(s"Failed to update credentials", e) + context.sendFailure(e) + } + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { if (stopping.get()) { logInfo(s"Driver from $remoteAddress disconnected during shutdown") diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index 117f51c5b8f2a..32f84b95593bf 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -35,7 +35,12 @@ private[spark] trait RpcCallContext { def sendFailure(e: Throwable): Unit /** - * The sender of this message. + * The sender's address of this message. */ def senderAddress: RpcAddress + + /** + * The sender's User Name of this message. + */ + def senderUserName: String } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 904c4d02dd2a4..f6c7f7e30b246 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -121,8 +121,8 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte /** Posts a message sent by a remote endpoint. */ def postRemoteMessage(message: RequestMessage, callback: RpcResponseCallback): Unit = { - val rpcCallContext = - new RemoteNettyRpcCallContext(nettyEnv, callback, message.senderAddress) + val rpcCallContext = new RemoteNettyRpcCallContext(nettyEnv, callback, + message.senderAddress, message.senderUserName) val rpcMessage = RpcMessage(message.senderAddress, message.content, rpcCallContext) postMessage(message.receiver.name, rpcMessage, (e) => callback.onFailure(e)) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala index 7dd7e610a28eb..8c2d9bae00a9c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcCallContext.scala @@ -23,7 +23,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc.{RpcAddress, RpcCallContext} -private[netty] abstract class NettyRpcCallContext(override val senderAddress: RpcAddress) +private[netty] abstract class NettyRpcCallContext( + override val senderAddress: RpcAddress, + override val senderUserName: String = null) extends RpcCallContext with Logging { protected def send(message: Any): Unit @@ -57,8 +59,9 @@ private[netty] class LocalNettyRpcCallContext( private[netty] class RemoteNettyRpcCallContext( nettyEnv: NettyRpcEnv, callback: RpcResponseCallback, - senderAddress: RpcAddress) - extends NettyRpcCallContext(senderAddress) { + senderAddress: RpcAddress, + senderUserName: String) + extends NettyRpcCallContext(senderAddress, senderUserName) { override protected def send(message: Any): Unit = { val reply = nettyEnv.serialize(message) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 1777e7a539751..a298a4cd70296 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -543,7 +543,8 @@ private[netty] class NettyRpcEndpointRef( private[netty] class RequestMessage( val senderAddress: RpcAddress, val receiver: NettyRpcEndpointRef, - val content: Any) { + val content: Any, + val senderUserName: String = null) { /** Manually serialize [[RequestMessage]] to minimize the size. */ def serialize(nettyEnv: NettyRpcEnv): ByteBuffer = { @@ -589,7 +590,11 @@ private[netty] object RequestMessage { } } - def apply(nettyEnv: NettyRpcEnv, client: TransportClient, bytes: ByteBuffer): RequestMessage = { + def apply( + nettyEnv: NettyRpcEnv, + client: TransportClient, + bytes: ByteBuffer, + senderUserName: String = null): RequestMessage = { val bis = new ByteBufferInputStream(bytes) val in = new DataInputStream(bis) try { @@ -601,7 +606,8 @@ private[netty] object RequestMessage { senderAddress, ref, // The remaining bytes in `bytes` are the message content. - nettyEnv.deserialize(client, bytes)) + nettyEnv.deserialize(client, bytes), + senderUserName) } finally { in.close() } @@ -652,10 +658,12 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostString, addr.getPort) - val requestMessage = RequestMessage(nettyEnv, client, message) + var requestMessage = RequestMessage(nettyEnv, client, message, client.getClientUser) + if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. - new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content) + new RequestMessage(clientAddr, requestMessage.receiver, requestMessage.content, + client.getClientUser) } else { // The remote RpcEnv listens to some port, we should also fire a RemoteProcessConnected for // the listening address diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 89a9ad6811e18..00c50f7e5ca00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -46,6 +46,9 @@ private[spark] object CoarseGrainedClusterMessages { case class KillExecutorsOnHost(host: String) extends CoarseGrainedClusterMessage + case class UploadCredentials(credential: Array[Byte]) + extends CoarseGrainedClusterMessage + sealed trait RegisterExecutorResponse case object RegisteredExecutor extends CoarseGrainedClusterMessage with RegisterExecutorResponse @@ -53,6 +56,10 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage with RegisterExecutorResponse + case class StopSparkContext() extends CoarseGrainedClusterMessage + + case class DelegateCredentials(credentials: Array[Byte]) extends CoarseGrainedClusterMessage + // Executors to driver case class RegisterExecutor( executorId: String, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index a46824a0c6fad..0f81979069c95 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -23,6 +23,11 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.concurrent.Future +import scala.util.{Failure, Success} +import scala.util.control.NonFatal + +import org.apache.hadoop.io.DataOutputBuffer +import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{ExecutorAllocationClient, SparkEnv, SparkException, TaskState} import org.apache.spark.internal.Logging @@ -169,10 +174,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // If the executor's rpc env is not listening for incoming connections, `hostPort` // will be null, and the client connection should be used to contact the executor. val executorAddress = if (executorRef.address != null) { - executorRef.address - } else { - context.senderAddress - } + executorRef.address + } else { + context.senderAddress + } logInfo(s"Registered executor $executorRef ($executorAddress) with ID $executorId") addressToExecutorId(executorAddress) = executorId totalCoreCount.addAndGet(cores) @@ -226,6 +231,19 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp val reply = SparkAppConfig(sparkProperties, SparkEnv.get.securityManager.getIOEncryptionKey()) context.reply(reply) + case UploadCredentials(c) => + logInfo("Came to upload Cred in Scheduler") + val futures = executorDataMap.map { case (_, e) => + e.executorEndpoint.ask[Boolean](UploadCredentials(c)) + } + + implicit val executor = ThreadUtils.sameThread + Future.sequence(futures) + .map { booleans => booleans.reduce(_ && _) } + .andThen { + case Success(b) => context.reply(b) + case Failure(NonFatal(e)) => context.sendFailure(e) + } } // Make fake resource offers on all executors @@ -675,6 +693,9 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp driverEndpoint.send(KillExecutorsOnHost(host)) true } + def updateCredentials(credential: Array[Byte]): Future[Boolean] = { + driverEndpoint.ask[Boolean](UploadCredentials(credential)) + } } private[spark] object CoarseGrainedSchedulerBackend { diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e4a74556d4f26..a9a60897bd0ec 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -453,6 +453,23 @@ To use a custom metrics.properties for the application master and executors, upd name matches both the include and the exclude pattern, this file will be excluded eventually. + + spark.yarn.clientToAM.port + 0 + + Port the application master listens on for connections from the client. + This port is specified when registering the AM with YARN so that client can later know which + port to connect to from the application Report. + + + + spark.yarn.hardKillTimeout + 60s + + Number of milliseconds to wait before the job client kills the application. + After the wait, client will attempt to terminate the YARN application. + + # Important notes diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index 5f2da036ff9f7..e7d9041a51126 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -442,6 +442,7 @@ protected boolean handle(String opt, String value) { } break; case KILL_SUBMISSION: + case UPLOAD_CRED_SUBMISSION: case STATUS: isAppResourceReq = false; sparkArgs.add(opt); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java index 6767cc5079649..f0db4a1769564 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitOptionParser.java @@ -48,6 +48,7 @@ class SparkSubmitOptionParser { protected final String FILES = "--files"; protected final String JARS = "--jars"; protected final String KILL_SUBMISSION = "--kill"; + protected final String UPLOAD_CRED_SUBMISSION = "--uploadCred"; protected final String MASTER = "--master"; protected final String NAME = "--name"; protected final String PACKAGES = "--packages"; @@ -102,6 +103,7 @@ class SparkSubmitOptionParser { { JARS }, { KEYTAB }, { KILL_SUBMISSION }, + { UPLOAD_CRED_SUBMISSION }, { MASTER }, { NAME }, { NUM_EXECUTORS }, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e227bff88f71d..5a0728d88c2db 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -17,22 +17,29 @@ package org.apache.spark.deploy.yarn -import java.io.{File, IOException} +import java.io.{ByteArrayInputStream, DataInputStream, File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URI, URL} import java.util.concurrent.{TimeoutException, TimeUnit} +import javax.crypto.SecretKey +import javax.crypto.spec.SecretKeySpec import scala.collection.mutable.HashMap import scala.concurrent.Promise import scala.concurrent.duration.Duration import scala.util.control.NonFatal +import com.google.common.base.Charsets +import io.netty.buffer.ByteBuf +import io.netty.buffer.Unpooled +import io.netty.handler.codec.base64.Base64 import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.exceptions.ApplicationAttemptNotFoundException import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -41,9 +48,16 @@ import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.{AMCredentialRenewer, YARNHadoopDelegationTokenManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ +import org.apache.spark.network.{BlockDataManager, TransportContext} +import org.apache.spark.network.client.TransportClientBootstrap +import org.apache.spark.network.netty.{NettyBlockRpcServer, SparkTransportConf} +import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} +import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap} import org.apache.spark.rpc._ +import org.apache.spark.rpc.netty.NettyRpcCallContext import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ +import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util._ /** @@ -89,6 +103,7 @@ private[spark] class ApplicationMaster( @volatile private var reporterThread: Thread = _ @volatile private var allocator: YarnAllocator = _ + @volatile private var clientToAMPort: Int = _ // A flag to check whether user has initialized spark context @volatile private var registered = false @@ -247,7 +262,9 @@ private[spark] class ApplicationMaster( if (!unregistered) { // we only want to unregister if we don't want the RM to retry - if (finalStatus == FinalApplicationStatus.SUCCEEDED || isLastAttempt) { + if (finalStatus == FinalApplicationStatus.SUCCEEDED || + finalStatus == FinalApplicationStatus.KILLED || + isLastAttempt) { unregister(finalStatus, finalMsg) cleanupStagingDir() } @@ -283,6 +300,7 @@ private[spark] class ApplicationMaster( credentialRenewerThread.start() credentialRenewerThread.join() } + clientToAMPort = sparkConf.getInt("spark.yarn.clientToAM.port", 0) if (isClusterMode) { runDriver(securityMgr) @@ -402,7 +420,8 @@ private[spark] class ApplicationMaster( uiAddress, historyAddress, securityMgr, - localResources) + localResources, + clientToAMPort) // Initialize the AM endpoint *after* the allocator has been initialized. This ensures // that when the driver sends an initial executor request (e.g. after an AM restart), @@ -422,6 +441,89 @@ private[spark] class ApplicationMaster( YarnSchedulerBackend.ENDPOINT_NAME) } + /** + * Create an [[RpcEndpoint]] that communicates with the client. + * + * @return A reference to the application master's RPC endpoint. + */ + private def runClientAMEndpoint( + port: Int, + driverRef: RpcEndpointRef, + securityManager: SecurityManager): RpcEndpointRef = { + val serversparkConf = new SparkConf() + serversparkConf.set("spark.rpc.connectionUsingTokens", "true") + + val amRpcEnv = + RpcEnv.create(ApplicationMaster.SYSTEM_NAME, Utils.localHostName(), port, serversparkConf, + securityManager) + clientToAMPort = amRpcEnv.address.port + + val clientAMEndpoint = + amRpcEnv.setupEndpoint(ApplicationMaster.ENDPOINT_NAME, + new ClientToAMEndpoint(amRpcEnv, driverRef, securityManager)) + clientAMEndpoint + } + + /** RpcEndpoint class for ClientToAM */ + private[spark] class ClientToAMEndpoint( + override val rpcEnv: RpcEnv, + driverRef: RpcEndpointRef, + securityManager: SecurityManager) + extends RpcEndpoint with Logging { + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case ApplicationMasterMessages.KillApplication => + if (securityManager.checkModifyPermissions(context.senderUserName)) { + driverRef.send(StopSparkContext) + finish(FinalApplicationStatus.KILLED, ApplicationMaster.EXIT_KILLED) + context.reply(true) + } else { + context.reply(false) + } + case ApplicationMasterMessages.UploadCredentials(c) => + if (securityManager.checkModifyPermissions(context.senderUserName)) { + context.reply(true) + val dataInput = new DataInputStream(new ByteArrayInputStream(c)) + val credentials = new Credentials + credentials.readFields(dataInput) + if (credentials != null) { + logInfo(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) + } + UserGroupInformation.getCurrentUser.addCredentials(credentials) + + try { + val timeout = RpcUtils.askRpcTimeout(sparkConf) + val success = timeout.awaitResult(driverRef.ask[Boolean](DelegateCredentials(c))) + if (!success) { + throw new SparkException(s"Current user doesn't have modify ACL") + context.reply(false) + } + } catch { + case e: TimeoutException => + throw new SparkException(s"Timed out waiting to upload credential") + context.reply(false) + } +// sc.schedulerBackend match { +// case s: CoarseGrainedSchedulerBackend => +// logInfo(s"Update credentials in driver") +// val f = s.updateCredentials(c) +// f onSuccess { +// case b => context.reply(b) +// } +// f onFailure { +// case NonFatal(e) => context.sendFailure(e) +// e +// } +// case _ => +// throw new SparkException(s"Update credentials on" + +// s" ${sc.schedulerBackend.getClass.getSimpleName} is not supported") +// } + } else { + context.reply(false) + } + } + } + private def runDriver(securityMgr: SecurityManager): Unit = { addAmIpFilter(None) userClassThread = startUserApplication() @@ -438,8 +540,12 @@ private[spark] class ApplicationMaster( val driverRef = createSchedulerRef( sc.getConf.get("spark.driver.host"), sc.getConf.get("spark.driver.port")) + val clientToAMSecurityManager = new SecurityManager(sparkConf) + runClientAMEndpoint(clientToAMPort, driverRef, clientToAMSecurityManager) registerAM(sc.getConf, rpcEnv, driverRef, sc.ui.map(_.webUrl), securityMgr) registered = true + clientToAMSecurityManager.setSecretKey(Base64.encode( + Unpooled.wrappedBuffer(client.getMasterKey)).toString(Charsets.UTF_8)); } else { // Sanity check; should never happen in normal operation, since sc should only be null // if the user app did not create a SparkContext. @@ -464,10 +570,13 @@ private[spark] class ApplicationMaster( amCores, true) val driverRef = waitForSparkDriver() addAmIpFilter(Some(driverRef)) + val clientToAMSecurityManager = new SecurityManager(sparkConf) + runClientAMEndpoint(clientToAMPort, driverRef, clientToAMSecurityManager) registerAM(sparkConf, rpcEnv, driverRef, sparkConf.getOption("spark.driver.appUIAddress"), securityMgr) registered = true - + clientToAMSecurityManager.setSecretKey(Base64.encode( + Unpooled.wrappedBuffer(client.getMasterKey)).toString(Charsets.UTF_8)); // In client mode the actor will stop the reporter thread. reporterThread.join() } @@ -749,8 +858,20 @@ private[spark] class ApplicationMaster( } +sealed trait ApplicationMasterMessage extends Serializable + +private [spark] object ApplicationMasterMessages { + + case class KillApplication() extends ApplicationMasterMessage + + case class UploadCredentials(credentials: Array[Byte]) extends ApplicationMasterMessage +} + object ApplicationMaster extends Logging { + val SYSTEM_NAME = "sparkYarnAM" + val ENDPOINT_NAME = "clientToAM" + // exit codes for different causes, no reason behind the values private val EXIT_SUCCESS = 0 private val EXIT_UNCAUGHT_EXCEPTION = 10 @@ -760,6 +881,7 @@ object ApplicationMaster extends Logging { private val EXIT_SECURITY = 14 private val EXIT_EXCEPTION_USER_CLASS = 15 private val EXIT_EARLY = 16 + private val EXIT_KILLED = 17 private var master: ApplicationMaster = _ diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index d408ca90a5d1c..87ef59337a6f5 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -18,19 +18,25 @@ package org.apache.spark.deploy.yarn import java.io.{File, FileOutputStream, IOException, OutputStreamWriter} -import java.net.{InetAddress, UnknownHostException, URI} +import java.net.{InetAddress, InetSocketAddress, UnknownHostException, URI} import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import java.security.PrivilegedExceptionAction import java.util.{Locale, Properties, UUID} +import java.util.concurrent.TimeoutException import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} +import scala.concurrent.ExecutionContext +import scala.util.control.Breaks._ import scala.util.control.NonFatal +import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files +import io.netty.buffer.Unpooled +import io.netty.handler.codec.base64.Base64 import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission @@ -45,16 +51,18 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.{YarnClient, YarnClientApplication} import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException -import org.apache.hadoop.yarn.util.Records +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.yarn.ApplicationMasterMessages.{KillApplication, UploadCredentials} import org.apache.spark.deploy.yarn.config._ import org.apache.spark.deploy.yarn.security.YARNHadoopDelegationTokenManager import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} -import org.apache.spark.util.{CallerContext, Utils} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.util.{CallerContext, RpcUtils, SparkExitCode, ThreadUtils, Utils} private[spark] class Client( val args: ClientArguments, @@ -1149,6 +1157,129 @@ private[spark] class Client( } } + def killSparkApplication(securityManager: SecurityManager): Unit = { + setupCredentials() + yarnClient.init(yarnConf) + yarnClient.start + val appId = ConverterUtils.toApplicationId(args.userArgs(0)) + val AMEndpoint = setupAMConnection(appId, securityManager) + try { + val timeout = RpcUtils.askRpcTimeout(sparkConf) + val success = timeout.awaitResult(AMEndpoint.ask[Boolean](KillApplication)) + if (!success) { + throw new SparkException(s"Current user doesn't have modify ACL") + return + } + } catch { + case e: TimeoutException => + throw new SparkException(s"Timed out waiting to kill the application: $appId") + } + + var currentTime = System.currentTimeMillis + val timeKillIssued = currentTime + + val killTimeOut = sparkConf.get(CLIENT_TO_AM_HARD_KILL_TIMEOUT) + while ((currentTime< timeKillIssued + killTimeOut) + && !isAppInTerminalState(appId)) { + try + Thread.sleep(1000L) + catch { + case ie: InterruptedException => break + } + currentTime = System.currentTimeMillis + } + if (!isAppInTerminalState(appId)) { + yarnClient.killApplication(appId) + } + } + def uploadCredentials(securityManager: SecurityManager): Unit = { + yarnClient.init(yarnConf) + yarnClient.start + + setupCredentials() + credentialManager.obtainDelegationTokens(hadoopConf, credentials) + + val dob = new DataOutputBuffer() + if (credentials != null) { + UserGroupInformation.getCurrentUser.addCredentials(credentials) + credentials.write(dob) + logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) + } + + val AMEndpoint = setupAMConnection(ConverterUtils.toApplicationId(args.userArgs(0)), + securityManager) +// val timeout = RpcUtils.askRpcTimeout(sparkConf) +// val success = timeout.awaitResult( +// AMEndpoint.ask[Boolean](UploadCredential(dob.getData))) +// +// if (!success) { +// +// throw new SparkException(s"Timed out waiting to kill the application: $appId") +// throw new SparkException(s"Current user doesn't have modifiy ACL for the Application : $appId") +// } + + try { + val timeout = RpcUtils.askRpcTimeout(sparkConf) + val success = timeout.awaitResult(AMEndpoint.ask[Boolean](UploadCredentials(dob.getData))) + if (!success) { + throw new SparkException(s"Current user doesn't have modify ACL") + return + } + } catch { + case e: TimeoutException => + throw new SparkException(s"Timed out waiting to upload credential") + } + } + private def setupAMConnection( + appId: ApplicationId, + securityManager: SecurityManager): RpcEndpointRef = { + logInfo(s"APP ID $appId") + val report = getApplicationReport(appId) + val state = report.getYarnApplicationState + if (report.getHost() == null || "".equals(report.getHost()) || "N/A".equals(report.getHost())) { + throw new SparkException(s"AM for $appId not assigned or dont have view ACL for it") + } + if ( state != YarnApplicationState.RUNNING) { + throw new SparkException(s"Application $appId needs to be in RUNNING") + } + + if (UserGroupInformation.isSecurityEnabled()) { + val serviceAddr = new InetSocketAddress(report.getHost(), report.getRpcPort()) + + val clientToAMToken = report.getClientToAMToken + val token = ConverterUtils.convertFromYarn(clientToAMToken, serviceAddr) + + // Fetch Identifier, secretkey from the report, encode it and Set it in the Security Manager + val userName = token.getIdentifier + var userstring = Base64.encode(Unpooled.wrappedBuffer(userName)).toString(UTF_8); + securityManager.setSaslUser(userstring) + val secretkey = token.getPassword + var secretkeystring = Base64.encode(Unpooled.wrappedBuffer(secretkey)).toString(UTF_8); + securityManager.setSecretKey(secretkeystring) + } + + sparkConf.set("spark.rpc.connectionUsingTokens", "true") + val rpcEnv = + RpcEnv.create("yarnDriverClient", Utils.localHostName(), 0, sparkConf, securityManager) + val AMHostPort = RpcAddress(report.getHost, report.getRpcPort) + val AMEndpoint = rpcEnv.setupEndpointRef(AMHostPort, + ApplicationMaster.ENDPOINT_NAME) + + AMEndpoint + } + + private def checkAppStatus(appId: ApplicationId): YarnApplicationState = { + val report = getApplicationReport(appId) + report.getYarnApplicationState + } + + private def isAppInTerminalState(appId: ApplicationId): Boolean = { + var status = checkAppStatus(appId) + return (status == YarnApplicationState.KILLED + || status == YarnApplicationState.FAILED + || status == YarnApplicationState.FINISHED) + } + private def findPySparkArchives(): Seq[String] = { sys.env.get("PYSPARK_ARCHIVES_PATH") .map(_.split(",").toSeq) @@ -1186,6 +1317,22 @@ private object Client extends Logging { new Client(args, sparkConf).run() } + def yarnKillSubmission(argStrings: Array[String]): Unit = { + System.setProperty("SPARK_YARN_MODE", "true") + val sparkConf = new SparkConf + val args = new ClientArguments(argStrings) + + new Client(args, sparkConf).killSparkApplication(new SecurityManager(sparkConf)) + } + def yarnUploadCredentials(argStrings: Array[String]): Unit = { + System.setProperty("SPARK_YARN_MODE", "true") + val sparkConf = new SparkConf + val args = new ClientArguments(argStrings) + logInfo("Came here") + + new Client(args, sparkConf).uploadCredentials(new SecurityManager(sparkConf)) + } + // Alias for the user jar val APP_JAR_NAME: String = "__app__.jar" diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index 72f4d273ab53b..68c8134d5527a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.yarn +import java.nio.ByteBuffer + import scala.collection.JavaConverters._ import org.apache.hadoop.yarn.api.records._ @@ -39,6 +41,7 @@ private[spark] class YarnRMClient extends Logging { private var amClient: AMRMClient[ContainerRequest] = _ private var uiHistoryAddress: String = _ private var registered: Boolean = false + private var masterkey: ByteBuffer = _ /** * Registers the application master with the RM. @@ -58,7 +61,8 @@ private[spark] class YarnRMClient extends Logging { uiAddress: Option[String], uiHistoryAddress: String, securityMgr: SecurityManager, - localResources: Map[String, LocalResource] + localResources: Map[String, LocalResource], + port: Int = 0 ): YarnAllocator = { amClient = AMRMClient.createAMRMClient() amClient.init(conf) @@ -71,8 +75,9 @@ private[spark] class YarnRMClient extends Logging { logInfo("Registering the ApplicationMaster") synchronized { - amClient.registerApplicationMaster(Utils.localHostName(), 0, trackingUrl) + var response = amClient.registerApplicationMaster(Utils.localHostName(), port, trackingUrl) registered = true + masterkey = response.getClientToAMTokenMasterKey() } new YarnAllocator(driverUrl, driverRef, conf, sparkConf, amClient, getAttemptId(), securityMgr, localResources, new SparkRackResolver()) @@ -89,6 +94,9 @@ private[spark] class YarnRMClient extends Logging { amClient.unregisterApplicationMaster(status, diagnostics, uiHistoryAddress) } } + /** Obtain the MasterKey reported back from YARN when Registering AM. */ + def getMasterKey(): ByteBuffer = masterkey + /** Returns the attempt ID. */ def getAttemptId(): ApplicationAttemptId = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 187803cc6050b..26520578e066a 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -187,6 +187,11 @@ package object config { .toSequence .createWithDefault(Nil) + private[spark] val CLIENT_TO_AM_HARD_KILL_TIMEOUT = ConfigBuilder("spark.yarn.hardKillTimeout") + .internal() + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("60s") + /* Client-mode AM configuration. */ private[spark] val AM_CORES = ConfigBuilder("spark.yarn.am.cores") diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 415a29fd887e8..cb3a53bf02c6d 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -23,9 +23,11 @@ import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} import scala.util.control.NonFatal +import org.apache.hadoop.io.DataOutputBuffer +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api.records.{ApplicationAttemptId, ApplicationId} -import org.apache.spark.SparkContext +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.rpc._ import org.apache.spark.scheduler._ @@ -271,6 +273,9 @@ private[spark] abstract class YarnSchedulerBackend( logError("Error requesting driver to remove executor" + s" $executorId for reason $reason", e) }(ThreadUtils.sameThread) + case StopSparkContext => + sc.stop + } @@ -305,6 +310,24 @@ private[spark] abstract class YarnSchedulerBackend( case RetrieveLastAllocatedExecutorId => context.reply(currentExecutorIdCounter) + + case DelegateCredentials(c) => + context.reply(true) + sc.schedulerBackend match { + case s: CoarseGrainedSchedulerBackend => + logInfo(s"Update credentials in driver") + val f = s.updateCredentials(c) +// f.onSuccess { +// case b => context.reply(b) +// } +// f.onFailure { +// case NonFatal(e) => context.sendFailure(e) +// e +// } + case _ => + throw new SparkException(s"Update credentials on" + + s" ${sc.schedulerBackend.getClass.getSimpleName} is not supported") + } } override def onDisconnected(remoteAddress: RpcAddress): Unit = {