From 1e80e351d8398293d8c652fe55415502090772f2 Mon Sep 17 00:00:00 2001 From: Tareq Sharafy Date: Fri, 5 Jun 2015 14:39:27 +0300 Subject: [PATCH] Safety fixes in interpretation of X-Forwarded-Host 1- Follow IETF standard in host syntax 2- Correctly detect default HTTP/HTTPs port in the URI Signed-off-by: Tareq Sharafy --- .../core/servlet/ServletContainerRequest.java | 94 +++++++++++-------- .../servlet/ServletContainerRequestTest.java | 27 ++++-- 2 files changed, 73 insertions(+), 48 deletions(-) diff --git a/everrest-core/src/main/java/org/everrest/core/servlet/ServletContainerRequest.java b/everrest-core/src/main/java/org/everrest/core/servlet/ServletContainerRequest.java index d5219287..ba1b5ead 100644 --- a/everrest-core/src/main/java/org/everrest/core/servlet/ServletContainerRequest.java +++ b/everrest-core/src/main/java/org/everrest/core/servlet/ServletContainerRequest.java @@ -10,9 +10,11 @@ *******************************************************************************/ package org.everrest.core.servlet; +import org.everrest.core.ExtHttpHeaders; import org.everrest.core.impl.ContainerRequest; import org.everrest.core.impl.InputHeadersMap; import org.everrest.core.impl.MultivaluedMapImpl; +import org.everrest.core.util.Logger; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.MultivaluedMap; @@ -26,12 +28,13 @@ import java.security.Principal; import java.util.Enumeration; -import static org.everrest.core.ExtHttpHeaders.FORWARDED_HOST; - /** @author andrew00x */ public final class ServletContainerRequest extends ContainerRequest { + + private static final Logger LOG = Logger.getLogger(ServletContainerRequest.class); public static ServletContainerRequest create(final HttpServletRequest req) { + // If the URL is forwarded, obtain the forwarding information final URL forwardedUrl = getForwardedUrl(req); String host; int port; @@ -44,15 +47,31 @@ public static ServletContainerRequest create(final HttpServletRequest req) { if (port < 0) { port = forwardedUrl.getDefaultPort(); } + if (LOG.isInfoEnabled()) { + LOG.info("Assuming forwarded URL: " + forwardedUrl); + } } + // The common URI prefix for both baseUri and requestUri + final StringBuilder commonUriBuilder = new StringBuilder(); final String scheme = getScheme(req); - final StringBuilder baseUriBuilder = uriBuilder(scheme, host, port); + commonUriBuilder.append(scheme); + commonUriBuilder.append("://"); + commonUriBuilder.append(host); + if (!(port < 0 || (port == 80 && "http".equals(scheme)) || (port == 443 && "https".equals(scheme)))) { + commonUriBuilder.append(':'); + commonUriBuilder.append(port); + } + final String commonUriPrefix = commonUriBuilder.toString(); + + // The Base URI - up to the servlet path + final StringBuilder baseUriBuilder = new StringBuilder(commonUriPrefix); baseUriBuilder.append(req.getContextPath()); baseUriBuilder.append(req.getServletPath()); final URI baseUri = URI.create(baseUriBuilder.toString()); - final StringBuilder requestUriBuilder = uriBuilder(scheme, host, port); + // The RequestURI - everything in the URL + final StringBuilder requestUriBuilder = new StringBuilder(commonUriPrefix); requestUriBuilder.append(req.getRequestURI()); final String queryString = req.getQueryString(); if (queryString != null) { @@ -84,51 +103,48 @@ private static String getScheme(HttpServletRequest servletRequest) { return servletRequest.getScheme(); } + /** + * Get the URL that is forwarded using the standard X-Forwarded-Host header. + * + * @param servletRequest + * @return The URL of the forwarded host. If the header is missing or invalid, null is returned. + */ private static URL getForwardedUrl(HttpServletRequest servletRequest) { - final String scheme = getScheme(servletRequest); final String forwardedHostAndPort = servletRequest.getHeader(FORWARDED_HOST); - if (forwardedHostAndPort != null && !forwardedHostAndPort.isEmpty()) { - final String host = getForwardedHost(forwardedHostAndPort); - final int port = getForwardedPort(forwardedHostAndPort); - try { - // Use the standard URI to verify the host details - return new URI(scheme, null, host, port, null, null, null).toURL(); - } catch (URISyntaxException | MalformedURLException e) { - return null; - } + if (forwardedHostAndPort == null || forwardedHostAndPort.isEmpty()) { + return null; } - return null; - } - - public static String getForwardedHost(String forwardedHostAndPort) { - final int colonIndex = forwardedHostAndPort.indexOf(':'); - if (colonIndex < 0) { - return forwardedHostAndPort; + URL url = parseForwardedHostHeader(forwardedHostAndPort, servletRequest); + if (url == null && LOG.isWarnEnabled()) { + LOG.warn("Ignoring invalid " + ExtHttpHeaders.FORWARDED_HOST + ": " + forwardedHostAndPort); } - return forwardedHostAndPort.substring(0, colonIndex); + return url; } - public static int getForwardedPort(String forwardedHostAndPort) { - final int colonIndex = forwardedHostAndPort.indexOf(':'); - if (colonIndex >= 0) { + /** Parse according to IETF standard for Host field: http://tools.ietf.org/html/rfc7230#section-5.4 */ + private static URL parseForwardedHostHeader(String forwardedHostAndPort, HttpServletRequest servletRequest) { + final String[] parts = forwardedHostAndPort.split(":"); + if (parts.length > 2) { + return null; + } + int fwdPort = -1; + if (parts.length == 2) { try { - return Integer.parseInt(forwardedHostAndPort.substring(colonIndex + 1, forwardedHostAndPort.length())); - } catch (NumberFormatException ignored) { + fwdPort = Integer.parseInt(parts[1]); + } catch (NumberFormatException e) { + return null; + } + if (fwdPort < 0) { + return null; } } - return -1; - } - - private static StringBuilder uriBuilder(String scheme, String host, int port) { - final StringBuilder uriBuilder = new StringBuilder(); - uriBuilder.append(scheme); - uriBuilder.append("://"); - uriBuilder.append(host); - if (!(port == 80 || (port == 443 && "https".equals(scheme)))) { - uriBuilder.append(':'); - uriBuilder.append(port); + final String fwdHost = parts[0]; + final String scheme = getScheme(servletRequest); + try { + return new URI(scheme, null, fwdHost, fwdPort, null, null, null).toURL(); + } catch (URISyntaxException | MalformedURLException e) { } - return uriBuilder; + return null; } /** diff --git a/everrest-core/src/test/java/org/everrest/core/servlet/ServletContainerRequestTest.java b/everrest-core/src/test/java/org/everrest/core/servlet/ServletContainerRequestTest.java index ebdc058c..10399c59 100644 --- a/everrest-core/src/test/java/org/everrest/core/servlet/ServletContainerRequestTest.java +++ b/everrest-core/src/test/java/org/everrest/core/servlet/ServletContainerRequestTest.java @@ -41,7 +41,7 @@ public class ServletContainerRequestTest { private static final String TEST_BASE_PATH = TEST_CONTEXT_PATH + TEST_SERVLET_PATH; private static final String TEST_FULL_PATH = TEST_BASE_PATH + TEST_SUBPATH; - + private static final String TEST_BASE_URI = TEST_SCHEME + TEST_HOST + TEST_BASE_PATH; private static final String TEST_REQUEST_URI = TEST_BASE_URI + TEST_SUBPATH; @@ -92,18 +92,27 @@ public String getPathInfo() { @Test public void testSimpleRequest() { - // A simple HTTP request - MockHttpServletRequest httpReq = new MockEmptyBodyHttpRequest(null, null); - ServletContainerRequest req = ServletContainerRequest.create(httpReq); - // Validate the fields - assertEquals(TEST_BASE_URI, req.getBaseUri().toString()); - assertEquals(TEST_REQUEST_URI, req.getRequestUri().toString()); + assertIgnoredForwardedHost(null); + } + + @Test + public void testInvalidForwardedHost1() { + assertIgnoredForwardedHost("a b c"); } @Test - public void testInvalidForwardedHost() { + public void testInvalidForwardedHost2() { + assertIgnoredForwardedHost("myhost.com:8877:200"); + } + + @Test + public void testInvalidForwardedHost3() { + assertIgnoredForwardedHost("myhost..com"); + } + + private static void assertIgnoredForwardedHost(String forwardedHostHeader) { // A simple HTTP request - MockHttpServletRequest httpReq = new MockEmptyBodyHttpRequest("a b c", null); + MockHttpServletRequest httpReq = new MockEmptyBodyHttpRequest(forwardedHostHeader, null); ServletContainerRequest req = ServletContainerRequest.create(httpReq); // Validate the fields assertEquals(TEST_BASE_URI, req.getBaseUri().toString());