Skip to content

Commit

Permalink
[fix][broker] Continue using the next provider for authentication if …
Browse files Browse the repository at this point in the history
…one fails (#23797)

Signed-off-by: Zixuan Liu <nodeces@gmail.com>
  • Loading branch information
nodece authored Jan 2, 2025
1 parent 9850605 commit 7619e2f
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class AuthenticationProviderList implements AuthenticationProvider {

private interface AuthProcessor<T, W> {

T apply(W process) throws AuthenticationException;
T apply(W process) throws Exception;

}

Expand All @@ -51,21 +51,30 @@ private enum ErrorCode {
AUTH_REQUIRED,
}

private static AuthenticationException newAuthenticationException(String message, Exception e) {
AuthenticationException authenticationException = new AuthenticationException(message);
if (e != null) {
authenticationException.initCause(e);
}
return authenticationException;
}

private static <T, W> T applyAuthProcessor(List<W> processors, AuthenticationMetrics metrics,
AuthProcessor<T, W> authFunc)
throws AuthenticationException {
AuthenticationException authenticationException = null;
Exception authenticationException = null;
String errorCode = ErrorCode.UNKNOWN.name();
for (W ap : processors) {
try {
return authFunc.apply(ap);
} catch (AuthenticationException ae) {
} catch (Exception ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + ap.getClass() + ": ", ae);
}
// Store the exception so we can throw it later instead of a generic one
authenticationException = ae;
errorCode = ap.getClass().getSimpleName() + "-INVALID-AUTH";
if (ae instanceof AuthenticationException) {
errorCode = ap.getClass().getSimpleName() + "-INVALID-AUTH";
}
}
}

Expand All @@ -76,7 +85,7 @@ private static <T, W> T applyAuthProcessor(List<W> processors, AuthenticationMet
} else {
metrics.recordFailure(AuthenticationProviderList.class.getSimpleName(),
"authentication-provider-list", errorCode);
throw authenticationException;
throw newAuthenticationException("Authentication failed", authenticationException);
}
}

Expand Down Expand Up @@ -290,12 +299,12 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA
throws AuthenticationException {
final List<AuthenticationState> states = new ArrayList<>(providers.size());

AuthenticationException authenticationException = null;
Exception authenticationException = null;
for (AuthenticationProvider provider : providers) {
try {
AuthenticationState state = provider.newAuthState(authData, remoteAddress, sslSession);
states.add(state);
} catch (AuthenticationException ae) {
} catch (Exception ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae);
}
Expand All @@ -305,11 +314,8 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA
}
if (states.isEmpty()) {
log.debug("Failed to initialize a new auth state from {}", remoteAddress, authenticationException);
if (authenticationException != null) {
throw authenticationException;
} else {
throw new AuthenticationException("Failed to initialize a new auth state from " + remoteAddress);
}
throw newAuthenticationException("Failed to initialize a new auth state from " + remoteAddress,
authenticationException);
} else {
return new AuthenticationListState(states, authenticationMetrics);
}
Expand All @@ -319,12 +325,12 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA
public AuthenticationState newHttpAuthState(HttpServletRequest request) throws AuthenticationException {
final List<AuthenticationState> states = new ArrayList<>(providers.size());

AuthenticationException authenticationException = null;
Exception authenticationException = null;
for (AuthenticationProvider provider : providers) {
try {
AuthenticationState state = provider.newHttpAuthState(request);
states.add(state);
} catch (AuthenticationException ae) {
} catch (Exception ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae);
}
Expand All @@ -335,35 +341,21 @@ public AuthenticationState newHttpAuthState(HttpServletRequest request) throws A
if (states.isEmpty()) {
log.debug("Failed to initialize a new http auth state from {}",
request.getRemoteHost(), authenticationException);
if (authenticationException != null) {
throw authenticationException;
} else {
throw new AuthenticationException(
"Failed to initialize a new http auth state from " + request.getRemoteHost());
}
throw newAuthenticationException(
"Failed to initialize a new http auth state from " + request.getRemoteHost(),
authenticationException);
} else {
return new AuthenticationListState(states, authenticationMetrics);
}
}

