Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safety fixes in interpretation of X-Forwarded-Host #6

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
*******************************************************************************/
package org.everrest.core.servlet;

import org.everrest.core.ExtHttpHeaders;
import org.everrest.core.impl.ContainerRequest;
import org.everrest.core.impl.InputHeadersMap;
import org.everrest.core.impl.MultivaluedMapImpl;
import org.everrest.core.util.Logger;

import javax.servlet.http.HttpServletRequest;
import javax.ws.rs.core.MultivaluedMap;
Expand All @@ -26,12 +28,13 @@
import java.security.Principal;
import java.util.Enumeration;

import static org.everrest.core.ExtHttpHeaders.FORWARDED_HOST;

/** @author andrew00x */
public final class ServletContainerRequest extends ContainerRequest {

private static final Logger LOG = Logger.getLogger(ServletContainerRequest.class);

public static ServletContainerRequest create(final HttpServletRequest req) {
// If the URL is forwarded, obtain the forwarding information
final URL forwardedUrl = getForwardedUrl(req);
String host;
int port;
Expand All @@ -44,15 +47,31 @@ public static ServletContainerRequest create(final HttpServletRequest req) {
if (port < 0) {
port = forwardedUrl.getDefaultPort();
}
if (LOG.isInfoEnabled()) {
LOG.info("Assuming forwarded URL: " + forwardedUrl);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can replace this block with

LOG.debug("Assuming forwarded URL: {}" , forwardedUrl);

}
}

// The common URI prefix for both baseUri and requestUri
final StringBuilder commonUriBuilder = new StringBuilder();
final String scheme = getScheme(req);
final StringBuilder baseUriBuilder = uriBuilder(scheme, host, port);
commonUriBuilder.append(scheme);
commonUriBuilder.append("://");
commonUriBuilder.append(host);
if (!(port < 0 || (port == 80 && "http".equals(scheme)) || (port == 443 && "https".equals(scheme)))) {
commonUriBuilder.append(':');
commonUriBuilder.append(port);
}
final String commonUriPrefix = commonUriBuilder.toString();

// The Base URI - up to the servlet path
final StringBuilder baseUriBuilder = new StringBuilder(commonUriPrefix);
baseUriBuilder.append(req.getContextPath());
baseUriBuilder.append(req.getServletPath());
final URI baseUri = URI.create(baseUriBuilder.toString());

final StringBuilder requestUriBuilder = uriBuilder(scheme, host, port);
// The RequestURI - everything in the URL
final StringBuilder requestUriBuilder = new StringBuilder(commonUriPrefix);
requestUriBuilder.append(req.getRequestURI());
final String queryString = req.getQueryString();
if (queryString != null) {
Expand Down Expand Up @@ -84,51 +103,48 @@ private static String getScheme(HttpServletRequest servletRequest) {
return servletRequest.getScheme();
}

/**
* Get the URL that is forwarded using the standard X-Forwarded-Host header.
*
* @param servletRequest
* @return The URL of the forwarded host. If the header is missing or invalid, null is returned.
*/
private static URL getForwardedUrl(HttpServletRequest servletRequest) {
final String scheme = getScheme(servletRequest);
final String forwardedHostAndPort = servletRequest.getHeader(FORWARDED_HOST);
if (forwardedHostAndPort != null && !forwardedHostAndPort.isEmpty()) {
final String host = getForwardedHost(forwardedHostAndPort);
final int port = getForwardedPort(forwardedHostAndPort);
try {
// Use the standard URI to verify the host details
return new URI(scheme, null, host, port, null, null, null).toURL();
} catch (URISyntaxException | MalformedURLException e) {
return null;
}
if (forwardedHostAndPort == null || forwardedHostAndPort.isEmpty()) {
return null;
}
return null;
}

public static String getForwardedHost(String forwardedHostAndPort) {
final int colonIndex = forwardedHostAndPort.indexOf(':');
if (colonIndex < 0) {
return forwardedHostAndPort;
URL url = parseForwardedHostHeader(forwardedHostAndPort, servletRequest);
if (url == null && LOG.isWarnEnabled()) {
LOG.warn("Ignoring invalid " + ExtHttpHeaders.FORWARDED_HOST + ": " + forwardedHostAndPort);
}
return forwardedHostAndPort.substring(0, colonIndex);
return url;
}

public static int getForwardedPort(String forwardedHostAndPort) {
final int colonIndex = forwardedHostAndPort.indexOf(':');
if (colonIndex >= 0) {
/** Parse according to IETF standard for Host field: http://tools.ietf.org/html/rfc7230#section-5.4 */
private static URL parseForwardedHostHeader(String forwardedHostAndPort, HttpServletRequest servletRequest) {
final String[] parts = forwardedHostAndPort.split(":");
if (parts.length > 2) {
return null;
}
int fwdPort = -1;
if (parts.length == 2) {
try {
return Integer.parseInt(forwardedHostAndPort.substring(colonIndex + 1, forwardedHostAndPort.length()));
} catch (NumberFormatException ignored) {
fwdPort = Integer.parseInt(parts[1]);
} catch (NumberFormatException e) {
return null;
}
if (fwdPort < 0) {
return null;
}
}
return -1;
}

private static StringBuilder uriBuilder(String scheme, String host, int port) {
final StringBuilder uriBuilder = new StringBuilder();
uriBuilder.append(scheme);
uriBuilder.append("://");
uriBuilder.append(host);
if (!(port == 80 || (port == 443 && "https".equals(scheme)))) {
uriBuilder.append(':');
uriBuilder.append(port);
final String fwdHost = parts[0];
final String scheme = getScheme(servletRequest);
try {
return new URI(scheme, null, fwdHost, fwdPort, null, null, null).toURL();
} catch (URISyntaxException | MalformedURLException e) {
}
return uriBuilder;
return null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class ServletContainerRequestTest {

private static final String TEST_BASE_PATH = TEST_CONTEXT_PATH + TEST_SERVLET_PATH;
private static final String TEST_FULL_PATH = TEST_BASE_PATH + TEST_SUBPATH;

private static final String TEST_BASE_URI = TEST_SCHEME + TEST_HOST + TEST_BASE_PATH;
private static final String TEST_REQUEST_URI = TEST_BASE_URI + TEST_SUBPATH;

Expand Down Expand Up @@ -92,18 +92,27 @@ public String getPathInfo() {

@Test
public void testSimpleRequest() {
// A simple HTTP request
MockHttpServletRequest httpReq = new MockEmptyBodyHttpRequest(null, null);
ServletContainerRequest req = ServletContainerRequest.create(httpReq);
// Validate the fields
assertEquals(TEST_BASE_URI, req.getBaseUri().toString());
assertEquals(TEST_REQUEST_URI, req.getRequestUri().toString());
assertIgnoredForwardedHost(null);
}

@Test
public void testInvalidForwardedHost1() {
assertIgnoredForwardedHost("a b c");
}

@Test
public void testInvalidForwardedHost() {
public void testInvalidForwardedHost2() {
assertIgnoredForwardedHost("myhost.com:8877:200");
}

@Test
public void testInvalidForwardedHost3() {
assertIgnoredForwardedHost("myhost..com");
}

private static void assertIgnoredForwardedHost(String forwardedHostHeader) {
// A simple HTTP request
MockHttpServletRequest httpReq = new MockEmptyBodyHttpRequest("a b c", null);
MockHttpServletRequest httpReq = new MockEmptyBodyHttpRequest(forwardedHostHeader, null);
ServletContainerRequest req = ServletContainerRequest.create(httpReq);
// Validate the fields
assertEquals(TEST_BASE_URI, req.getBaseUri().toString());
Expand Down