Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix][broker] TokenAuthenticationState: authenticate token only once #19314

Merged
merged 10 commits into from
Feb 1, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,17 @@ public AuthenticationState newAuthState(AuthData authData, SocketAddress remoteA
final List<AuthenticationState> states = new ArrayList<>(providers.size());

AuthenticationException authenticationException = null;
try {
applyAuthProcessor(
providers,
provider -> {
AuthenticationState state = provider.newAuthState(authData, remoteAddress, sslSession);
states.add(state);
return state;
for (AuthenticationProvider provider : providers) {
try {
AuthenticationState state = provider.newAuthState(authData, remoteAddress, sslSession);
states.add(state);
} catch (AuthenticationException ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae);
}
);
} catch (AuthenticationException ae) {
authenticationException = ae;
// Store the exception so we can throw it later instead of a generic one
authenticationException = ae;
}
}
if (states.isEmpty()) {
log.debug("Failed to initialize a new auth state from {}", remoteAddress, authenticationException);
Expand All @@ -203,17 +203,17 @@ public AuthenticationState newHttpAuthState(HttpServletRequest request) throws A
final List<AuthenticationState> states = new ArrayList<>(providers.size());

AuthenticationException authenticationException = null;
try {
applyAuthProcessor(
providers,
provider -> {
AuthenticationState state = provider.newHttpAuthState(request);
states.add(state);
return state;
}
);
} catch (AuthenticationException ae) {
authenticationException = ae;
for (AuthenticationProvider provider : providers) {
try {
AuthenticationState state = provider.newHttpAuthState(request);
states.add(state);
} catch (AuthenticationException ae) {
if (log.isDebugEnabled()) {
log.debug("Authentication failed for auth provider " + provider.getClass() + ": ", ae);
}
// Store the exception so we can throw it later instead of a generic one
authenticationException = ae;
}
}
if (states.isEmpty()) {
log.debug("Failed to initialize a new http auth state from {}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,9 +343,6 @@ private static final class TokenAuthenticationState implements AuthenticationSta
SocketAddress remoteAddress,
SSLSession sslSession) throws AuthenticationException {
this.provider = provider;
String token = new String(authData.getBytes(), UTF_8);
this.authenticationDataSource = new AuthenticationDataCommand(token, remoteAddress, sslSession);
this.checkExpiration(token);
this.remoteAddress = remoteAddress;
this.sslSession = sslSession;
}
Expand All @@ -354,15 +351,9 @@ private static final class TokenAuthenticationState implements AuthenticationSta
AuthenticationProviderToken provider,
HttpServletRequest request) throws AuthenticationException {
this.provider = provider;
String httpHeaderValue = request.getHeader(HTTP_HEADER_NAME);
if (httpHeaderValue == null || !httpHeaderValue.startsWith(HTTP_HEADER_VALUE_PREFIX)) {
throw new AuthenticationException("Invalid HTTP Authorization header");
}

// Remove prefix
String token = httpHeaderValue.substring(HTTP_HEADER_VALUE_PREFIX.length());
// Set this for backwards compatibility with AuthenticationProvider#newHttpAuthState
this.authenticationDataSource = new AuthenticationDataHttps(request);
this.checkExpiration(token);

// These are not used when this constructor is invoked, set them to null.
this.sslSession = null;
Expand All @@ -371,6 +362,9 @@ private static final class TokenAuthenticationState implements AuthenticationSta

@Override
public String getAuthRole() throws AuthenticationException {
if (jwt == null) {
throw new AuthenticationException("Must authenticate before calling getAuthRole");
}
return provider.getPrincipal(jwt);
}

Expand Down Expand Up @@ -404,8 +398,8 @@ public AuthenticationDataSource getAuthDataSource() {

@Override
public boolean isComplete() {
// The authentication of tokens is always done in one single stage
return true;
// The authentication of tokens is always done in one single stage, so once jwt is set, it is "complete"
return jwt != null;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ public AuthenticationService(ServiceConfiguration conf) throws PulsarServerExcep
}

for (Map.Entry<String, List<AuthenticationProvider>> entry : providerMap.entrySet()) {
AuthenticationProviderList provider = new AuthenticationProviderList(entry.getValue());
AuthenticationProvider provider;
if (entry.getValue().size() == 1) {
provider = entry.getValue().get(0);
} else {
provider = new AuthenticationProviderList(entry.getValue());
}
provider.initialize(conf);
providers.put(provider.getAuthMethodName(), provider);
LOG.info("[{}] has been loaded.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,12 @@
package org.apache.pulsar.broker.authentication;

import static java.nio.charset.StandardCharsets.UTF_8;
import javax.servlet.http.HttpServletRequest;
import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedDataAttributeName;
import static org.apache.pulsar.broker.web.AuthenticationFilter.AuthenticatedRoleAttributeName;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
Expand All @@ -35,6 +39,7 @@
import java.util.Optional;
import java.util.Properties;
import java.util.concurrent.TimeUnit;
import javax.servlet.http.HttpServletRequest;
import org.apache.pulsar.broker.ServiceConfiguration;
import org.apache.pulsar.broker.authentication.utils.AuthTokenUtils;
import org.apache.pulsar.common.api.AuthData;
Expand Down Expand Up @@ -157,19 +162,13 @@ public void testAuthenticate() throws Exception {
}

private AuthenticationState newAuthState(String token, String expectedSubject) throws Exception {
// Must pass the token to the newAuthState for legacy reasons.
AuthenticationState authState = authProvider.newAuthState(
AuthData.of(token.getBytes(UTF_8)),
null,
null
);
assertEquals(authState.getAuthRole(), expectedSubject);
assertTrue(authState.isComplete());
assertFalse(authState.isExpired());
return authState;
}

private AuthenticationState newHttpAuthState(HttpServletRequest request, String expectedSubject) throws Exception {
AuthenticationState authState = authProvider.newHttpAuthState(request);
authState.authenticateAsync(AuthData.of(token.getBytes(UTF_8))).get();
assertEquals(authState.getAuthRole(), expectedSubject);
assertTrue(authState.isComplete());
assertFalse(authState.isExpired());
Expand Down Expand Up @@ -200,37 +199,42 @@ public void testNewAuthState() throws Exception {
}

@Test
public void testNewHttpAuthState() throws Exception {
public void testAuthenticateHttpRequest() throws Exception {
HttpServletRequest requestAA = mock(HttpServletRequest.class);
when(requestAA.getRemoteAddr()).thenReturn("127.0.0.1");
when(requestAA.getRemotePort()).thenReturn(8080);
when(requestAA.getHeader("Authorization")).thenReturn("Bearer " + expiringTokenAA);
AuthenticationState authStateAA = newHttpAuthState(requestAA, SUBJECT_A);
boolean doFilterAA = authProvider.authenticateHttpRequest(requestAA, null);
assertTrue(doFilterAA);
verify(requestAA).setAttribute(eq(AuthenticatedRoleAttributeName), eq(SUBJECT_A));
verify(requestAA).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));

HttpServletRequest requestAB = mock(HttpServletRequest.class);
when(requestAB.getRemoteAddr()).thenReturn("127.0.0.1");
when(requestAB.getRemotePort()).thenReturn(8080);
when(requestAB.getHeader("Authorization")).thenReturn("Bearer " + expiringTokenAB);
AuthenticationState authStateAB = newHttpAuthState(requestAB, SUBJECT_B);
boolean doFilterAB = authProvider.authenticateHttpRequest(requestAB, null);
assertTrue(doFilterAB);
verify(requestAB).setAttribute(eq(AuthenticatedRoleAttributeName), eq(SUBJECT_B));
verify(requestAB).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));

HttpServletRequest requestBA = mock(HttpServletRequest.class);
when(requestBA.getRemoteAddr()).thenReturn("127.0.0.1");
when(requestBA.getRemotePort()).thenReturn(8080);
when(requestBA.getHeader("Authorization")).thenReturn("Bearer " + expiringTokenBA);
AuthenticationState authStateBA = newHttpAuthState(requestBA, SUBJECT_A);
boolean doFilterBA = authProvider.authenticateHttpRequest(requestBA, null);
assertTrue(doFilterBA);
verify(requestBA).setAttribute(eq(AuthenticatedRoleAttributeName), eq(SUBJECT_A));
verify(requestBA).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));

HttpServletRequest requestBB = mock(HttpServletRequest.class);
when(requestBB.getRemoteAddr()).thenReturn("127.0.0.1");
when(requestBB.getRemotePort()).thenReturn(8080);
when(requestBB.getHeader("Authorization")).thenReturn("Bearer " + expiringTokenBB);
AuthenticationState authStateBB = newHttpAuthState(requestBB, SUBJECT_B);

Thread.sleep(TimeUnit.SECONDS.toMillis(6));

verifyAuthStateExpired(authStateAA, SUBJECT_A);
verifyAuthStateExpired(authStateAB, SUBJECT_B);
verifyAuthStateExpired(authStateBA, SUBJECT_A);
verifyAuthStateExpired(authStateBB, SUBJECT_B);
boolean doFilterBB = authProvider.authenticateHttpRequest(requestBB, null);
assertTrue(doFilterBB);
verify(requestBB).setAttribute(eq(AuthenticatedRoleAttributeName), eq(SUBJECT_B));
verify(requestBB).setAttribute(eq(AuthenticatedDataAttributeName), isA(AuthenticationDataSource.class));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import static org.testng.Assert.assertNotEquals;
import static org.testng.Assert.assertNotNull;
import static org.testng.Assert.assertNull;
import static org.testng.Assert.assertThrows;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.fail;
import com.google.common.collect.Lists;
Expand Down Expand Up @@ -682,6 +683,7 @@ public void testExpiringToken() throws Exception {
Optional.of(new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(3))));

AuthenticationState authState = provider.newAuthState(AuthData.of(expiringToken.getBytes()), null, null);
authState.authenticate(AuthData.of(expiringToken.getBytes()));
assertTrue(authState.isComplete());
assertFalse(authState.isExpired());

Expand All @@ -693,6 +695,34 @@ public void testExpiringToken() throws Exception {
assertEquals(brokerData, AuthData.REFRESH_AUTH_DATA);
}

@SuppressWarnings("deprecation")
@Test
public void testExpiredTokenFailsOnAuthenticate() throws Exception {
SecretKey secretKey = AuthTokenUtils.createSecretKey(SignatureAlgorithm.HS256);

@Cleanup
AuthenticationProviderToken provider = new AuthenticationProviderToken();

Properties properties = new Properties();
properties.setProperty(AuthenticationProviderToken.CONF_TOKEN_SECRET_KEY,
AuthTokenUtils.encodeKeyBase64(secretKey));

ServiceConfiguration conf = new ServiceConfiguration();
conf.setProperties(properties);
provider.initialize(conf);

// Create a token that is already expired
String expiringToken = AuthTokenUtils.createToken(secretKey, SUBJECT,
Optional.of(new Date(System.currentTimeMillis() - TimeUnit.SECONDS.toMillis(3))));

AuthData expiredAuthData = AuthData.of(expiringToken.getBytes());

// It is important that this call doesn't fail because we no longer authenticate the auth data at construction
AuthenticationState authState = provider.newAuthState(expiredAuthData,null, null);
// The call to authenticate the token is the call that fails
assertThrows(AuthenticationException.class, () -> authState.authenticate(expiredAuthData));
}

// tests for Token Audience
@Test
public void testRightTokenAudienceClaim() throws Exception {
Expand Down Expand Up @@ -916,6 +946,7 @@ public void testTokenFromHttpHeaders() throws Exception {
assertTrue(doFilter, "Authentication should have passed");
}

@SuppressWarnings("deprecation")
@Test
public void testTokenStateUpdatesAuthenticationDataSource() throws Exception {
SecretKey secretKey = AuthTokenUtils.createSecretKey(SignatureAlgorithm.HS256);
Expand All @@ -931,20 +962,26 @@ public void testTokenStateUpdatesAuthenticationDataSource() throws Exception {
conf.setProperties(properties);
provider.initialize(conf);

String firstToken = AuthTokenUtils.createToken(secretKey, SUBJECT, Optional.empty());
AuthenticationState authState = provider.newAuthState(null,null, null);

// Haven't authenticated yet, so cannot get role when using constructor with no auth data
assertThrows(AuthenticationException.class, authState::getAuthRole);
assertNull(authState.getAuthDataSource(), "Haven't created a source yet.");

AuthenticationState authState = provider.newAuthState(AuthData.of(firstToken.getBytes()),null, null);
String firstToken = AuthTokenUtils.createToken(secretKey, SUBJECT, Optional.empty());

AuthData firstChallenge = authState.authenticate(AuthData.of(firstToken.getBytes()));
AuthenticationDataSource firstAuthDataSource = authState.getAuthDataSource();
assertNotNull(firstAuthDataSource, "Should be initialized.");

String secondToken = AuthTokenUtils.createToken(secretKey, SUBJECT,
Optional.of(new Date(System.currentTimeMillis() + TimeUnit.SECONDS.toMillis(3))));
assertNull(firstChallenge, "TokenAuth doesn't respond with challenges");
assertNotNull(firstAuthDataSource, "Created authDataSource");

String secondToken = AuthTokenUtils.createToken(secretKey, SUBJECT, Optional.empty());

AuthData challenge = authState.authenticate(AuthData.of(secondToken.getBytes()));
AuthData secondChallenge = authState.authenticate(AuthData.of(secondToken.getBytes()));
AuthenticationDataSource secondAuthDataSource = authState.getAuthDataSource();

assertNull(challenge, "TokenAuth doesn't respond with challenges");
assertNull(secondChallenge, "TokenAuth doesn't respond with challenges");
assertNotNull(secondAuthDataSource, "Created authDataSource");

assertNotEquals(firstAuthDataSource, secondAuthDataSource);
Expand Down