diff --git a/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationState.java b/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationState.java
index 6e5b889ff3c9c..2f8c79a79191e 100644
--- a/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationState.java
+++ b/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationState.java
@@ -21,39 +21,60 @@
import static java.nio.charset.StandardCharsets.UTF_8;
import java.net.SocketAddress;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ExecutionException;
import javax.naming.AuthenticationException;
import javax.net.ssl.SSLSession;
import javax.servlet.http.HttpServletRequest;
import org.apache.pulsar.common.api.AuthData;
/**
- * Interface for authentication state.
- *
- * It tell broker whether the authentication is completed or not,
- * if completed, what is the AuthRole is.
+ * A class to track single stage authentication. This class assumes that:
+ * 1. {@link #authenticateAsync(AuthData)} is called once and when the {@link CompletableFuture} completes,
+ * authentication is complete.
+ * 2. Authentication does not expire, so {@link #isExpired()} always returns false.
+ *
+ * See {@link AuthenticationState} for Pulsar's contract on how this interface is used by Pulsar.
*/
public class OneStageAuthenticationState implements AuthenticationState {
- private final AuthenticationDataSource authenticationDataSource;
- private final String authRole;
+ private AuthenticationDataSource authenticationDataSource;
+ private final SocketAddress remoteAddress;
+ private final SSLSession sslSession;
+ private final AuthenticationProvider provider;
+ private volatile String authRole;
+
+ /**
+ * Constructor for a {@link OneStageAuthenticationState} where there is no authentication performed during
+ * initialization.
+ * @param remoteAddress - remoteAddress associated with the {@link AuthenticationState}
+ * @param sslSession - sslSession associated with the {@link AuthenticationState}
+ * @param provider - {@link AuthenticationProvider} to use to verify {@link AuthData}
+ */
public OneStageAuthenticationState(AuthData authData,
SocketAddress remoteAddress,
SSLSession sslSession,
- AuthenticationProvider provider) throws AuthenticationException {
- this.authenticationDataSource = new AuthenticationDataCommand(
- new String(authData.getBytes(), UTF_8), remoteAddress, sslSession);
- this.authRole = provider.authenticate(authenticationDataSource);
+ AuthenticationProvider provider) {
+ this.provider = provider;
+ this.remoteAddress = remoteAddress;
+ this.sslSession = sslSession;
}
- public OneStageAuthenticationState(HttpServletRequest request, AuthenticationProvider provider)
- throws AuthenticationException {
+ public OneStageAuthenticationState(HttpServletRequest request, AuthenticationProvider provider) {
+ // Must initialize this here for backwards compatibility with http authentication
this.authenticationDataSource = new AuthenticationDataHttps(request);
- this.authRole = provider.authenticate(authenticationDataSource);
+ this.provider = provider;
+ // These are not used when invoking this constructor.
+ this.remoteAddress = null;
+ this.sslSession = null;
}
@Override
- public String getAuthRole() {
+ public String getAuthRole() throws AuthenticationException {
+ if (authRole == null) {
+ throw new AuthenticationException("Must authenticate before calling getAuthRole");
+ }
return authRole;
}
@@ -62,13 +83,47 @@ public AuthenticationDataSource getAuthDataSource() {
return authenticationDataSource;
}
+ /**
+ * Warning: this method is not intended to be called concurrently.
+ */
+ @Override
+ public CompletableFuture authenticateAsync(AuthData authData) {
+ if (authRole != null) {
+ // Authentication is already completed
+ return CompletableFuture.completedFuture(null);
+ }
+ this.authenticationDataSource = new AuthenticationDataCommand(
+ new String(authData.getBytes(), UTF_8), remoteAddress, sslSession);
+
+ return provider
+ .authenticateAsync(authenticationDataSource)
+ .thenApply(role -> {
+ this.authRole = role;
+ // Single stage authentication always returns null
+ return null;
+ });
+ }
+
+ /**
+ * @deprecated use {@link #authenticateAsync(AuthData)}
+ */
+ @Deprecated
@Override
- public AuthData authenticate(AuthData authData) {
- return null;
+ public AuthData authenticate(AuthData authData) throws AuthenticationException {
+ try {
+ return authenticateAsync(authData).get();
+ } catch (InterruptedException | ExecutionException e) {
+ throw new RuntimeException(e);
+ }
}
+ /**
+ * @deprecated rely on result from {@link #authenticateAsync(AuthData)}. For more information, see the Javadoc
+ * for {@link AuthenticationState#isComplete()}.
+ */
+ @Deprecated
@Override
public boolean isComplete() {
- return true;
+ return authRole != null;
}
}
diff --git a/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationStateTest.java b/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationStateTest.java
new file mode 100644
index 0000000000000..7ec1222c651c3
--- /dev/null
+++ b/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationStateTest.java
@@ -0,0 +1,135 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.pulsar.broker.authentication;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertNull;
+import static org.testng.Assert.assertSame;
+import static org.testng.Assert.assertThrows;
+import static org.testng.Assert.assertTrue;
+import org.apache.pulsar.broker.ServiceConfiguration;
+import org.apache.pulsar.common.api.AuthData;
+import org.testng.annotations.Test;
+import javax.naming.AuthenticationException;
+import javax.servlet.http.HttpServletRequest;
+import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.atomic.LongAdder;
+
+public class OneStageAuthenticationStateTest {
+
+ public static class CountingAuthenticationProvider implements AuthenticationProvider {
+ public LongAdder authCallCount = new LongAdder();
+
+ @Override
+ public void initialize(ServiceConfiguration config) throws IOException {
+ }
+
+ @Override
+ public String getAuthMethodName() {
+ return null;
+ }
+
+ @Override
+ public void close() throws IOException {
+ }
+
+ @Override
+ public CompletableFuture authenticateAsync(AuthenticationDataSource authData) {
+ authCallCount.increment();
+ return CompletableFuture.completedFuture(authData.getCommandData());
+ }
+
+ public int getAuthCallCount() {
+ return authCallCount.intValue();
+ }
+ }
+
+ @Test
+ public void verifyAuthenticateAsyncIsCalledExactlyOnceAndSetsRole() throws Exception {
+ CountingAuthenticationProvider provider = new CountingAuthenticationProvider();
+ AuthData authData = AuthData.of("role".getBytes());
+ OneStageAuthenticationState authState = new OneStageAuthenticationState(authData, null, null, provider);
+ assertEquals(provider.getAuthCallCount(), 0, "Auth count should not increase yet");
+ AuthData challenge = authState.authenticateAsync(authData).get();
+ assertNull(challenge);
+ assertEquals(provider.getAuthCallCount(), 1, "Call authenticate only once");
+ assertEquals(authState.getAuthRole(), "role");
+ AuthenticationDataSource firstAuthenticationDataSource = authState.getAuthDataSource();
+ assertTrue(firstAuthenticationDataSource instanceof AuthenticationDataCommand);
+
+ // Verify subsequent call to authenticate does not change data
+ AuthData secondChallenge = authState.authenticateAsync(AuthData.of("admin".getBytes())).get();
+ assertNull(secondChallenge);
+ assertEquals(authState.getAuthRole(), "role");
+ AuthenticationDataSource secondAuthenticationDataSource = authState.getAuthDataSource();
+ assertSame(secondAuthenticationDataSource, firstAuthenticationDataSource);
+ assertEquals(provider.getAuthCallCount(), 1, "Call authenticate only once, even later.");
+ }
+
+ @SuppressWarnings("deprecation")
+ @Test
+ public void verifyAuthenticateIsCalledExactlyOnceAndSetsRole() throws Exception {
+ CountingAuthenticationProvider provider = new CountingAuthenticationProvider();
+ AuthData authData = AuthData.of("role".getBytes());
+ OneStageAuthenticationState authState = new OneStageAuthenticationState(authData, null, null, provider);
+ assertEquals(provider.getAuthCallCount(), 0, "Auth count should not increase yet");
+ assertFalse(authState.isComplete());
+ AuthData challenge = authState.authenticate(authData);
+ assertNull(challenge);
+ assertTrue(authState.isComplete());
+ assertEquals(provider.getAuthCallCount(), 1, "Call authenticate only once");
+ assertEquals(authState.getAuthRole(), "role");
+ AuthenticationDataSource firstAuthenticationDataSource = authState.getAuthDataSource();
+ assertTrue(firstAuthenticationDataSource instanceof AuthenticationDataCommand);
+
+ // Verify subsequent call to authenticate does not change data
+ AuthData secondChallenge = authState.authenticate(AuthData.of("admin".getBytes()));
+ assertNull(secondChallenge);
+ assertEquals(authState.getAuthRole(), "role");
+ AuthenticationDataSource secondAuthenticationDataSource = authState.getAuthDataSource();
+ assertSame(secondAuthenticationDataSource, firstAuthenticationDataSource);
+ assertEquals(provider.getAuthCallCount(), 1, "Call authenticate only once, even later.");
+ }
+
+ @Test
+ public void verifyGetAuthRoleBeforeAuthenticateFails() {
+ CountingAuthenticationProvider provider = new CountingAuthenticationProvider();
+ AuthData authData = AuthData.of("role".getBytes());
+ OneStageAuthenticationState authState = new OneStageAuthenticationState(authData, null, null, provider);
+ assertThrows(AuthenticationException.class, authState::getAuthRole);
+ assertNull(authState.getAuthDataSource());
+ }
+
+ @Test
+ public void verifyHttpAuthConstructorInitializesAuthDataSourceAndDoesNotAuthenticateData() {
+ HttpServletRequest request = mock(HttpServletRequest.class);
+ when(request.getRemoteAddr()).thenReturn("localhost");
+ when(request.getRemotePort()).thenReturn(8080);
+ CountingAuthenticationProvider provider = new CountingAuthenticationProvider();
+ OneStageAuthenticationState authState = new OneStageAuthenticationState(request, provider);
+ assertNotNull(authState.getAuthDataSource());
+ assertEquals(provider.getAuthCallCount(), 0);
+ }
+}