Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][broker] Continue using the next provider for http authentication if one fails #23842

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
import java.io.Closeable;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;
import javax.naming.AuthenticationException;
import javax.servlet.http.HttpServletRequest;
Expand All @@ -49,7 +48,7 @@ public class AuthenticationService implements Closeable {
private static final Logger LOG = LoggerFactory.getLogger(AuthenticationService.class);
private final String anonymousUserRole;

private final Map<String, AuthenticationProvider> providers = new HashMap<>();
private final Map<String, AuthenticationProvider> providers = new LinkedHashMap<>();

public AuthenticationService(ServiceConfiguration conf) throws PulsarServerException {
this(conf, OpenTelemetry.noop());
Expand All @@ -60,7 +59,7 @@ public AuthenticationService(ServiceConfiguration conf, OpenTelemetry openTeleme
anonymousUserRole = conf.getAnonymousUserRole();
if (conf.isAuthenticationEnabled()) {
try {
Map<String, List<AuthenticationProvider>> providerMap = new HashMap<>();
Map<String, List<AuthenticationProvider>> providerMap = new LinkedHashMap<>();
for (String className : conf.getAuthenticationProviders()) {
if (className.isEmpty()) {
continue;
Expand Down Expand Up @@ -131,7 +130,7 @@ public boolean authenticateHttpRequest(HttpServletRequest request, HttpServletRe
AuthenticationProvider providerToUse = getAuthProvider(authMethodName);
try {
return providerToUse.authenticateHttpRequest(request, response);
} catch (AuthenticationException e) {
} catch (Exception e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Authentication failed for provider " + providerToUse.getAuthMethodName() + " : "
+ e.getMessage(), e);
Expand All @@ -142,7 +141,7 @@ public boolean authenticateHttpRequest(HttpServletRequest request, HttpServletRe
for (AuthenticationProvider provider : providers.values()) {
try {
return provider.authenticateHttpRequest(request, response);
} catch (AuthenticationException e) {
} catch (Exception e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Authentication failed for provider " + provider.getAuthMethodName() + ": "
+ e.getMessage(), e);
Expand Down Expand Up @@ -183,25 +182,18 @@ public String authenticateHttpRequest(HttpServletRequest request, Authentication
}
// Backward compatible, the authData value was null in the previous implementation
return providerToUse.authenticateAsync(authData).get();
} catch (AuthenticationException e) {
} catch (Exception e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Authentication failed for provider " + providerToUse.getAuthMethodName() + " : "
+ e.getMessage(), e);
}
throw e;
} catch (ExecutionException | InterruptedException e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Authentication failed for provider " + providerToUse.getAuthMethodName() + " : "
+ e.getMessage(), e);
}
throw new RuntimeException(e);
}
} else {
for (AuthenticationProvider provider : providers.values()) {
try {
AuthenticationState authenticationState = provider.newHttpAuthState(request);
return provider.authenticateAsync(authenticationState.getAuthDataSource()).get();
} catch (ExecutionException | InterruptedException | AuthenticationException e) {
} catch (Exception e) {
if (LOG.isDebugEnabled()) {
LOG.debug("Authentication failed for provider " + provider.getAuthMethodName() + ": "
+ e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedDataAttributeName;
import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedRoleAttributeName;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
Expand All @@ -29,15 +31,22 @@
import static org.testng.Assert.assertTrue;
import com.google.common.collect.Sets;
import java.io.IOException;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import javax.naming.AuthenticationException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import lombok.Cleanup;
import org.apache.pulsar.broker.ServiceConfiguration;
import org.apache.pulsar.broker.authentication.AuthenticationDataCommand;
import org.apache.pulsar.broker.authentication.AuthenticationDataHttps;
import org.apache.pulsar.broker.authentication.AuthenticationDataSource;
import org.apache.pulsar.broker.authentication.AuthenticationProvider;
import org.apache.pulsar.broker.authentication.AuthenticationService;
import org.apache.pulsar.broker.authentication.AuthenticationState;
import org.apache.pulsar.broker.web.AuthenticationFilter;
import org.apache.pulsar.common.api.AuthData;
import org.testng.Assert;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -166,6 +175,123 @@ public void testAuthenticationHttpRequestResponseWithAnonymousRole() throws Exce
service.close();
}

@Test
public void testHttpRequestWithMultipleProviders() throws Exception {
ServiceConfiguration config = new ServiceConfiguration();
Set<String> providersClassNames = new LinkedHashSet<>();
providersClassNames.add(MockAuthenticationProviderAlwaysFail.class.getName());
providersClassNames.add(MockHttpAuthenticationProvider.class.getName());
config.setAuthenticationProviders(providersClassNames);
config.setAuthenticationEnabled(true);
@Cleanup
AuthenticationService service = new AuthenticationService(config);

HttpServletRequest request = mock(HttpServletRequest.class);

when(request.getParameter("role")).thenReturn("success-role1");
assertTrue(service.authenticateHttpRequest(request, (HttpServletResponse) null));

when(request.getParameter("role")).thenReturn("");
assertThatThrownBy(() -> service.authenticateHttpRequest(request, (HttpServletResponse) null))
.isInstanceOf(AuthenticationException.class);

when(request.getParameter("role")).thenReturn("error-role1");
assertThatThrownBy(() -> service.authenticateHttpRequest(request, (HttpServletResponse) null))
.isInstanceOf(AuthenticationException.class);

when(request.getHeader(AuthenticationFilter.PULSAR_AUTH_METHOD_NAME)).thenReturn("http-auth");
assertThatThrownBy(() -> service.authenticateHttpRequest(request, (HttpServletResponse) null))
.isInstanceOf(RuntimeException.class);

HttpServletRequest requestForAuthenticationDataSource = mock(HttpServletRequest.class);
assertThatThrownBy(() -> service.authenticateHttpRequest(requestForAuthenticationDataSource,
(AuthenticationDataSource) null))
.isInstanceOf(AuthenticationException.class);

when(requestForAuthenticationDataSource.getParameter("role")).thenReturn("error-role2");
assertThatThrownBy(() -> service.authenticateHttpRequest(requestForAuthenticationDataSource,
(AuthenticationDataSource) null))
.isInstanceOf(AuthenticationException.class);

when(requestForAuthenticationDataSource.getParameter("role")).thenReturn("success-role2");
assertThat(service.authenticateHttpRequest(requestForAuthenticationDataSource,
(AuthenticationDataSource) null)).isEqualTo("role2");
}

public static class MockHttpAuthenticationProvider implements AuthenticationProvider {
@Override
public void close() throws IOException {
}

@Override
public void initialize(ServiceConfiguration config) throws IOException {
}

@Override
public String getAuthMethodName() {
return "http-auth";
}

private String getRole(HttpServletRequest request) {
String role = request.getParameter("role");
if (role != null) {
String[] s = role.split("-");
if (s.length == 2 && s[0].equals("success")) {
return s[1];
}
}
return null;
}

@Override
public boolean authenticateHttpRequest(HttpServletRequest request, HttpServletResponse response) {
String role = getRole(request);
if (role != null) {
return true;
}
throw new RuntimeException("test authentication failed");
}

@Override
public String authenticate(AuthenticationDataSource authData) throws AuthenticationException {
return authData.getCommandData();
}

@Override
public AuthenticationState newHttpAuthState(HttpServletRequest request) throws AuthenticationException {
String role = getRole(request);
if (role != null) {
return new AuthenticationState() {
@Override
public String getAuthRole() throws AuthenticationException {
return role;
}

@Override
public AuthData authenticate(AuthData authData) throws AuthenticationException {
return null;
}

@Override
public AuthenticationDataSource getAuthDataSource() {
return new AuthenticationDataCommand(role);
}

@Override
public boolean isComplete() {
return true;
}

@Override
public CompletableFuture<AuthData> authenticateAsync(AuthData authData) {
return AuthenticationState.super.authenticateAsync(authData);
}
};
}
throw new RuntimeException("new http auth failed");
}
}

public static class MockAuthenticationProvider implements AuthenticationProvider {

@Override
Expand Down
Loading