diff --git a/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/AuthenticationProviderList.java b/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/AuthenticationProviderList.java index 0e5559b3c3aab..02ed52b5ec042 100644 --- a/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/AuthenticationProviderList.java +++ b/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/AuthenticationProviderList.java @@ -42,7 +42,7 @@ public class AuthenticationProviderList implements AuthenticationProvider { private interface AuthProcessor { - T apply(W process) throws AuthenticationException; + T apply(W process) throws Exception; } @@ -51,21 +51,30 @@ private enum ErrorCode { AUTH_REQUIRED, } + private static AuthenticationException newAuthenticationException(String message, Exception e) { + AuthenticationException authenticationException = new AuthenticationException(message); + if (e != null) { + authenticationException.initCause(e); + } + return authenticationException; + } + private static T applyAuthProcessor(List processors, AuthenticationMetrics metrics, AuthProcessor authFunc) throws AuthenticationException { - AuthenticationException authenticationException = null; + Exception authenticationException = null; String errorCode = ErrorCode.UNKNOWN.name(); for (W ap : processors) { try { return authFunc.apply(ap); - } catch (AuthenticationException ae) { + } catch (Exception ae) { if (log.isDebugEnabled()) { log.debug("Authentication failed for auth provider " + ap.getClass() + ": ", ae); } - // Store the exception so we can throw it later instead of a generic one authenticationException = ae; - errorCode = ap.getClass().getSimpleName() + "-INVALID-AUTH"; + if (ae instanceof AuthenticationException) { + errorCode = ap.getClass().getSimpleName() + "-INVALID-AUTH"; + } } } @@ -76,7 +85,7 @@ private static T applyAuthProcessor(List processors, AuthenticationMet } else { metrics.recordFailure(AuthenticationProviderList.class.getSimpleName(), "authentication-provider-list", errorCode); - throw authenticationException; + throw newAuthenticationException("Authentication failed", authenticationException); } } @@ -290,12 +299,12 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA throws AuthenticationException { final List states = new ArrayList<>(providers.size()); - AuthenticationException authenticationException = null; + Exception authenticationException = null; for (AuthenticationProvider provider : providers) { try { AuthenticationState state = provider.newAuthState(authData, remoteAddress, sslSession); states.add(state); - } catch (AuthenticationException ae) { + } catch (Exception ae) { if (log.isDebugEnabled()) { log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae); } @@ -305,11 +314,8 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA } if (states.isEmpty()) { log.debug("Failed to initialize a new auth state from {}", remoteAddress, authenticationException); - if (authenticationException != null) { - throw authenticationException; - } else { - throw new AuthenticationException("Failed to initialize a new auth state from " + remoteAddress); - } + throw newAuthenticationException("Failed to initialize a new auth state from " + remoteAddress, + authenticationException); } else { return new AuthenticationListState(states, authenticationMetrics); } @@ -319,12 +325,12 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA public AuthenticationState newHttpAuthState(HttpServletRequest request) throws AuthenticationException { final List states = new ArrayList<>(providers.size()); - AuthenticationException authenticationException = null; + Exception authenticationException = null; for (AuthenticationProvider provider : providers) { try { AuthenticationState state = provider.newHttpAuthState(request); states.add(state); - } catch (AuthenticationException ae) { + } catch (Exception ae) { if (log.isDebugEnabled()) { log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae); } @@ -335,12 +341,9 @@ public AuthenticationState newHttpAuthState(HttpServletRequest request) throws A if (states.isEmpty()) { log.debug("Failed to initialize a new http auth state from {}", request.getRemoteHost(), authenticationException); - if (authenticationException != null) { - throw authenticationException; - } else { - throw new AuthenticationException( - "Failed to initialize a new http auth state from " + request.getRemoteHost()); - } + throw newAuthenticationException( + "Failed to initialize a new http auth state from " + request.getRemoteHost(), + authenticationException); } else { return new AuthenticationListState(states, authenticationMetrics); } @@ -348,22 +351,11 @@ public AuthenticationState newHttpAuthState(HttpServletRequest request) throws A @Override public boolean authenticateHttpRequest(HttpServletRequest request, HttpServletResponse response) throws Exception { - Boolean authenticated = applyAuthProcessor( + return applyAuthProcessor( providers, authenticationMetrics, - provider -> { - try { - return provider.authenticateHttpRequest(request, response); - } catch (Exception e) { - if (e instanceof AuthenticationException) { - throw (AuthenticationException) e; - } else { - throw new AuthenticationException("Failed to authentication http request"); - } - } - } + provider -> provider.authenticateHttpRequest(request, response) ); - return authenticated; } @Override diff --git a/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/AuthenticationProviderListTest.java b/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/AuthenticationProviderListTest.java index e81198217b5b6..f139bb384a4be 100644 --- a/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/AuthenticationProviderListTest.java +++ b/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/AuthenticationProviderListTest.java @@ -21,6 +21,8 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedDataAttributeName; import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedRoleAttributeName; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.mock; @@ -35,13 +37,17 @@ import java.security.KeyPair; import java.security.PrivateKey; import java.util.Date; +import java.util.List; import java.util.Optional; import java.util.Properties; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; +import javax.naming.AuthenticationException; import javax.servlet.http.HttpServletRequest; import org.apache.pulsar.broker.ServiceConfiguration; import org.apache.pulsar.broker.authentication.utils.AuthTokenUtils; import org.apache.pulsar.common.api.AuthData; +import org.apache.pulsar.common.util.FutureUtil; import org.assertj.core.util.Lists; import org.testng.annotations.AfterMethod; import org.testng.annotations.BeforeMethod; @@ -260,4 +266,125 @@ public void testAuthenticateHttpRequest() throws Exception { verify(requestBB).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class)); } -} + @Test + public void testAuthenticateWithMultipleProviders() throws Exception { + HttpServletRequest httpRequest = mock(HttpServletRequest.class); + AuthenticationDataSource authenticationDataSource = mock(AuthenticationDataSource.class); + + AuthenticationProvider failingProvider = mock(AuthenticationProvider.class); + List providers = Lists.newArrayList( + failingProvider + ); + try (AuthenticationProvider provider = new AuthenticationProviderList(providers)) { + provider.initialize(new ServiceConfiguration()); + RuntimeException authenticateException = new RuntimeException("authenticateException"); + + when(failingProvider.authenticateAsync(authenticationDataSource)) + .thenReturn(FutureUtil.failedFuture(authenticateException)); + when(failingProvider.authenticate(authenticationDataSource)) + .thenThrow(authenticateException); + assertThat(provider.authenticateAsync(authenticationDataSource)) + .failsWithin(3, TimeUnit.SECONDS) + .withThrowableThat().withCause(authenticateException); + assertThatThrownBy(() -> provider.authenticate(authenticationDataSource)) + .isInstanceOf(AuthenticationException.class) + .hasCause(authenticateException); + + RuntimeException authenticateHttpRequestException = new RuntimeException("authenticateHttpRequestAsync"); + when(failingProvider.authenticateHttpRequestAsync(httpRequest, null)) + .thenReturn(FutureUtil.failedFuture(authenticateHttpRequestException)); + when(failingProvider.authenticateHttpRequest(httpRequest, null)) + .thenThrow(authenticateHttpRequestException); + assertThat(provider.authenticateHttpRequestAsync(httpRequest, null)) + .failsWithin(3, TimeUnit.SECONDS) + .withThrowableThat() + .havingCause() + .withCause(authenticateHttpRequestException); + assertThatThrownBy(() -> provider.authenticateHttpRequest(httpRequest, null)) + .isInstanceOf(AuthenticationException.class) + .hasCause(authenticateHttpRequestException); + + RuntimeException newAuthStateException = new RuntimeException("newAuthState"); + when(failingProvider.newAuthState(null, null, null)) + .thenThrow(newAuthStateException); + assertThatThrownBy(() -> provider.newAuthState(null, null, null)) + .isInstanceOf(AuthenticationException.class) + .hasCause(newAuthStateException); + + RuntimeException newHttpAuthStateException = new RuntimeException("newHttpAuthState"); + when(failingProvider.newHttpAuthState(httpRequest)) + .thenThrow(newHttpAuthStateException); + assertThatThrownBy(() -> provider.newHttpAuthState(httpRequest)) + .isInstanceOf(AuthenticationException.class) + .hasCause(newHttpAuthStateException); + } + + AuthenticationProvider successfulProvider = mock(AuthenticationProvider.class); + providers.add(successfulProvider); + String subject = "test-role"; + + try (AuthenticationProvider provider = new AuthenticationProviderList(providers)) { + provider.initialize(new ServiceConfiguration()); + + when(successfulProvider.authenticateAsync(authenticationDataSource)) + .thenReturn(CompletableFuture.completedFuture(subject)); + when(successfulProvider.authenticate(authenticationDataSource)) + .thenReturn(subject); + assertThat(provider.authenticateAsync(authenticationDataSource)) + .succeedsWithin(3, TimeUnit.SECONDS) + .matches(subject::equals); + assertThat(provider.authenticate(authenticationDataSource)) + .isEqualTo(subject); + + when(successfulProvider.authenticateHttpRequestAsync(httpRequest, null)) + .thenReturn(CompletableFuture.completedFuture(true)); + when(successfulProvider.authenticateHttpRequest(httpRequest, null)) + .thenReturn(true); + assertThat(provider.authenticateHttpRequestAsync(httpRequest, null)) + .succeedsWithin(3, TimeUnit.SECONDS) + .isEqualTo(true); + assertThat(provider.authenticateHttpRequest(httpRequest, null)) + .isEqualTo(true); + + AuthenticationState authenticationState = new AuthenticationState() { + @Override + public String getAuthRole() { + return subject; + } + + @Override + public AuthData authenticate(AuthData authData) { + return null; + } + + @Override + public AuthenticationDataSource getAuthDataSource() { + return null; + } + + @Override + public boolean isComplete() { + return false; + } + }; + when(successfulProvider.newAuthState(null, null, null)) + .thenReturn(authenticationState); + when(successfulProvider.newHttpAuthState(httpRequest)).thenReturn(authenticationState); + verifyAuthenticationStateSuccess(provider.newAuthState(null, null, null), true, subject); + verifyAuthenticationStateSuccess(provider.newAuthState(null, null, null), false, subject); + verifyAuthenticationStateSuccess(provider.newHttpAuthState(httpRequest), true, subject); + verifyAuthenticationStateSuccess(provider.newHttpAuthState(httpRequest), false, subject); + } + } + + private void verifyAuthenticationStateSuccess(AuthenticationState authState, boolean isAsync, String expectedRole) + throws Exception { + assertThat(authState).isNotNull(); + if (isAsync) { + assertThat(authState.authenticateAsync(null)).succeedsWithin(3, TimeUnit.SECONDS); + } else { + assertThat(authState.authenticate(null)).isNull(); + } + assertThat(authState.getAuthRole()).isEqualTo(expectedRole); + } +} \ No newline at end of file