Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][broker] Continue using the next provider for authentication if one fails #23797

Merged
merged 1 commit into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public class AuthenticationProviderList implements AuthenticationProvider {

private interface AuthProcessor<T, W> {

T apply(W process) throws AuthenticationException;
T apply(W process) throws Exception;

}

Expand All @@ -51,21 +51,30 @@ private enum ErrorCode {
AUTH_REQUIRED,
}

private static AuthenticationException newAuthenticationException(String message, Exception e) {
AuthenticationException authenticationException = new AuthenticationException(message);
if (e != null) {
authenticationException.initCause(e);
}
return authenticationException;
}

private static <T, W> T applyAuthProcessor(List<W> processors, AuthenticationMetrics metrics,
AuthProcessor<T, W> authFunc)
throws AuthenticationException {
AuthenticationException authenticationException = null;
Exception authenticationException = null;
String errorCode = ErrorCode.UNKNOWN.name();
for (W ap : processors) {
try {
return authFunc.apply(ap);
} catch (AuthenticationException ae) {
} catch (Exception ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + ap.getClass() + ": ", ae);
}
// Store the exception so we can throw it later instead of a generic one
authenticationException = ae;
errorCode = ap.getClass().getSimpleName() + "-INVALID-AUTH";
if (ae instanceof AuthenticationException) {
errorCode = ap.getClass().getSimpleName() + "-INVALID-AUTH";
}
}
}

Expand All @@ -76,7 +85,7 @@ private static <T, W> T applyAuthProcessor(List<W> processors, AuthenticationMet
} else {
metrics.recordFailure(AuthenticationProviderList.class.getSimpleName(),
"authentication-provider-list", errorCode);
throw authenticationException;
throw newAuthenticationException("Authentication failed", authenticationException);
}
}

Expand Down Expand Up @@ -290,12 +299,12 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA
throws AuthenticationException {
final List<AuthenticationState> states = new ArrayList<>(providers.size());

AuthenticationException authenticationException = null;
Exception authenticationException = null;
for (AuthenticationProvider provider : providers) {
try {
AuthenticationState state = provider.newAuthState(authData, remoteAddress, sslSession);
states.add(state);
} catch (AuthenticationException ae) {
} catch (Exception ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae);
}
Expand All @@ -305,11 +314,8 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA
}
if (states.isEmpty()) {
log.debug("Failed to initialize a new auth state from {}", remoteAddress, authenticationException);
if (authenticationException != null) {
throw authenticationException;
} else {
throw new AuthenticationException("Failed to initialize a new auth state from " + remoteAddress);
}
throw newAuthenticationException("Failed to initialize a new auth state from " + remoteAddress,
authenticationException);
} else {
return new AuthenticationListState(states, authenticationMetrics);
}
Expand All @@ -319,12 +325,12 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA
public AuthenticationState newHttpAuthState(HttpServletRequest request) throws AuthenticationException {
final List<AuthenticationState> states = new ArrayList<>(providers.size());

AuthenticationException authenticationException = null;
Exception authenticationException = null;
for (AuthenticationProvider provider : providers) {
try {
AuthenticationState state = provider.newHttpAuthState(request);
states.add(state);
} catch (AuthenticationException ae) {
} catch (Exception ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae);
}
Expand All @@ -335,35 +341,21 @@ public AuthenticationState newHttpAuthState(HttpServletRequest request) throws A
if (states.isEmpty()) {
log.debug("Failed to initialize a new http auth state from {}",
request.getRemoteHost(), authenticationException);
if (authenticationException != null) {
throw authenticationException;
} else {
throw new AuthenticationException(
"Failed to initialize a new http auth state from " + request.getRemoteHost());
}
throw newAuthenticationException(
"Failed to initialize a new http auth state from " + request.getRemoteHost(),
authenticationException);
} else {
return new AuthenticationListState(states, authenticationMetrics);
}
}

