diff --git a/assembly/pom.xml b/assembly/pom.xml index 6c31ec745b5bd..58e7ae5bb0c7f 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -136,10 +136,6 @@ spark-yarn_${scala.binary.version} ${project.version} - - org.apache.hadoop - hadoop-yarn-server-web-proxy - diff --git a/dev/deps/spark-deps-hadoop-3-hive-2.3 b/dev/deps/spark-deps-hadoop-3-hive-2.3 index 4b6f5dda585b3..13e74a1627fb3 100644 --- a/dev/deps/spark-deps-hadoop-3-hive-2.3 +++ b/dev/deps/spark-deps-hadoop-3-hive-2.3 @@ -80,7 +80,6 @@ hadoop-client-runtime/3.4.0//hadoop-client-runtime-3.4.0.jar hadoop-cloud-storage/3.4.0//hadoop-cloud-storage-3.4.0.jar hadoop-huaweicloud/3.4.0//hadoop-huaweicloud-3.4.0.jar hadoop-shaded-guava/1.2.0//hadoop-shaded-guava-1.2.0.jar -hadoop-yarn-server-web-proxy/3.4.0//hadoop-yarn-server-web-proxy-3.4.0.jar hive-beeline/2.3.10//hive-beeline-2.3.10.jar hive-cli/2.3.10//hive-cli-2.3.10.jar hive-common/2.3.10//hive-common-2.3.10.jar diff --git a/pom.xml b/pom.xml index d92d210a5ffc3..1d11d3840e250 100644 --- a/pom.xml +++ b/pom.xml @@ -1769,83 +1769,6 @@ ${yarn.version} test - - org.apache.hadoop - hadoop-yarn-server-web-proxy - ${yarn.version} - ${hadoop.deps.scope} - - - org.apache.hadoop - hadoop-yarn-server-common - - - org.apache.hadoop - hadoop-yarn-common - - - org.apache.hadoop - hadoop-yarn-api - - - org.bouncycastle - bcprov-jdk15on - - - org.bouncycastle - bcpkix-jdk15on - - - org.fusesource.leveldbjni - leveldbjni-all - - - asm - asm - - - org.ow2.asm - asm - - - org.jboss.netty - netty - - - javax.servlet - servlet-api - - - javax.servlet - javax.servlet-api - - - commons-logging - commons-logging - - - com.sun.jersey - * - - - com.sun.jersey.jersey-test-framework - * - - - com.sun.jersey.contribs - * - - - - com.zaxxer - HikariCP-java7 - - - com.microsoft.sqlserver - mssql-jdbc - - - org.apache.hadoop hadoop-yarn-client diff --git a/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpFilter.java b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpFilter.java new file mode 100644 index 0000000000000..60e880d1ac4aa --- /dev/null +++ b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpFilter.java @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn; + +import org.apache.hadoop.classification.InterfaceAudience.Public; +import org.apache.hadoop.classification.VisibleForTesting; +import org.apache.hadoop.security.UserGroupInformation; +import org.apache.hadoop.util.Time; + +import jakarta.servlet.*; +import jakarta.servlet.http.Cookie; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.net.*; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import org.apache.spark.internal.SparkLogger; +import org.apache.spark.internal.SparkLoggerFactory; + +// This class is copied from Hadoop 3.4.0 +// org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter +// +// Modification: +// Migrate from javax.servlet to jakarta.servlet +// Copy constant string definitions to strip external dependency +// - RM_HA_URLS +// - PROXY_USER_COOKIE_NAME +@Public +public class AmIpFilter implements Filter { + private static final SparkLogger LOG = SparkLoggerFactory.getLogger(AmIpFilter.class); + + @Deprecated + public static final String PROXY_HOST = "PROXY_HOST"; + @Deprecated + public static final String PROXY_URI_BASE = "PROXY_URI_BASE"; + public static final String PROXY_HOSTS = "PROXY_HOSTS"; + public static final String PROXY_HOSTS_DELIMITER = ","; + public static final String PROXY_URI_BASES = "PROXY_URI_BASES"; + public static final String PROXY_URI_BASES_DELIMITER = ","; + private static final String PROXY_PATH = "/proxy"; + // RM_HA_URLS is defined in AmFilterInitializer in the original Hadoop code + private static final String RM_HA_URLS = "RM_HA_URLS"; + // WebAppProxyServlet is defined in WebAppProxyServlet in the original Hadoop code + public static final String PROXY_USER_COOKIE_NAME = "proxy-user"; + // update the proxy IP list about every 5 min + private static long updateInterval = TimeUnit.MINUTES.toMillis(5); + + private String[] proxyHosts; + private Set proxyAddresses = null; + private long lastUpdate; + @VisibleForTesting + Map proxyUriBases; + String[] rmUrls = null; + + @Override + public void init(FilterConfig conf) throws ServletException { + // Maintain for backwards compatibility + if (conf.getInitParameter(PROXY_HOST) != null + && conf.getInitParameter(PROXY_URI_BASE) != null) { + proxyHosts = new String[]{conf.getInitParameter(PROXY_HOST)}; + proxyUriBases = new HashMap<>(1); + proxyUriBases.put("dummy", conf.getInitParameter(PROXY_URI_BASE)); + } else { + proxyHosts = conf.getInitParameter(PROXY_HOSTS) + .split(PROXY_HOSTS_DELIMITER); + + String[] proxyUriBasesArr = conf.getInitParameter(PROXY_URI_BASES) + .split(PROXY_URI_BASES_DELIMITER); + proxyUriBases = new HashMap<>(proxyUriBasesArr.length); + for (String proxyUriBase : proxyUriBasesArr) { + try { + URL url = new URL(proxyUriBase); + proxyUriBases.put(url.getHost() + ":" + url.getPort(), proxyUriBase); + } catch(MalformedURLException e) { + LOG.warn(proxyUriBase + " does not appear to be a valid URL", e); + } + } + } + + if (conf.getInitParameter(RM_HA_URLS) != null) { + rmUrls = conf.getInitParameter(RM_HA_URLS).split(","); + } + } + + protected Set getProxyAddresses() throws ServletException { + long now = Time.monotonicNow(); + synchronized(this) { + if (proxyAddresses == null || (lastUpdate + updateInterval) <= now) { + proxyAddresses = new HashSet<>(); + for (String proxyHost : proxyHosts) { + try { + for (InetAddress add : InetAddress.getAllByName(proxyHost)) { + LOG.debug("proxy address is: {}", add.getHostAddress()); + proxyAddresses.add(add.getHostAddress()); + } + lastUpdate = now; + } catch (UnknownHostException e) { + LOG.warn("Could not locate " + proxyHost + " - skipping", e); + } + } + if (proxyAddresses.isEmpty()) { + throw new ServletException("Could not locate any of the proxy hosts"); + } + } + return proxyAddresses; + } + } + + @Override + public void destroy() { + // Empty + } + + @Override + public void doFilter(ServletRequest req, ServletResponse resp, + FilterChain chain) throws IOException, ServletException { + ProxyUtils.rejectNonHttpRequests(req); + + HttpServletRequest httpReq = (HttpServletRequest)req; + HttpServletResponse httpResp = (HttpServletResponse)resp; + + LOG.debug("Remote address for request is: {}", httpReq.getRemoteAddr()); + + if (!getProxyAddresses().contains(httpReq.getRemoteAddr())) { + StringBuilder redirect = new StringBuilder(findRedirectUrl()); + + redirect.append(httpReq.getRequestURI()); + + int insertPoint = redirect.indexOf(PROXY_PATH); + + if (insertPoint >= 0) { + // Add /redirect as the second component of the path so that the RM web + // proxy knows that this request was a redirect. + insertPoint += PROXY_PATH.length(); + redirect.insert(insertPoint, "/redirect"); + } + // add the query parameters on the redirect if there were any + String queryString = httpReq.getQueryString(); + if (queryString != null && !queryString.isEmpty()) { + redirect.append("?"); + redirect.append(queryString); + } + + ProxyUtils.sendRedirect(httpReq, httpResp, redirect.toString()); + } else { + String user = null; + + if (httpReq.getCookies() != null) { + for (Cookie c: httpReq.getCookies()) { + if (PROXY_USER_COOKIE_NAME.equals(c.getName())){ + user = c.getValue(); + break; + } + } + } + if (user == null) { + LOG.debug("Could not find {} cookie, so user will not be set", + PROXY_USER_COOKIE_NAME); + + chain.doFilter(req, resp); + } else { + AmIpPrincipal principal = new AmIpPrincipal(user); + ServletRequest requestWrapper = new AmIpServletRequestWrapper(httpReq, + principal); + + chain.doFilter(requestWrapper, resp); + } + } + } + + @VisibleForTesting + public String findRedirectUrl() throws ServletException { + String addr = null; + if (proxyUriBases.size() == 1) { + // external proxy or not RM HA + addr = proxyUriBases.values().iterator().next(); + } else if (rmUrls != null) { + for (String url : rmUrls) { + String host = proxyUriBases.get(url); + if (isValidUrl(host)) { + addr = host; + break; + } + } + } + + if (addr == null) { + throw new ServletException( + "Could not determine the proxy server for redirection"); + } + return addr; + } + + @VisibleForTesting + public boolean isValidUrl(String url) { + boolean isValid = false; + try { + HttpURLConnection conn = (HttpURLConnection) new URL(url).openConnection(); + conn.connect(); + isValid = conn.getResponseCode() == HttpURLConnection.HTTP_OK; + // If security is enabled, any valid RM which can give 401 Unauthorized is + // good enough to access. Since AM doesn't have enough credential, auth + // cannot be completed and hence 401 is fine in such case. + if (!isValid && UserGroupInformation.isSecurityEnabled()) { + isValid = (conn.getResponseCode() == HttpURLConnection.HTTP_UNAUTHORIZED) + || (conn.getResponseCode() == HttpURLConnection.HTTP_FORBIDDEN); + return isValid; + } + } catch (Exception e) { + LOG.warn("Failed to connect to " + url + ": " + e.toString()); + } + return isValid; + } + + @VisibleForTesting + protected static void setUpdateInterval(long updateInterval) { + AmIpFilter.updateInterval = updateInterval; + } +} diff --git a/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpPrincipal.java b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpPrincipal.java new file mode 100644 index 0000000000000..9d5a5e3b04568 --- /dev/null +++ b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpPrincipal.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn; + +import java.security.Principal; + +// This class is copied from Hadoop 3.4.0 +// org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpPrincipal +public class AmIpPrincipal implements Principal { + private final String name; + + public AmIpPrincipal(String name) { + this.name = name; + } + + @Override + public String getName() { + return name; + } +} diff --git a/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpServletRequestWrapper.java b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpServletRequestWrapper.java new file mode 100644 index 0000000000000..9082378fe89c7 --- /dev/null +++ b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/AmIpServletRequestWrapper.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn; + +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletRequestWrapper; +import java.security.Principal; + +// This class is copied from Hadoop 3.4.0 +// org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpServletRequestWrapper +// +// Modification: +// Migrate from javax.servlet to jakarta.servlet +public class AmIpServletRequestWrapper extends HttpServletRequestWrapper { + private final AmIpPrincipal principal; + + public AmIpServletRequestWrapper(HttpServletRequest request, + AmIpPrincipal principal) { + super(request); + this.principal = principal; + } + + @Override + public Principal getUserPrincipal() { + return principal; + } + + @Override + public String getRemoteUser() { + return principal.getName(); + } + + @Override + public boolean isUserInRole(String role) { + // No role info so far + return false; + } + +} diff --git a/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/ProxyUtils.java b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/ProxyUtils.java new file mode 100644 index 0000000000000..c7a49a76c655f --- /dev/null +++ b/resource-managers/yarn/src/main/java/org/apache/spark/deploy/yarn/ProxyUtils.java @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn; + +import org.apache.hadoop.yarn.webapp.MimeType; +import org.apache.hadoop.yarn.webapp.hamlet2.Hamlet; + +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.EnumSet; + +import org.apache.spark.internal.SparkLogger; +import org.apache.spark.internal.SparkLoggerFactory; + +// Class containing general purpose proxy utilities +// +// This class is copied from Hadoop 3.4.0 +// org.apache.hadoop.yarn.server.webproxy.ProxyUtils +// +// Modification: +// Migrate from javax.servlet to jakarta.servlet +public class ProxyUtils { + private static final SparkLogger LOG = SparkLoggerFactory.getLogger(ProxyUtils.class); + public static final String E_HTTP_HTTPS_ONLY = + "This filter only works for HTTP/HTTPS"; + public static final String LOCATION = "Location"; + + public static class __ implements Hamlet.__ { + // Empty + } + + public static class Page extends Hamlet { + Page(PrintWriter out) { + super(out, 0, false); + } + + public HTML html() { + return new HTML<>("html", null, EnumSet.of(EOpt.ENDTAG)); + } + } + + /** + * Handle redirects with a status code that can in future support verbs other + * than GET, thus supporting full REST functionality. + *

