From 26e10536773498cfd0a4514256456795658dd6d8 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Tue, 14 Feb 2023 03:09:55 -0600 Subject: [PATCH] [fix][broker] Make authentication refresh threadsafe (#19506) Co-authored-by: Lari Hotari (cherry picked from commit 153e4d4cc3b56aaee224b0a68e0186c08125c975) (cherry picked from commit 161ec5aa20c4e0d9f82473e43e5ccdc7a113f236) --- .../service/PulsarChannelInitializer.java | 29 ----- .../pulsar/broker/service/ServerCnx.java | 109 +++++++++++------- .../pulsar/broker/service/ServerCnxTest.java | 14 ++- 3 files changed, 76 insertions(+), 76 deletions(-) diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/PulsarChannelInitializer.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/PulsarChannelInitializer.java index e75c518a50f02..e1057de54ccda 100644 --- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/PulsarChannelInitializer.java +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/PulsarChannelInitializer.java @@ -18,9 +18,6 @@ */ package org.apache.pulsar.broker.service; -import static org.apache.bookkeeper.util.SafeRunnable.safeRun; -import com.github.benmanes.caffeine.cache.Cache; -import com.github.benmanes.caffeine.cache.Caffeine; import com.google.common.annotations.VisibleForTesting; import io.netty.channel.ChannelInitializer; import io.netty.channel.socket.SocketChannel; @@ -29,8 +26,6 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslProvider; -import java.net.SocketAddress; -import java.util.concurrent.TimeUnit; import lombok.Builder; import lombok.Data; import lombok.extern.slf4j.Slf4j; @@ -56,14 +51,6 @@ public class PulsarChannelInitializer extends ChannelInitializer private final ServiceConfiguration brokerConf; private NettySSLContextAutoRefreshBuilder nettySSLContextAutoRefreshBuilder; - // This cache is used to maintain a list of active connections to iterate over them - // We keep weak references to have the cache to be auto cleaned up when the connections - // objects are GCed. - private final Cache connections = Caffeine.newBuilder() - .weakKeys() - .weakValues() - .build(); - /** * @param pulsar * An instance of {@link PulsarService} @@ -112,10 +99,6 @@ public PulsarChannelInitializer(PulsarService pulsar, PulsarChannelOptions opts) this.sslCtxRefresher = null; } this.brokerConf = pulsar.getConfiguration(); - - pulsar.getExecutor().scheduleAtFixedRate(safeRun(this::refreshAuthenticationCredentials), - pulsar.getConfig().getAuthenticationRefreshCheckSeconds(), - pulsar.getConfig().getAuthenticationRefreshCheckSeconds(), TimeUnit.SECONDS); } @Override @@ -145,18 +128,6 @@ protected void initChannel(SocketChannel ch) throws Exception { ch.pipeline().addLast("flowController", new FlowControlHandler()); ServerCnx cnx = newServerCnx(pulsar, listenerName); ch.pipeline().addLast("handler", cnx); - - connections.put(ch.remoteAddress(), cnx); - } - - private void refreshAuthenticationCredentials() { - connections.asMap().values().forEach(cnx -> { - try { - cnx.refreshAuthenticationCredentials(); - } catch (Throwable t) { - log.warn("[{}] Failed to refresh auth credentials", cnx.clientAddress()); - } - }); } @VisibleForTesting diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java index 851058dc811b4..954ab1d182f47 100644 --- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/ServerCnx.java @@ -37,6 +37,7 @@ import io.netty.handler.ssl.SslHandler; import io.netty.util.concurrent.FastThreadLocal; import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.ScheduledFuture; import io.prometheus.client.Gauge; import java.io.IOException; import java.net.InetSocketAddress; @@ -177,6 +178,7 @@ public class ServerCnx extends PulsarHandler implements TransportCnx { private AuthenticationState originalAuthState; private volatile AuthenticationDataSource originalAuthData; private boolean pendingAuthChallengeResponse = false; + private ScheduledFuture authRefreshTask; // Max number of pending requests per connections. If multiple producers are sharing the same connection the flow // control done by a single producer might not be enough to prevent write spikes on the broker. @@ -306,6 +308,9 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception { } cnxsPerThread.get().remove(this); + if (authRefreshTask != null) { + authRefreshTask.cancel(false); + } // Connection is gone, close the producers immediately producers.forEach((__, producerFuture) -> { @@ -656,15 +661,19 @@ private void doAuthentication(AuthData clientData, if (state != State.Connected) { // First time authentication is done - if (service.isAuthenticationEnabled() && service.isAuthorizationEnabled()) { - if (!service.getAuthorizationService() - .isValidOriginalPrincipal(this.authRole, originalPrincipal, remoteAddress)) { - state = State.Failed; - service.getPulsarStats().recordConnectionCreateFail(); - final ByteBuf msg = Commands.newError(-1, ServerError.AuthorizationError, "Invalid roles."); - ctx.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE); - return; + if (service.isAuthenticationEnabled()) { + if (service.isAuthorizationEnabled()) { + if (!service.getAuthorizationService() + .isValidOriginalPrincipal(this.authRole, originalPrincipal, remoteAddress)) { + state = State.Failed; + service.getPulsarStats().recordConnectionCreateFail(); + final ByteBuf msg = Commands.newError(-1, ServerError.AuthorizationError, + "Invalid roles."); + ctx.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE); + return; + } } + maybeScheduleAuthenticationCredentialsRefresh(); } completeConnect(clientProtocolVersion, clientVersion); } else { @@ -691,61 +700,75 @@ private void doAuthentication(AuthData clientData, } } - public void refreshAuthenticationCredentials() { - AuthenticationState authState = this.originalAuthState != null ? originalAuthState : this.authState; - + /** + * Method to initialize the {@link #authRefreshTask} task. + */ + private void maybeScheduleAuthenticationCredentialsRefresh() { + assert ctx.executor().inEventLoop(); + assert authRefreshTask == null; if (authState == null) { // Authentication is disabled or there's no local state to refresh return; - } else if (getState() != State.Connected || !isActive) { - // Connection is either still being established or already closed. + } + authRefreshTask = ctx.executor().scheduleAtFixedRate(this::refreshAuthenticationCredentials, + service.getPulsar().getConfig().getAuthenticationRefreshCheckSeconds(), + service.getPulsar().getConfig().getAuthenticationRefreshCheckSeconds(), + TimeUnit.SECONDS); + } + + private void refreshAuthenticationCredentials() { + assert ctx.executor().inEventLoop(); + AuthenticationState authState = this.originalAuthState != null ? originalAuthState : this.authState; + if (getState() == State.Failed) { + // Happens when an exception is thrown that causes this connection to close. return; } else if (!authState.isExpired()) { // Credentials are still valid. Nothing to do at this point return; } else if (originalPrincipal != null && originalAuthState == null) { + // This case is only checked when the authState is expired because we've reached a point where + // authentication needs to be refreshed, but the protocol does not support it unless the proxy forwards + // the originalAuthData. log.info( "[{}] Cannot revalidate user credential when using proxy and" + " not forwarding the credentials. Closing connection", remoteAddress); + ctx.close(); return; } - ctx.executor().execute(SafeRun.safeRun(() -> { - log.info("[{}] Refreshing authentication credentials for originalPrincipal {} and authRole {}", - remoteAddress, originalPrincipal, this.authRole); - - if (!supportsAuthenticationRefresh()) { - log.warn("[{}] Closing connection because client doesn't support auth credentials refresh", - remoteAddress); - ctx.close(); - return; - } + if (!supportsAuthenticationRefresh()) { + log.warn("[{}] Closing connection because client doesn't support auth credentials refresh", + remoteAddress); + ctx.close(); + return; + } - if (pendingAuthChallengeResponse) { - log.warn("[{}] Closing connection after timeout on refreshing auth credentials", - remoteAddress); - ctx.close(); - return; - } + if (pendingAuthChallengeResponse) { + log.warn("[{}] Closing connection after timeout on refreshing auth credentials", + remoteAddress); + ctx.close(); + return; + } - try { - AuthData brokerData = authState.refreshAuthentication(); + log.info("[{}] Refreshing authentication credentials for originalPrincipal {} and authRole {}", + remoteAddress, originalPrincipal, this.authRole); + try { + AuthData brokerData = authState.refreshAuthentication(); - ctx.writeAndFlush(Commands.newAuthChallenge(authMethod, brokerData, - getRemoteEndpointProtocolVersion())); - if (log.isDebugEnabled()) { - log.debug("[{}] Sent auth challenge to client to refresh credentials with method: {}.", - remoteAddress, authMethod); - } + ctx.writeAndFlush(Commands.newAuthChallenge(authMethod, brokerData, + getRemoteEndpointProtocolVersion())); + if (log.isDebugEnabled()) { + log.debug("[{}] Sent auth challenge to client to refresh credentials with method: {}.", + remoteAddress, authMethod); + } - pendingAuthChallengeResponse = true; + pendingAuthChallengeResponse = true; - } catch (AuthenticationException e) { - log.warn("[{}] Failed to refresh authentication: {}", remoteAddress, e); - ctx.close(); - } - })); + } catch (AuthenticationException e) { + log.warn("[{}] Failed to refresh authentication: {}", remoteAddress, e); + ctx.close(); + } } private static final byte[] emptyArray = new byte[0]; diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ServerCnxTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ServerCnxTest.java index a994a3adbadec..c39e1f5b7e4b2 100644 --- a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ServerCnxTest.java +++ b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/ServerCnxTest.java @@ -496,10 +496,13 @@ public void testAuthChallengePrincipalChangeFails() throws Exception { when(brokerService.getAuthenticationService()).thenReturn(authenticationService); when(authenticationService.getAuthenticationProvider(authMethodName)).thenReturn(authenticationProvider); svcConfig.setAuthenticationEnabled(true); + svcConfig.setAuthenticationRefreshCheckSeconds(30); resetChannel(); assertTrue(channel.isActive()); assertEquals(serverCnx.getState(), State.Start); + // Don't want the keep alive task affecting which messages are handled + serverCnx.cancelKeepAliveTask(); ByteBuf clientCommand = Commands.newConnect(authMethodName, "pass.client", ""); channel.writeInbound(clientCommand); @@ -512,7 +515,7 @@ public void testAuthChallengePrincipalChangeFails() throws Exception { // Trigger the ServerCnx to check if authentication is expired (it is because of our special implementation) // and then force channel to run the task - serverCnx.refreshAuthenticationCredentials(); + channel.advanceTimeBy(30, TimeUnit.SECONDS); channel.runPendingTasks(); Object responseAuthChallenge1 = getResponse(); assertTrue(responseAuthChallenge1 instanceof CommandAuthChallenge); @@ -522,7 +525,7 @@ public void testAuthChallengePrincipalChangeFails() throws Exception { channel.writeInbound(authResponse1); // Trigger the ServerCnx to check if authentication is expired again - serverCnx.refreshAuthenticationCredentials(); + channel.advanceTimeBy(30, TimeUnit.SECONDS); assertTrue(channel.hasPendingTasks(), "This test assumes there are pending tasks to run."); channel.runPendingTasks(); Object responseAuthChallenge2 = getResponse(); @@ -548,10 +551,13 @@ public void testAuthChallengeOriginalPrincipalChangeFails() throws Exception { svcConfig.setAuthenticationEnabled(true); svcConfig.setAuthenticateOriginalAuthData(true); svcConfig.setProxyRoles(Collections.singleton("pass.proxy")); + svcConfig.setAuthenticationRefreshCheckSeconds(30); resetChannel(); assertTrue(channel.isActive()); assertEquals(serverCnx.getState(), State.Start); + // Don't want the keep alive task affecting which messages are handled + serverCnx.cancelKeepAliveTask(); ByteBuf clientCommand = Commands.newConnect(authMethodName, "pass.proxy", 1, null, null, "pass.client", "pass.client", authMethodName); @@ -568,7 +574,7 @@ public void testAuthChallengeOriginalPrincipalChangeFails() throws Exception { // Trigger the ServerCnx to check if authentication is expired (it is because of our special implementation) // and then force channel to run the task - serverCnx.refreshAuthenticationCredentials(); + channel.advanceTimeBy(30, TimeUnit.SECONDS); assertTrue(channel.hasPendingTasks(), "This test assumes there are pending tasks to run."); channel.runPendingTasks(); Object responseAuthChallenge1 = getResponse(); @@ -579,7 +585,7 @@ public void testAuthChallengeOriginalPrincipalChangeFails() throws Exception { channel.writeInbound(authResponse1); // Trigger the ServerCnx to check if authentication is expired again - serverCnx.refreshAuthenticationCredentials(); + channel.advanceTimeBy(30, TimeUnit.SECONDS); channel.runPendingTasks(); Object responseAuthChallenge2 = getResponse(); assertTrue(responseAuthChallenge2 instanceof CommandAuthChallenge);