Skip to content

Commit

Permalink
polish and ut
Browse files Browse the repository at this point in the history
  • Loading branch information
pan3793 committed May 17, 2024
1 parent b772645 commit 0b8dd36
Show file tree
Hide file tree
Showing 2 changed files with 440 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@ package org.apache.spark.deploy.yarn
import java.io.IOException
import java.net.{HttpURLConnection, InetAddress, MalformedURLException, UnknownHostException, URL}
import java.security.Principal
import java.util
import java.util.concurrent.TimeUnit

import jakarta.servlet.{Filter, FilterChain, FilterConfig, ServletException, ServletRequest, ServletResponse}
import jakarta.servlet.http.{HttpServletRequest, HttpServletRequestWrapper, HttpServletResponse}
import org.apache.commons.lang3.StringUtils
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.util.Time
import org.apache.hadoop.yarn.webapp.MimeType

import org.apache.spark.internal.Logging

Expand All @@ -36,126 +34,114 @@ class YarnAMIpFilter extends Filter with Logging {

import YarnAMIpFilter._

private[spark] class AmIpPrincipal(name: String) extends Principal {
override def getName: String = name
}

private var proxyHosts: Array[String] = null
private var proxyAddresses: util.Set[String] = null
private var lastUpdate: Long = 0L
private var proxyUriBases: util.Map[String, String] = null
private var rmUrls: Array[String] = null
private[spark] var proxyHosts: Array[String] = _
private[spark] var proxyAddresses: Set[String] = _
private[spark] var lastUpdate: Long = 0L
private[spark] var proxyUriBases: Map[String, String] = _
private[spark] var rmUrls: Array[String] = _

@throws[ServletException]
override def init(conf: FilterConfig): Unit = {
// YARN-1811: Maintain for backwards compatibility
if (conf.getInitParameter(PROXY_HOST) != null
&& conf.getInitParameter(PROXY_URI_BASE) != null) {
proxyHosts = Array[String](conf.getInitParameter(PROXY_HOST))
proxyUriBases = new util.HashMap[String, String](1)
proxyUriBases.put("dummy", conf.getInitParameter(PROXY_URI_BASE))
} else {
proxyHosts = conf.getInitParameter(PROXY_HOSTS).split(PROXY_HOSTS_DELIMITER)
val proxyUriBasesArr = conf.getInitParameter(PROXY_URI_BASES).split(PROXY_URI_BASES_DELIMITER)
proxyUriBases = new util.HashMap[String, String](proxyUriBasesArr.length)
for (proxyUriBase <- proxyUriBasesArr) {
try {
val url: URL = new URL(proxyUriBase)
proxyUriBases.put(url.getHost + ":" + url.getPort, proxyUriBase)
} catch {
case e: MalformedURLException =>
logWarning(s"$proxyUriBase does not appear to be a valid URL", e)
}
}
val legacyProxyHostOpt = Option(conf.getInitParameter(PROXY_HOST))
val legacyProxyUriBaseOpt = Option(conf.getInitParameter(PROXY_URI_BASE))
(legacyProxyHostOpt, legacyProxyUriBaseOpt) match {
// YARN-1811: Maintain for backwards compatibility
case (Some(legacyProxyHost), Some(legacyProxyUriBase)) =>
proxyHosts = Array(legacyProxyHost)
proxyUriBases = Map("dummy" -> legacyProxyUriBase)
case _ =>
proxyHosts = conf.getInitParameter(PROXY_HOSTS).split(PROXY_HOSTS_DELIMITER)
proxyUriBases = conf.getInitParameter(PROXY_URI_BASES).split(PROXY_URI_BASES_DELIMITER)
.flatMap { proxyUriBase =>
try {
val url = new URL(proxyUriBase)
Some(url.getHost + ":" + url.getPort, proxyUriBase)
} catch {
case e: MalformedURLException =>
logWarning(s"$proxyUriBase does not appear to be a valid URL", e)
None
}
}.toMap
}
if (conf.getInitParameter(RM_HA_URLS) != null) {
rmUrls = conf.getInitParameter(RM_HA_URLS).split(",")

Option(conf.getInitParameter(RM_HA_URLS)).foreach { rmHaUrls =>
rmUrls = rmHaUrls.split(RM_HA_URLS_DELIMITER)
}
}

@throws[ServletException]
protected def getProxyAddresses: util.Set[String] = {
val now: Long = Time.monotonicNow
this.synchronized {
if (proxyAddresses == null || (lastUpdate + updateInterval) <= now) {
proxyAddresses = new util.HashSet[String]
for (proxyHost <- proxyHosts) {
try {
for (add <- InetAddress.getAllByName(proxyHost)) {
logDebug(s"proxy address is: ${add.getHostAddress}")
proxyAddresses.add(add.getHostAddress)
protected[spark] def getProxyAddresses: Set[String] = {
val now = System.nanoTime()
if (proxyAddresses == null || (lastUpdate + updateInterval) <= now) {
this.synchronized {
if (proxyAddresses == null || (lastUpdate + updateInterval) <= now) {
val proxyAddressesBuilder = Set.newBuilder[String]
proxyHosts.foreach { proxyHost =>
try {
InetAddress.getAllByName(proxyHost).foreach { proxyAddr =>
logDebug(s"proxy address is: ${proxyAddr.getHostAddress}")
proxyAddressesBuilder += proxyAddr.getHostAddress
}
lastUpdate = now
} catch {
case e: UnknownHostException =>
logWarning(s"Could not locate $proxyHost - skipping", e)
}
lastUpdate = now
} catch {
case e: UnknownHostException =>
logWarning(s"Could not locate $proxyHost - skipping", e)
}
}
if (proxyAddresses.isEmpty) {
throw new ServletException("Could not locate any of the proxy hosts")
proxyAddresses = proxyAddressesBuilder.result()
if (proxyAddresses.isEmpty) {
throw new ServletException("Could not locate any of the proxy hosts")
}
}
}
return proxyAddresses
}
}

override def destroy(): Unit = {
// Empty
proxyAddresses
}

@throws[IOException]
@throws[ServletException]
override def doFilter(req: ServletRequest, resp: ServletResponse, chain: FilterChain): Unit = {
rejectNonHttpRequests(req)
val httpReq: HttpServletRequest = req.asInstanceOf[HttpServletRequest]
val httpResp: HttpServletResponse = resp.asInstanceOf[HttpServletResponse]
val httpReq = req.asInstanceOf[HttpServletRequest]
val httpResp = resp.asInstanceOf[HttpServletResponse]
logDebug(s"Remote address for request is: ${httpReq.getRemoteAddr}")
if (!getProxyAddresses.contains(httpReq.getRemoteAddr)) {
val redirect: StringBuilder = new StringBuilder(findRedirectUrl)
val redirect = new StringBuilder(findRedirectUrl)
redirect.append(httpReq.getRequestURI)
var insertPoint: Int = redirect.indexOf(PROXY_PATH)
var insertPoint = redirect.indexOf(PROXY_PATH)
if (insertPoint >= 0) {
// Add /redirect as the second component of the path so that the RM web
// proxy knows that this request was a redirect.
insertPoint += PROXY_PATH.length
redirect.insert(insertPoint, "/redirect")
}
// add the query parameters on the redirect if there were any
val queryString: String = httpReq.getQueryString
if (queryString != null && queryString.nonEmpty) {
redirect.append("?")
redirect.append(queryString)
Option(httpReq.getQueryString).filterNot(StringUtils.isEmpty).foreach { queryString =>
redirect.append("?").append(queryString)
}
sendRedirect(httpReq, httpResp, redirect.toString)
} else {
var user: String = null
if (httpReq.getCookies != null) {
httpReq.getCookies.find { c => c.getName == PROXY_USER_COOKIE_NAME }
.foreach { c => user = c.getValue }
}
if (user == null) {
logDebug(s"Could not find $PROXY_USER_COOKIE_NAME cookie, so user will not be set")
chain.doFilter(req, resp)
}
else {
val principal = new AmIpPrincipal(user)
val requestWrapper = new HttpServletRequestWrapper(httpReq) {
override def getUserPrincipal: Principal = principal
override def getRemoteUser: String = principal.getName
override def isUserInRole(role: String): Boolean = false
}
chain.doFilter(requestWrapper, resp)
Option(httpReq.getCookies).flatMap { cookies =>
cookies.find { c => c.getName == PROXY_USER_COOKIE_NAME }.map(_.getValue)
} match {
case Some(user) =>
val principal = new AmIpPrincipal(user)
val requestWrapper = new AmIpServletRequestWrapper(httpReq, principal)
chain.doFilter(requestWrapper, resp)
case None =>
logDebug(s"Could not find $PROXY_USER_COOKIE_NAME cookie, so user will not be set")
chain.doFilter(req, resp)
}
}
}

@throws[ServletException]
private def findRedirectUrl: String = {
private[spark] def findRedirectUrl: String = {
val addr = if (proxyUriBases.size == 1) {
// external proxy or not RM HA
Some(proxyUriBases.values.iterator.next)
Some(proxyUriBases.values.iterator.next())
} else if (rmUrls != null) {
rmUrls.find { url => isValidUrl(proxyUriBases.get(url)) }
rmUrls.map(url => proxyUriBases(url)).find { host => isValidUrl(host) }
} else {
None
}
Expand Down Expand Up @@ -193,20 +179,20 @@ class YarnAMIpFilter extends Filter with Logging {
* <p>
* At the end of this method, the output stream is closed.
*
* @param request request (hence: the verb and any other information
* @param req request (hence: the verb and any other information
* relevant to a redirect)
* @param response the response
* @param target the target URL -unencoded
* @param resp the response
* @param target the target URL -unencoded
*
*/
@throws[IOException]
private def sendRedirect(request: HttpServletRequest,
response: HttpServletResponse, target: String): Unit = {
logDebug(s"Redirecting ${request.getMethod} ${request.getRequestURI} to $target")
val location = response.encodeRedirectURL(target)
response.setStatus(HttpServletResponse.SC_FOUND)
response.setHeader(LOCATION, location)
response.setContentType(MimeType.HTML)
private def sendRedirect(
req: HttpServletRequest, resp: HttpServletResponse, target: String): Unit = {
logDebug(s"Redirecting ${req.getMethod} ${req.getRequestURI} to $target")
val location = resp.encodeRedirectURL(target)
resp.setStatus(HttpServletResponse.SC_FOUND)
resp.setHeader(LOCATION, location)
resp.setContentType("text/html")
val content = s"""
|<html>
|<head>
Expand All @@ -219,7 +205,7 @@ class YarnAMIpFilter extends Filter with Logging {
|</html>
""".stripMargin

val writer = response.getWriter
val writer = resp.getWriter
writer.write(content)
writer.close()
}
Expand Down Expand Up @@ -247,8 +233,24 @@ private[spark] object YarnAMIpFilter {
val PROXY_PATH = "/proxy"
val PROXY_USER_COOKIE_NAME = "proxy-user"
val RM_HA_URLS = "RM_HA_URLS"
val RM_HA_URLS_DELIMITER = ","
val E_HTTP_HTTPS_ONLY = "This filter only works for HTTP/HTTPS"
val LOCATION = "Location"
// update the proxy IP list about every 5 min
val updateInterval = TimeUnit.MINUTES.toMillis(5)
var updateInterval = TimeUnit.MINUTES.toNanos(5)

// only for testing
def setUpdateInterval(ns: Long): Unit = updateInterval = ns

private[spark] class AmIpPrincipal(name: String) extends Principal {
override def getName: String = name
}

private[spark] class AmIpServletRequestWrapper(
httpReq: HttpServletRequest,
principal: AmIpPrincipal) extends HttpServletRequestWrapper(httpReq) {
override def getUserPrincipal: Principal = principal
override def getRemoteUser: String = principal.getName
override def isUserInRole(role: String): Boolean = false
}
}
Loading

0 comments on commit 0b8dd36

Please sign in to comment.