+ * The target URL is included in the redirect text returned + *

+ * At the end of this method, the output stream is closed. + * + * @param request request (hence: the verb and any other information + * relevant to a redirect) + * @param response the response + * @param target the target URL -unencoded + * + */ + public static void sendRedirect(HttpServletRequest request, + HttpServletResponse response, + String target) + throws IOException { + LOG.debug("Redirecting {} {} to {}", + request.getMethod(), + request.getRequestURI(), + target); + String location = response.encodeRedirectURL(target); + response.setStatus(HttpServletResponse.SC_FOUND); + response.setHeader(LOCATION, location); + response.setContentType(MimeType.HTML); + PrintWriter writer = response.getWriter(); + Page p = new Page(writer); + p.html() + .head().title("Moved").__() + .body() + .h1("Moved") + .div() + .__("Content has moved ") + .a(location, "here").__() + .__().__(); + writer.close(); + } + + + /** + * Output 404 with appropriate message. + * @param resp the http response. + * @param message the message to include on the page. + * @throws IOException on any error. + */ + public static void notFound(HttpServletResponse resp, String message) + throws IOException { + resp.setStatus(HttpServletResponse.SC_NOT_FOUND); + resp.setContentType(MimeType.HTML); + Page p = new Page(resp.getWriter()); + p.html().h1(message).__(); + } + + /** + * Reject any request that isn't from an HTTP servlet + * @param req request + * @throws ServletException if the request is of the wrong type + */ + public static void rejectNonHttpRequests(ServletRequest req) throws + ServletException { + if (!(req instanceof HttpServletRequest)) { + throw new ServletException(E_HTTP_HTTPS_ONLY); + } + } +} diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 8f20f6602ec5c..4b5f9be3193f9 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -696,7 +696,7 @@ private[spark] class ApplicationMaster( /** Add the Yarn IP filter that is required for properly securing the UI. */ private def addAmIpFilter(driver: Option[RpcEndpointRef], proxyBase: String) = { - val amFilter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val amFilter = classOf[AmIpFilter].getName val params = client.getAmIpFilterParams(yarnConf, proxyBase) driver match { case Some(d) => diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala new file mode 100644 index 0000000000000..e25bd665dec0d --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/AmIpFilterSuite.scala @@ -0,0 +1,342 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.io.{IOException, PrintWriter, StringWriter} +import java.net.HttpURLConnection +import java.util +import java.util.{Collections, Locale} +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicBoolean + +import scala.jdk.CollectionConverters._ + +import jakarta.servlet.{FilterChain, FilterConfig, ServletContext, ServletException, ServletOutputStream, ServletRequest, ServletResponse} +import jakarta.servlet.http.{Cookie, HttpServlet, HttpServletRequest, HttpServletResponse} +import jakarta.ws.rs.core.MediaType +import org.eclipse.jetty.server.{Server, ServerConnector} +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} +import org.eclipse.jetty.util.thread.QueuedThreadPool +import org.mockito.Mockito.{mock, when} +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkFunSuite + +// A port of org.apache.hadoop.yarn.server.webproxy.amfilter.TestAmFilter +class AmIpFilterSuite extends SparkFunSuite { + + private val proxyHost = "localhost" + private val proxyUri = "http://bogus" + + class TestAmIpFilter extends AmIpFilter { + override def getProxyAddresses: util.Set[String] = Set(proxyHost).asJava + } + + class DummyFilterConfig (val map: util.Map[String, String]) extends FilterConfig { + override def getFilterName: String = "dummy" + + override def getInitParameter(arg0: String): String = map.get(arg0) + + override def getInitParameterNames: util.Enumeration[String] = + Collections.enumeration(map.keySet) + + override def getServletContext: ServletContext = null + } + + test("filterNullCookies") { + val request = mock(classOf[HttpServletRequest]) + + when(request.getCookies).thenReturn(null) + when(request.getRemoteAddr).thenReturn(proxyHost) + + val response = mock(classOf[HttpServletResponse]) + val invoked = new AtomicBoolean + + val chain = new FilterChain() { + @throws[IOException] + @throws[ServletException] + override def doFilter(req: ServletRequest, resp: ServletResponse): Unit = { + invoked.set(true) + } + } + + val params = new util.HashMap[String, String] + params.put(AmIpFilter.PROXY_HOST, proxyHost) + params.put(AmIpFilter.PROXY_URI_BASE, proxyUri) + val conf = new DummyFilterConfig(params) + val filter = new TestAmIpFilter + filter.init(conf) + filter.doFilter(request, response, chain) + assert(invoked.get) + filter.destroy() + } + + test("testFindRedirectUrl") { + class EchoServlet extends HttpServlet { + @throws[IOException] + @throws[ServletException] + override def doGet(request: HttpServletRequest, response: HttpServletResponse): Unit = { + response.setContentType(MediaType.TEXT_PLAIN + "; charset=utf-8") + val out = response.getWriter + request.getParameterNames.asScala.toSeq.sorted.foreach { key => + out.print(key) + out.print(':') + out.print(request.getParameter(key)) + out.print('\n') + } + out.close() + } + } + + def withHttpEchoServer(body: String => Unit): Unit = { + val server = new Server(0) + server.getThreadPool.asInstanceOf[QueuedThreadPool].setMaxThreads(20) + val context = new ServletContextHandler + context.setContextPath("/foo") + server.setHandler(context) + val servletPath = "/bar" + context.addServlet(new ServletHolder(new EchoServlet), servletPath) + server.getConnectors.head.asInstanceOf[ServerConnector].setHost("localhost") + try { + server.start() + body(server.getURI.toString + servletPath) + } finally { + server.stop() + } + } + + // generate a valid URL + withHttpEchoServer { rm1Url => + val rm1 = "rm1" + val rm2 = "rm2" + // invalid url + val rm2Url = "host2:8088" + + val filter = new TestAmIpFilter + // make sure findRedirectUrl() go to HA branch + filter.proxyUriBases = Map(rm1 -> rm1Url, rm2 -> rm2Url).asJava + filter.rmUrls = Array[String](rm1, rm2) + + assert(filter.findRedirectUrl === rm1Url) + } + } + + test("testProxyUpdate") { + var params = new util.HashMap[String, String] + params.put(AmIpFilter.PROXY_HOSTS, proxyHost) + params.put(AmIpFilter.PROXY_URI_BASES, proxyUri) + + var conf = new DummyFilterConfig(params) + val filter = new AmIpFilter + val updateInterval = TimeUnit.SECONDS.toMillis(1) + AmIpFilter.setUpdateInterval(updateInterval) + filter.init(conf) + + // check that the configuration was applied + assert(filter.getProxyAddresses.contains("127.0.0.1")) + + // change proxy configurations + params = new util.HashMap[String, String] + params.put(AmIpFilter.PROXY_HOSTS, "unknownhost") + params.put(AmIpFilter.PROXY_URI_BASES, proxyUri) + conf = new DummyFilterConfig(params) + filter.init(conf) + + // configurations shouldn't be updated now + assert(!filter.getProxyAddresses.isEmpty) + // waiting for configuration update + eventually(timeout(5.seconds), interval(500.millis)) { + assertThrows[ServletException] { + filter.getProxyAddresses.isEmpty + } + } + } + + test("testFilter") { + var doFilterRequest: String = null + var servletWrapper: AmIpServletRequestWrapper = null + + val params = new util.HashMap[String, String] + params.put(AmIpFilter.PROXY_HOST, proxyHost) + params.put(AmIpFilter.PROXY_URI_BASE, proxyUri) + val config = new DummyFilterConfig(params) + + // dummy filter + val chain = new FilterChain() { + @throws[IOException] + @throws[ServletException] + override def doFilter(req: ServletRequest, resp: ServletResponse): Unit = { + doFilterRequest = req.getClass.getName + req match { + case wrapper: AmIpServletRequestWrapper => servletWrapper = wrapper + case _ => + } + } + } + val testFilter = new AmIpFilter + testFilter.init(config) + + val response = new HttpServletResponseForTest + + // Test request should implements HttpServletRequest + val failRequest = mock(classOf[ServletRequest]) + val throws = intercept[ServletException] { + testFilter.doFilter(failRequest, response, chain) + } + assert(ProxyUtils.E_HTTP_HTTPS_ONLY === throws.getMessage) + + + // request with HttpServletRequest + val request = mock(classOf[HttpServletRequest]) + when(request.getRemoteAddr).thenReturn("nowhere") + when(request.getRequestURI).thenReturn("/app/application_00_0") + + // address "redirect" is not in host list for non-proxy connection + testFilter.doFilter(request, response, chain) + assert(HttpURLConnection.HTTP_MOVED_TEMP === response.status) + var redirect = response.getHeader(ProxyUtils.LOCATION) + assert("http://bogus/app/application_00_0" === redirect) + + // address "redirect" is not in host list for proxy connection + when(request.getRequestURI).thenReturn("/proxy/application_00_0") + testFilter.doFilter(request, response, chain) + assert(HttpURLConnection.HTTP_MOVED_TEMP === response.status) + redirect = response.getHeader(ProxyUtils.LOCATION) + assert("http://bogus/proxy/redirect/application_00_0" === redirect) + + // check for query parameters + when(request.getRequestURI).thenReturn("/proxy/application_00_0") + when(request.getQueryString).thenReturn("id=0") + testFilter.doFilter(request, response, chain) + assert(HttpURLConnection.HTTP_MOVED_TEMP === response.status) + redirect = response.getHeader(ProxyUtils.LOCATION) + assert("http://bogus/proxy/redirect/application_00_0?id=0" === redirect) + + // "127.0.0.1" contains in host list. Without cookie + when(request.getRemoteAddr).thenReturn("127.0.0.1") + testFilter.doFilter(request, response, chain) + assert(doFilterRequest.contains("HttpServletRequest")) + + // cookie added + val cookies = Array[Cookie](new Cookie(AmIpFilter.PROXY_USER_COOKIE_NAME, "user")) + + when(request.getCookies).thenReturn(cookies) + testFilter.doFilter(request, response, chain) + + assert(doFilterRequest === classOf[AmIpServletRequestWrapper].getName) + // request contains principal from cookie + assert(servletWrapper.getUserPrincipal.getName === "user") + assert(servletWrapper.getRemoteUser === "user") + assert(!servletWrapper.isUserInRole("")) + } + + private class HttpServletResponseForTest extends HttpServletResponse { + private var redirectLocation = "" + var status = 0 + private var contentType: String = _ + final private val headers = new util.HashMap[String, String](1) + private var body: StringWriter = _ + + def getRedirect: String = redirectLocation + + @throws[IOException] + override def sendRedirect(location: String): Unit = redirectLocation = location + + override def setDateHeader(name: String, date: Long): Unit = {} + + override def addDateHeader(name: String, date: Long): Unit = {} + + override def addCookie(cookie: Cookie): Unit = {} + + override def containsHeader(name: String): Boolean = false + + override def encodeURL(url: String): String = null + + override def encodeRedirectURL(url: String): String = url + + override def encodeUrl(url: String): String = null + + override def encodeRedirectUrl(url: String): String = null + + @throws[IOException] + override def sendError(sc: Int, msg: String): Unit = {} + + @throws[IOException] + override def sendError(sc: Int): Unit = {} + + override def setStatus(status: Int): Unit = this.status = status + + override def setStatus(sc: Int, sm: String): Unit = {} + + override def getStatus: Int = 0 + + override def setContentType(contentType: String): Unit = this.contentType = contentType + + override def setBufferSize(size: Int): Unit = {} + + override def getBufferSize: Int = 0 + + @throws[IOException] + override def flushBuffer(): Unit = {} + + override def resetBuffer(): Unit = {} + + override def isCommitted: Boolean = false + + override def reset(): Unit = {} + + override def setLocale(loc: Locale): Unit = {} + + override def getLocale: Locale = null + + override def setHeader(name: String, value: String): Unit = headers.put(name, value) + + override def addHeader(name: String, value: String): Unit = {} + + override def setIntHeader(name: String, value: Int): Unit = {} + + override def addIntHeader(name: String, value: Int): Unit = {} + + override def getHeader(name: String): String = headers.get(name) + + override def getHeaders(name: String): util.Collection[String] = null + + override def getHeaderNames: util.Collection[String] = null + + override def getCharacterEncoding: String = null + + override def getContentType: String = null + + @throws[IOException] + override def getOutputStream: ServletOutputStream = null + + @throws[IOException] + override def getWriter: PrintWriter = { + body = new StringWriter + new PrintWriter(body) + } + + override def setCharacterEncoding(charset: String): Unit = {} + + override def setContentLength(len: Int): Unit = {} + + override def setContentLengthLong(len: Long): Unit = {} + } + +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index bed048c4b5dfc..6cbc74a75a064 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -86,7 +86,7 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) } // Add Yarn proxy filter specific configurations to the recovered SparkConf - val filter = "org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter" + val filter = "org.apache.spark.deploy.yarn.AmIpFilter" val filterPrefix = s"spark.$filter.param." newReloadConf.getAll.foreach { case (k, v) => if (k.startsWith(filterPrefix) && k.length > filterPrefix.length) {