@Override
public boolean authenticateHttpRequest(HttpServletRequest request, HttpServletResponse response) throws Exception {
Boolean authenticated = applyAuthProcessor(
return applyAuthProcessor(
providers,
authenticationMetrics,
provider -> {
try {
return provider.authenticateHttpRequest(request, response);
} catch (Exception e) {
if (e instanceof AuthenticationException) {
throw (AuthenticationException) e;
} else {
throw new AuthenticationException("Failed to authentication http request");
}
}
}
provider -> provider.authenticateHttpRequest(request, response)
);
return authenticated;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedDataAttributeName;
import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedRoleAttributeName;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.mock;
Expand All @@ -35,13 +37,17 @@
import java.security.KeyPair;
import java.security.PrivateKey;
import java.util.Date;
import java.util.List;
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import javax.naming.AuthenticationException;
import javax.servlet.http.HttpServletRequest;
import org.apache.pulsar.broker.ServiceConfiguration;
import org.apache.pulsar.broker.authentication.utils.AuthTokenUtils;
import org.apache.pulsar.common.api.AuthData;
import org.apache.pulsar.common.util.FutureUtil;
import org.assertj.core.util.Lists;
import org.testng.annotations.AfterMethod;
import org.testng.annotations.BeforeMethod;
Expand Down Expand Up @@ -260,4 +266,125 @@ public void testAuthenticateHttpRequest() throws Exception {
verify(requestBB).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));
}

}
@Test
public void testAuthenticateWithMultipleProviders() throws Exception {
HttpServletRequest httpRequest = mock(HttpServletRequest.class);
AuthenticationDataSource authenticationDataSource = mock(AuthenticationDataSource.class);

AuthenticationProvider failingProvider = mock(AuthenticationProvider.class);
List<AuthenticationProvider> providers = Lists.newArrayList(
failingProvider
);
try (AuthenticationProvider provider = new AuthenticationProviderList(providers)) {
provider.initialize(new ServiceConfiguration());
RuntimeException authenticateException = new RuntimeException("authenticateException");

when(failingProvider.authenticateAsync(authenticationDataSource))
.thenReturn(FutureUtil.failedFuture(authenticateException));
when(failingProvider.authenticate(authenticationDataSource))
.thenThrow(authenticateException);
assertThat(provider.authenticateAsync(authenticationDataSource))
.failsWithin(3, TimeUnit.SECONDS)
.withThrowableThat().withCause(authenticateException);
assertThatThrownBy(() -> provider.authenticate(authenticationDataSource))
.isInstanceOf(AuthenticationException.class)
.hasCause(authenticateException);

RuntimeException authenticateHttpRequestException = new RuntimeException("authenticateHttpRequestAsync");
when(failingProvider.authenticateHttpRequestAsync(httpRequest, null))
.thenReturn(FutureUtil.failedFuture(authenticateHttpRequestException));
when(failingProvider.authenticateHttpRequest(httpRequest, null))
.thenThrow(authenticateHttpRequestException);
assertThat(provider.authenticateHttpRequestAsync(httpRequest, null))
.failsWithin(3, TimeUnit.SECONDS)
.withThrowableThat()
.havingCause()
.withCause(authenticateHttpRequestException);
assertThatThrownBy(() -> provider.authenticateHttpRequest(httpRequest, null))
.isInstanceOf(AuthenticationException.class)
.hasCause(authenticateHttpRequestException);

RuntimeException newAuthStateException = new RuntimeException("newAuthState");
when(failingProvider.newAuthState(null, null, null))
.thenThrow(newAuthStateException);
assertThatThrownBy(() -> provider.newAuthState(null, null, null))
.isInstanceOf(AuthenticationException.class)
.hasCause(newAuthStateException);

RuntimeException newHttpAuthStateException = new RuntimeException("newHttpAuthState");
when(failingProvider.newHttpAuthState(httpRequest))
.thenThrow(newHttpAuthStateException);
assertThatThrownBy(() -> provider.newHttpAuthState(httpRequest))
.isInstanceOf(AuthenticationException.class)
.hasCause(newHttpAuthStateException);
}

AuthenticationProvider successfulProvider = mock(AuthenticationProvider.class);
providers.add(successfulProvider);
String subject = "test-role";

try (AuthenticationProvider provider = new AuthenticationProviderList(providers)) {
provider.initialize(new ServiceConfiguration());

when(successfulProvider.authenticateAsync(authenticationDataSource))
.thenReturn(CompletableFuture.completedFuture(subject));
when(successfulProvider.authenticate(authenticationDataSource))
.thenReturn(subject);
assertThat(provider.authenticateAsync(authenticationDataSource))
.succeedsWithin(3, TimeUnit.SECONDS)
.matches(subject::equals);
assertThat(provider.authenticate(authenticationDataSource))
.isEqualTo(subject);

when(successfulProvider.authenticateHttpRequestAsync(httpRequest, null))
.thenReturn(CompletableFuture.completedFuture(true));
when(successfulProvider.authenticateHttpRequest(httpRequest, null))
.thenReturn(true);
assertThat(provider.authenticateHttpRequestAsync(httpRequest, null))
.succeedsWithin(3, TimeUnit.SECONDS)
.isEqualTo(true);
assertThat(provider.authenticateHttpRequest(httpRequest, null))
.isEqualTo(true);

AuthenticationState authenticationState = new AuthenticationState() {
@Override
public String getAuthRole() {
return subject;
}

@Override
public AuthData authenticate(AuthData authData) {
return null;
}

@Override
public AuthenticationDataSource getAuthDataSource() {
return null;
}

@Override
public boolean isComplete() {
return false;
}
};
when(successfulProvider.newAuthState(null, null, null))
.thenReturn(authenticationState);
when(successfulProvider.newHttpAuthState(httpRequest)).thenReturn(authenticationState);
verifyAuthenticationStateSuccess(provider.newAuthState(null, null, null), true, subject);
verifyAuthenticationStateSuccess(provider.newAuthState(null, null, null), false, subject);
verifyAuthenticationStateSuccess(provider.newHttpAuthState(httpRequest), true, subject);
verifyAuthenticationStateSuccess(provider.newHttpAuthState(httpRequest), false, subject);
}
}

private void verifyAuthenticationStateSuccess(AuthenticationState authState, boolean isAsync, String expectedRole)
throws Exception {
assertThat(authState).isNotNull();
if (isAsync) {
assertThat(authState.authenticateAsync(null)).succeedsWithin(3, TimeUnit.SECONDS);
} else {
assertThat(authState.authenticate(null)).isNull();
}
assertThat(authState.getAuthRole()).isEqualTo(expectedRole);
}
}
Loading