@Override
public boolean authenticateHttpRequest(HttpServletRequest request, HttpServletResponse response) throws Exception {
Boolean authenticated = applyAuthProcessor(
return applyAuthProcessor(
providers,
authenticationMetrics,
provider -> {
try {
return provider.authenticateHttpRequest(request, response);
} catch (Exception e) {
if (e instanceof AuthenticationException) {
throw (AuthenticationException) e;
} else {
throw new AuthenticationException("Failed to authentication http request");
}
}
}
provider -> provider.authenticateHttpRequest(request, response)
);
return authenticated;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedDataAttributeName;
import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedRoleAttributeName;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.mock;
Expand All @@ -35,13 +37,17 @@
import java.security.KeyPair;
import java.security.PrivateKey;
import java.util.Date;
import java.util.List;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import javax.naming.AuthenticationException;
import javax.servlet.http.HttpServletRequest;
import org.apache.pulsar.broker.ServiceConfiguration;
import org.apache.pulsar.broker.authentication.utils.AuthTokenUtils;
import org.apache.pulsar.common.api.AuthData;
import org.apache.pulsar.common.util.FutureUtil;
import org.assertj.core.util.Lists;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
Expand Down Expand Up @@ -260,4 +266,125 @@ public void testAuthenticateHttpRequest() throws Exception {
verify(requestBB).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));
}

}
@Test
public void testAuthenticateWithMultipleProviders() throws Exception {
HttpServletRequest httpRequest = mock(HttpServletRequest.class);
AuthenticationDataSource authenticationDataSource = mock(AuthenticationDataSource.class);

AuthenticationProvider failingProvider = mock(AuthenticationProvider.class);
List<AuthenticationProvider> providers = Lists.newArrayList(
failingProvider
);
try (AuthenticationProvider provider = new AuthenticationProviderList(providers)) {
provider.initialize(new ServiceConfiguration());
RuntimeException authenticateException = new RuntimeException("authenticateException");

when(failingProvider.authenticateAsync(authenticationDataSource))
.thenReturn(FutureUtil.failedFuture(authenticateException));
when(failingProvider.authenticate(authenticationDataSource))
.thenThrow(authenticateException);
assertThat(provider.authenticateAsync(authenticationDataSource))
.failsWithin(3, TimeUnit.SECONDS)
.withThrowableThat().withCause(authenticateException);
assertThatThrownBy(() -> provider.authenticate(authenticationDataSource))
.isInstanceOf(AuthenticationException.class)
.hasCause(authenticateException);

RuntimeException authenticateHttpRequestException = new RuntimeException("authenticateHttpRequestAsync");
when(failingProvider.authenticateHttpRequestAsync(httpRequest, null))
.thenReturn(FutureUtil.failedFuture(authenticateHttpRequestException));
when(failingProvider.authenticateHttpRequest(httpRequest, null))
.thenThrow(authenticateHttpRequestException);
assertThat(provider.authenticateHttpRequestAsync(httpRequest, null))
.failsWithin(3, TimeUnit.SECONDS)
.withThrowableThat()
.havingCause()
.withCause(authenticateHttpRequestException);
assertThatThrownBy(() -> provider.authenticateHttpRequest(httpRequest, null))
.isInstanceOf(AuthenticationException.class)
.hasCause(authenticateHttpRequestException);

RuntimeException newAuthStateException = new RuntimeException("newAuthState");
when(failingProvider.newAuthState(null, null, null))
.thenThrow(newAuthStateException);
assertThatThrownBy(() -> provider.newAuthState(null, null, null))
.isInstanceOf(AuthenticationException.class)
.hasCause(newAuthStateException);

RuntimeException newHttpAuthStateException = new RuntimeException("newHttpAuthState");
when(failingProvider.newHttpAuthState(httpRequest))
.thenThrow(newHttpAuthStateException);
assertThatThrownBy(() -> provider.newHttpAuthState(httpRequest))
.isInstanceOf(AuthenticationException.class)
.hasCause(newHttpAuthStateException);
}

AuthenticationProvider successfulProvider = mock(AuthenticationProvider.class);
providers.add(successfulProvider);
String subject = "test-role";

try (AuthenticationProvider provider = new AuthenticationProviderList(providers)) {
provider.initialize(new ServiceConfiguration());

when(successfulProvider.authenticateAsync(authenticationDataSource))
.thenReturn(CompletableFuture.completedFuture(subject));
when(successfulProvider.authenticate(authenticationDataSource))
.thenReturn(subject);
assertThat(provider.authenticateAsync(authenticationDataSource))
.succeedsWithin(3, TimeUnit.SECONDS)
.matches(subject::equals);
assertThat(provider.authenticate(authenticationDataSource))
.isEqualTo(subject);

when(successfulProvider.authenticateHttpRequestAsync(httpRequest, null))
.thenReturn(CompletableFuture.completedFuture(true));
when(successfulProvider.authenticateHttpRequest(httpRequest, null))
.thenReturn(true);
assertThat(provider.authenticateHttpRequestAsync(httpRequest, null))
.succeedsWithin(3, TimeUnit.SECONDS)
.isEqualTo(true);
assertThat(provider.authenticateHttpRequest(httpRequest, null))
.isEqualTo(true);

AuthenticationState authenticationState = new AuthenticationState() {
@Override
public String getAuthRole() {
return subject;
}

@Override
public AuthData authenticate(AuthData authData) {
return null;
}

@Override
public AuthenticationDataSource getAuthDataSource() {
return null;
}

@Override
public boolean isComplete() {
return false;
}
};
when(successfulProvider.newAuthState(null, null, null))
.thenReturn(authenticationState);
when(successfulProvider.newHttpAuthState(httpRequest)).thenReturn(authenticationState);
verifyAuthenticationStateSuccess(provider.newAuthState(null, null, null), true, subject);
verifyAuthenticationStateSuccess(provider.newAuthState(null, null, null), false, subject);
verifyAuthenticationStateSuccess(provider.newHttpAuthState(httpRequest), true, subject);
verifyAuthenticationStateSuccess(provider.newHttpAuthState(httpRequest), false, subject);
}
}

private void verifyAuthenticationStateSuccess(AuthenticationState authState, boolean isAsync, String expectedRole)
throws Exception {
assertThat(authState).isNotNull();
if (isAsync) {
assertThat(authState.authenticateAsync(null)).succeedsWithin(3, TimeUnit.SECONDS);
} else {
assertThat(authState.authenticate(null)).isNull();
}
assertThat(authState.getAuthRole()).isEqualTo(expectedRole);
}
}

0 comments on commit 7619e2f

Please sign in to comment.