diff --git a/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationState.java b/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationState.java index 6e5b889ff3c9c..2f8c79a79191e 100644 --- a/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationState.java +++ b/pulsar-broker-common/src/main/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationState.java @@ -21,39 +21,60 @@ import static java.nio.charset.StandardCharsets.UTF_8; import java.net.SocketAddress; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; import javax.naming.AuthenticationException; import javax.net.ssl.SSLSession; import javax.servlet.http.HttpServletRequest; import org.apache.pulsar.common.api.AuthData; /** - * Interface for authentication state. - * - * It tell broker whether the authentication is completed or not, - * if completed, what is the AuthRole is. + * A class to track single stage authentication. This class assumes that: + * 1. {@link #authenticateAsync(AuthData)} is called once and when the {@link CompletableFuture} completes, + * authentication is complete. + * 2. Authentication does not expire, so {@link #isExpired()} always returns false. + *

+ * See {@link AuthenticationState} for Pulsar's contract on how this interface is used by Pulsar. */ public class OneStageAuthenticationState implements AuthenticationState { - private final AuthenticationDataSource authenticationDataSource; - private final String authRole; + private AuthenticationDataSource authenticationDataSource; + private final SocketAddress remoteAddress; + private final SSLSession sslSession; + private final AuthenticationProvider provider; + private volatile String authRole; + + /** + * Constructor for a {@link OneStageAuthenticationState} where there is no authentication performed during + * initialization. + * @param remoteAddress - remoteAddress associated with the {@link AuthenticationState} + * @param sslSession - sslSession associated with the {@link AuthenticationState} + * @param provider - {@link AuthenticationProvider} to use to verify {@link AuthData} + */ public OneStageAuthenticationState(AuthData authData, SocketAddress remoteAddress, SSLSession sslSession, - AuthenticationProvider provider) throws AuthenticationException { - this.authenticationDataSource = new AuthenticationDataCommand( - new String(authData.getBytes(), UTF_8), remoteAddress, sslSession); - this.authRole = provider.authenticate(authenticationDataSource); + AuthenticationProvider provider) { + this.provider = provider; + this.remoteAddress = remoteAddress; + this.sslSession = sslSession; } - public OneStageAuthenticationState(HttpServletRequest request, AuthenticationProvider provider) - throws AuthenticationException { + public OneStageAuthenticationState(HttpServletRequest request, AuthenticationProvider provider) { + // Must initialize this here for backwards compatibility with http authentication this.authenticationDataSource = new AuthenticationDataHttps(request); - this.authRole = provider.authenticate(authenticationDataSource); + this.provider = provider; + // These are not used when invoking this constructor. + this.remoteAddress = null; + this.sslSession = null; } @Override - public String getAuthRole() { + public String getAuthRole() throws AuthenticationException { + if (authRole == null) { + throw new AuthenticationException("Must authenticate before calling getAuthRole"); + } return authRole; } @@ -62,13 +83,47 @@ public AuthenticationDataSource getAuthDataSource() { return authenticationDataSource; } + /** + * Warning: this method is not intended to be called concurrently. + */ + @Override + public CompletableFuture authenticateAsync(AuthData authData) { + if (authRole != null) { + // Authentication is already completed + return CompletableFuture.completedFuture(null); + } + this.authenticationDataSource = new AuthenticationDataCommand( + new String(authData.getBytes(), UTF_8), remoteAddress, sslSession); + + return provider + .authenticateAsync(authenticationDataSource) + .thenApply(role -> { + this.authRole = role; + // Single stage authentication always returns null + return null; + }); + } + + /** + * @deprecated use {@link #authenticateAsync(AuthData)} + */ + @Deprecated @Override - public AuthData authenticate(AuthData authData) { - return null; + public AuthData authenticate(AuthData authData) throws AuthenticationException { + try { + return authenticateAsync(authData).get(); + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } } + /** + * @deprecated rely on result from {@link #authenticateAsync(AuthData)}. For more information, see the Javadoc + * for {@link AuthenticationState#isComplete()}. + */ + @Deprecated @Override public boolean isComplete() { - return true; + return authRole != null; } } diff --git a/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationStateTest.java b/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationStateTest.java new file mode 100644 index 0000000000000..7ec1222c651c3 --- /dev/null +++ b/pulsar-broker-common/src/test/java/org/apache/pulsar/broker/authentication/OneStageAuthenticationStateTest.java @@ -0,0 +1,135 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.pulsar.broker.authentication; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertNull; +import static org.testng.Assert.assertSame; +import static org.testng.Assert.assertThrows; +import static org.testng.Assert.assertTrue; +import org.apache.pulsar.broker.ServiceConfiguration; +import org.apache.pulsar.common.api.AuthData; +import org.testng.annotations.Test; +import javax.naming.AuthenticationException; +import javax.servlet.http.HttpServletRequest; +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.LongAdder; + +public class OneStageAuthenticationStateTest { + + public static class CountingAuthenticationProvider implements AuthenticationProvider { + public LongAdder authCallCount = new LongAdder(); + + @Override + public void initialize(ServiceConfiguration config) throws IOException { + } + + @Override + public String getAuthMethodName() { + return null; + } + + @Override + public void close() throws IOException { + } + + @Override + public CompletableFuture authenticateAsync(AuthenticationDataSource authData) { + authCallCount.increment(); + return CompletableFuture.completedFuture(authData.getCommandData()); + } + + public int getAuthCallCount() { + return authCallCount.intValue(); + } + } + + @Test + public void verifyAuthenticateAsyncIsCalledExactlyOnceAndSetsRole() throws Exception { + CountingAuthenticationProvider provider = new CountingAuthenticationProvider(); + AuthData authData = AuthData.of("role".getBytes()); + OneStageAuthenticationState authState = new OneStageAuthenticationState(authData, null, null, provider); + assertEquals(provider.getAuthCallCount(), 0, "Auth count should not increase yet"); + AuthData challenge = authState.authenticateAsync(authData).get(); + assertNull(challenge); + assertEquals(provider.getAuthCallCount(), 1, "Call authenticate only once"); + assertEquals(authState.getAuthRole(), "role"); + AuthenticationDataSource firstAuthenticationDataSource = authState.getAuthDataSource(); + assertTrue(firstAuthenticationDataSource instanceof AuthenticationDataCommand); + + // Verify subsequent call to authenticate does not change data + AuthData secondChallenge = authState.authenticateAsync(AuthData.of("admin".getBytes())).get(); + assertNull(secondChallenge); + assertEquals(authState.getAuthRole(), "role"); + AuthenticationDataSource secondAuthenticationDataSource = authState.getAuthDataSource(); + assertSame(secondAuthenticationDataSource, firstAuthenticationDataSource); + assertEquals(provider.getAuthCallCount(), 1, "Call authenticate only once, even later."); + } + + @SuppressWarnings("deprecation") + @Test + public void verifyAuthenticateIsCalledExactlyOnceAndSetsRole() throws Exception { + CountingAuthenticationProvider provider = new CountingAuthenticationProvider(); + AuthData authData = AuthData.of("role".getBytes()); + OneStageAuthenticationState authState = new OneStageAuthenticationState(authData, null, null, provider); + assertEquals(provider.getAuthCallCount(), 0, "Auth count should not increase yet"); + assertFalse(authState.isComplete()); + AuthData challenge = authState.authenticate(authData); + assertNull(challenge); + assertTrue(authState.isComplete()); + assertEquals(provider.getAuthCallCount(), 1, "Call authenticate only once"); + assertEquals(authState.getAuthRole(), "role"); + AuthenticationDataSource firstAuthenticationDataSource = authState.getAuthDataSource(); + assertTrue(firstAuthenticationDataSource instanceof AuthenticationDataCommand); + + // Verify subsequent call to authenticate does not change data + AuthData secondChallenge = authState.authenticate(AuthData.of("admin".getBytes())); + assertNull(secondChallenge); + assertEquals(authState.getAuthRole(), "role"); + AuthenticationDataSource secondAuthenticationDataSource = authState.getAuthDataSource(); + assertSame(secondAuthenticationDataSource, firstAuthenticationDataSource); + assertEquals(provider.getAuthCallCount(), 1, "Call authenticate only once, even later."); + } + + @Test + public void verifyGetAuthRoleBeforeAuthenticateFails() { + CountingAuthenticationProvider provider = new CountingAuthenticationProvider(); + AuthData authData = AuthData.of("role".getBytes()); + OneStageAuthenticationState authState = new OneStageAuthenticationState(authData, null, null, provider); + assertThrows(AuthenticationException.class, authState::getAuthRole); + assertNull(authState.getAuthDataSource()); + } + + @Test + public void verifyHttpAuthConstructorInitializesAuthDataSourceAndDoesNotAuthenticateData() { + HttpServletRequest request = mock(HttpServletRequest.class); + when(request.getRemoteAddr()).thenReturn("localhost"); + when(request.getRemotePort()).thenReturn(8080); + CountingAuthenticationProvider provider = new CountingAuthenticationProvider(); + OneStageAuthenticationState authState = new OneStageAuthenticationState(request, provider); + assertNotNull(authState.getAuthDataSource()); + assertEquals(provider.getAuthCallCount(), 0); + } +}