Skip to content

Commit

Permalink
refactor(headers): Refactor header propagation logic
Browse files Browse the repository at this point in the history
This also adds ability to propagate any `X-SPINNAKER-*` header automatically
And, ability to notify with a metric and a log when a request doesn't have
headers needed to successfully pass identity of the caller
(i.e. `X-SPINNAKER-USER` and `X-SPINNAKER-ACCOUNTS`)

symmetric PR in orca
  • Loading branch information
marchello2000 committed Apr 30, 2019
1 parent da4a880 commit 54bd383
Show file tree
Hide file tree
Showing 6 changed files with 306 additions and 164 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,64 @@
import org.springframework.util.CollectionUtils;

public class AuthenticatedRequest {
public static final String SPINNAKER_USER = "X-SPINNAKER-USER";
public static final String SPINNAKER_ACCOUNTS = "X-SPINNAKER-ACCOUNTS";
public static final String SPINNAKER_USER_ORIGIN = "X-SPINNAKER-USER-ORIGIN";
public static final String SPINNAKER_REQUEST_ID = "X-SPINNAKER-REQUEST-ID";
public static final String SPINNAKER_EXECUTION_ID = "X-SPINNAKER-EXECUTION-ID";
/**
* Known X-SPINNAKER headers, but any X-SPINNAKER-* key in the MDC will be automatically
* propagated to the HTTP headers.
*
* <p>Use makeCustomerHeader() to add customer headers
*/
public enum Header {
USER("X-SPINNAKER-USER", true),
ACCOUNTS("X-SPINNAKER-ACCOUNTS", true),
USER_ORIGIN("X-SPINNAKER-USER-ORIGIN", false),
REQUEST_ID("X-SPINNAKER-REQUEST-ID", false),
EXECUTION_ID("X-SPINNAKER-EXECUTION-ID", false),
APPLICATION("X-SPINNAKER-APPLICATION", false);

private String header;
private boolean isRequired;

Header(String header, boolean isRequired) {
this.header = header;
this.isRequired = isRequired;
}

public String getHeader() {
return header;
}

public boolean isRequired() {
return isRequired;
}

public static String XSpinnakerPrefix = "X-SPINNAKER-";
public static String XSpinnakerAnonymous = XSpinnakerPrefix + "ANONYMOUS";

public static String makeCustomHeader(String header) {
return XSpinnakerPrefix + header.toUpperCase();
}
}

/**
* Allow a given HTTP call to be anonymous. Normally, all requests to Spinnaker services should be
* authenticated (i.e. include USER & ACCOUNTS HTTP headers). However, in specific cases it is
* necessary to make an anonymous call. If an anonymous call is made that is not wrapped in this
* method, it will result in a log message and a metric being logged (indicating a potential bug).
* Use this method to avoid the log and metric. To make an anonymous call wrap it in this
* function, e.g.
*
* <p>AuthenticatedRequest.allowAnonymous(() -> { // do HTTP call here });
*/
public static <V> V allowAnonymous(Callable<V> closure) throws Exception {
String originalValue = MDC.get(Header.XSpinnakerAnonymous);
MDC.put(Header.XSpinnakerAnonymous, "anonymous");

try {
return closure.call();
} finally {
setOrRemoveMdc(Header.XSpinnakerAnonymous, originalValue);
}
}

public static <V> Callable<V> propagate(Callable<V> closure) {
return propagate(closure, true, principal());
Expand All @@ -55,43 +108,20 @@ public static <V> Callable<V> propagate(
String executionId = getSpinnakerExecutionId().orElse(null);
String requestId = getSpinnakerRequestId().orElse(null);
String spinnakerAccounts = getSpinnakerAccounts(principal).orElse(null);
String spinnakerApp = getSpinnakerApplication().orElse(null);

return () -> {
String originalSpinnakerUser = MDC.get(SPINNAKER_USER);
String originalSpinnakerUserOrigin = MDC.get(SPINNAKER_USER_ORIGIN);
String originalSpinnakerAccounts = MDC.get(SPINNAKER_ACCOUNTS);
String originalSpinnakerRequestId = MDC.get(SPINNAKER_REQUEST_ID);
String originalSpinnakerExecutionId = MDC.get(SPINNAKER_EXECUTION_ID);
try {
if (spinnakerUser != null) {
MDC.put(SPINNAKER_USER, spinnakerUser);
} else {
MDC.remove(SPINNAKER_USER);
}

if (userOrigin != null) {
MDC.put(SPINNAKER_USER_ORIGIN, userOrigin);
} else {
MDC.remove(SPINNAKER_USER_ORIGIN);
}
// Deal with (set/reset) known X-SPINNAKER headers, all others will just stick around
Map originalMdc = MDC.getCopyOfContextMap();

if (spinnakerAccounts != null) {
MDC.put(SPINNAKER_ACCOUNTS, spinnakerAccounts);
} else {
MDC.remove(SPINNAKER_ACCOUNTS);
}

if (executionId != null) {
MDC.put(SPINNAKER_EXECUTION_ID, executionId);
} else {
MDC.remove(SPINNAKER_EXECUTION_ID);
}
try {
setOrRemoveMdc(Header.USER.getHeader(), spinnakerUser);
setOrRemoveMdc(Header.USER_ORIGIN.getHeader(), userOrigin);
setOrRemoveMdc(Header.ACCOUNTS.getHeader(), spinnakerAccounts);
setOrRemoveMdc(Header.REQUEST_ID.getHeader(), executionId);
setOrRemoveMdc(Header.EXECUTION_ID.getHeader(), requestId);
setOrRemoveMdc(Header.APPLICATION.getHeader(), spinnakerApp);

if (requestId != null) {
MDC.put(SPINNAKER_REQUEST_ID, requestId);
} else {
MDC.remove(SPINNAKER_REQUEST_ID);
}
return closure.call();
} finally {
MDC.clear();
Expand All @@ -103,38 +133,43 @@ public static <V> Callable<V> propagate(
} catch (Exception ignored) {
}

if (restoreOriginalContext) {
if (originalSpinnakerUser != null) {
MDC.put(SPINNAKER_USER, originalSpinnakerUser);
}

if (originalSpinnakerUserOrigin != null) {
MDC.put(SPINNAKER_USER_ORIGIN, originalSpinnakerUserOrigin);
}

if (originalSpinnakerAccounts != null) {
MDC.put(SPINNAKER_ACCOUNTS, originalSpinnakerAccounts);
}

if (originalSpinnakerRequestId != null) {
MDC.put(SPINNAKER_REQUEST_ID, originalSpinnakerRequestId);
}

if (originalSpinnakerExecutionId != null) {
MDC.put(SPINNAKER_EXECUTION_ID, originalSpinnakerExecutionId);
}
if (restoreOriginalContext && originalMdc != null) {
MDC.setContextMap(originalMdc);
}
}
};
}

private static void setOrRemoveMdc(String key, String value) {
if (value != null) {
MDC.put(key, value);
} else {
MDC.remove(key);
}
}

public static Map<String, Optional<String>> getAuthenticationHeaders() {
Map<String, Optional<String>> headers = new HashMap<>();
headers.put(SPINNAKER_USER, getSpinnakerUser());
headers.put(SPINNAKER_ACCOUNTS, getSpinnakerAccounts());
headers.put(SPINNAKER_USER_ORIGIN, getSpinnakerUserOrigin());
headers.put(SPINNAKER_REQUEST_ID, getSpinnakerRequestId());
headers.put(SPINNAKER_EXECUTION_ID, getSpinnakerExecutionId());
headers.put(Header.USER.getHeader(), getSpinnakerUser());
headers.put(Header.ACCOUNTS.getHeader(), getSpinnakerAccounts());

// Copy all headers that look like X-SPINNAKER*
Map<String, String> allMdcEntries = MDC.getCopyOfContextMap();

for (Map.Entry<String, String> mdcEntry : allMdcEntries.entrySet()) {
String header = mdcEntry.getKey();

boolean isSpinnakerHeader =
header.toLowerCase().startsWith(Header.XSpinnakerPrefix.toLowerCase());
boolean isSpinnakerAuthHeader =
Header.USER.getHeader().equalsIgnoreCase(header)
|| Header.ACCOUNTS.getHeader().equalsIgnoreCase(header);

if (isSpinnakerHeader && !isSpinnakerAuthHeader) {
headers.put(header, Optional.of(mdcEntry.getValue()));
}
}

return headers;
}

Expand All @@ -143,7 +178,7 @@ public static Optional<String> getSpinnakerUser() {
}

public static Optional<String> getSpinnakerUser(Object principal) {
Object spinnakerUser = MDC.get(SPINNAKER_USER);
Object spinnakerUser = MDC.get(Header.USER.getHeader());

if (principal != null && principal instanceof User) {
spinnakerUser = ((User) principal).getUsername();
Expand All @@ -157,7 +192,7 @@ public static Optional<String> getSpinnakerAccounts() {
}

public static Optional<String> getSpinnakerAccounts(Object principal) {
Object spinnakerAccounts = MDC.get(SPINNAKER_ACCOUNTS);
Object spinnakerAccounts = MDC.get(Header.ACCOUNTS.getHeader());

if (principal instanceof User && !CollectionUtils.isEmpty(((User) principal).allowedAccounts)) {
spinnakerAccounts = String.join(",", ((User) principal).getAllowedAccounts());
Expand All @@ -166,10 +201,6 @@ public static Optional<String> getSpinnakerAccounts(Object principal) {
return Optional.ofNullable((String) spinnakerAccounts);
}

public static Optional<String> getSpinnakerUserOrigin() {
return Optional.ofNullable(MDC.get(SPINNAKER_USER_ORIGIN));
}

/**
* Returns or creates a spinnaker request ID.
*
Expand All @@ -182,15 +213,23 @@ public static Optional<String> getSpinnakerUserOrigin() {
*/
public static Optional<String> getSpinnakerRequestId() {
return Optional.of(
Optional.ofNullable(MDC.get(SPINNAKER_REQUEST_ID))
Optional.ofNullable(MDC.get(Header.REQUEST_ID.getHeader()))
.orElse(
getSpinnakerExecutionId()
.map(id -> format("%s:%s", id, UUID.randomUUID().toString()))
.orElse(UUID.randomUUID().toString())));
}

public static Optional<String> getSpinnakerUserOrigin() {
return Optional.ofNullable(MDC.get(Header.USER_ORIGIN.getHeader()));
}

public static Optional<String> getSpinnakerExecutionId() {
return Optional.ofNullable(MDC.get(SPINNAKER_EXECUTION_ID));
return Optional.ofNullable(MDC.get(Header.EXECUTION_ID.getHeader()));
}

private static Optional<String> getSpinnakerApplication() {
return Optional.ofNullable(MDC.get(Header.APPLICATION.getHeader()));
}

/** @return the Spring Security principal or null if there is no authority. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class AuthenticatedRequestSpec extends Specification {
void "should extract user details by priority (Principal > MDC)"() {
when:
MDC.clear()
MDC.put(AuthenticatedRequest.SPINNAKER_USER, "spinnaker-user")
MDC.put(AuthenticatedRequest.Header.USER.header, "spinnaker-user")
then:
AuthenticatedRequest.getSpinnakerUser().get() == "spinnaker-user"
Expand All @@ -33,7 +33,7 @@ class AuthenticatedRequestSpec extends Specification {
void "should extract allowed account details by priority (Principal > MDC"() {
when:
MDC.clear()
MDC.put(AuthenticatedRequest.SPINNAKER_ACCOUNTS, "account1,account2")
MDC.put(AuthenticatedRequest.Header.ACCOUNTS.header, "account1,account2")
then:
AuthenticatedRequest.getSpinnakerAccounts().get() == "account1,account2"
Expand All @@ -51,23 +51,23 @@ class AuthenticatedRequestSpec extends Specification {
void "should propagate user/allowed account details"() {
when:
MDC.put(AuthenticatedRequest.SPINNAKER_USER, "spinnaker-user")
MDC.put(AuthenticatedRequest.SPINNAKER_ACCOUNTS, "account1,account2")
MDC.put(AuthenticatedRequest.Header.USER.header, "spinnaker-user")
MDC.put(AuthenticatedRequest.Header.ACCOUNTS.header, "account1,account2")
def closure = AuthenticatedRequest.propagate({
assert AuthenticatedRequest.getSpinnakerUser().get() == "spinnaker-user"
assert AuthenticatedRequest.getSpinnakerAccounts().get() == "account1,account2"
return true
})
MDC.put(AuthenticatedRequest.SPINNAKER_USER, "spinnaker-another-user")
MDC.put(AuthenticatedRequest.SPINNAKER_ACCOUNTS, "account1,account3")
MDC.put(AuthenticatedRequest.Header.USER.header, "spinnaker-another-user")
MDC.put(AuthenticatedRequest.Header.ACCOUNTS.header, "account1,account3")
closure.call()
then:
// ensure MDC context is restored
MDC.get(AuthenticatedRequest.SPINNAKER_USER) == "spinnaker-another-user"
MDC.get(AuthenticatedRequest.SPINNAKER_ACCOUNTS) == "account1,account3"
MDC.get(AuthenticatedRequest.Header.USER.header) == "spinnaker-another-user"
MDC.get(AuthenticatedRequest.Header.ACCOUNTS.header) == "account1,account3"
when:
MDC.clear()
Expand All @@ -76,4 +76,17 @@ class AuthenticatedRequestSpec extends Specification {
closure.call()
MDC.clear()
}
void "should propagate headers"() {
when:
MDC.clear()
MDC.put(AuthenticatedRequest.Header.USER.header, "spinnaker-another-user")
MDC.put(AuthenticatedRequest.Header.makeCustomHeader("cloudprovider"), "aws")
then:
AuthenticatedRequest.getAuthenticationHeaders() == [
'X-SPINNAKER-USER': Optional.of("spinnaker-another-user"),
'X-SPINNAKER-ACCOUNTS': Optional.empty(),
'X-SPINNAKER-CLOUDPROVIDER': Optional.of("aws")]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.netflix.spinnaker.filters


import com.netflix.spinnaker.security.User
import groovy.util.logging.Slf4j
import org.slf4j.MDC
Expand Down Expand Up @@ -69,9 +70,7 @@ class AuthenticatedRequestFilter implements Filter {
void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
def spinnakerUser = null
def spinnakerAccounts = null
def spinnakerUserOrigin = null
def spinnakerRequestId = null
def spinnakerExecutionId = null
HashMap<String, String> otherSpinnakerHeaders = new HashMap<>()

try {
def session = ((HttpServletRequest) request).getSession(false)
Expand All @@ -90,17 +89,28 @@ class AuthenticatedRequestFilter implements Filter {

if (extractSpinnakerHeaders) {
def httpServletRequest = (HttpServletRequest) request
spinnakerUser = spinnakerUser ?: httpServletRequest.getHeader(SPINNAKER_USER)
spinnakerAccounts = spinnakerAccounts ?: httpServletRequest.getHeader(SPINNAKER_ACCOUNTS)
spinnakerUserOrigin = httpServletRequest.getHeader(SPINNAKER_USER_ORIGIN)
spinnakerRequestId = httpServletRequest.getHeader(SPINNAKER_REQUEST_ID)
spinnakerExecutionId = httpServletRequest.getHeader(SPINNAKER_EXECUTION_ID)
spinnakerUser = spinnakerUser ?: httpServletRequest.getHeader(Header.USER.getHeader())
spinnakerAccounts = spinnakerAccounts ?: httpServletRequest.getHeader(Header.ACCOUNTS.getHeader())

Enumeration<String> headers = httpServletRequest.getHeaderNames()

for (header in headers) {
if (header.startsWith(Header.XSpinnakerPrefix)) {
otherSpinnakerHeaders.put(header, httpServletRequest.getHeader(header))
}
}
}
if (extractSpinnakerUserOriginHeader) {
spinnakerUserOrigin = "deck".equalsIgnoreCase(((HttpServletRequest) request).getHeader("X-RateLimit-App")) ? "deck" : "api"
otherSpinnakerHeaders.put(
Header.USER_ORIGIN.getHeader(),
"deck".equalsIgnoreCase(((HttpServletRequest) request).getHeader("X-RateLimit-App")) ? "deck" : "api"
)
}
if (forceNewSpinnakerRequestId) {
spinnakerRequestId = UUID.randomUUID().toString()
otherSpinnakerHeaders.put(
Header.REQUEST_ID.getHeader(),
UUID.randomUUID().toString()
)
}

// only extract from the x509 certificate if `spinnakerUser` has not been supplied as a header
Expand All @@ -116,19 +126,13 @@ class AuthenticatedRequestFilter implements Filter {

try {
if (spinnakerUser) {
MDC.put(SPINNAKER_USER, spinnakerUser)
MDC.put(Header.USER.getHeader(), spinnakerUser)
}
if (spinnakerAccounts) {
MDC.put(SPINNAKER_ACCOUNTS, spinnakerAccounts)
}
if (spinnakerUserOrigin) {
MDC.put(SPINNAKER_USER_ORIGIN, spinnakerUserOrigin)
}
if (spinnakerRequestId) {
MDC.put(SPINNAKER_REQUEST_ID, spinnakerRequestId)
MDC.put(Header.ACCOUNTS.getHeader(), spinnakerAccounts)
}
if (spinnakerExecutionId) {
MDC.put(SPINNAKER_EXECUTION_ID, spinnakerExecutionId)
for (header in otherSpinnakerHeaders) {
MDC.put(header.key, header.value)
}

chain.doFilter(request, response)
Expand Down
Loading

0 comments on commit 54bd383

Please sign in to comment.