Skip to content

Commit

Permalink
[fix][broker] Make authentication refresh threadsafe (apache#19506)
Browse files Browse the repository at this point in the history
Co-authored-by: Lari Hotari <lhotari@users.noreply.github.com>
(cherry picked from commit 153e4d4)
(cherry picked from commit 161ec5a)
  • Loading branch information
michaeljmarshall committed Feb 22, 2023
1 parent 14152fc commit 26e1053
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 76 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
*/
package org.apache.pulsar.broker.service;

import static org.apache.bookkeeper.util.SafeRunnable.safeRun;
import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine;
import com.google.common.annotations.VisibleForTesting;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
Expand All @@ -29,8 +26,6 @@
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslProvider;
import java.net.SocketAddress;
import java.util.concurrent.TimeUnit;
import lombok.Builder;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
Expand All @@ -56,14 +51,6 @@ public class PulsarChannelInitializer extends ChannelInitializer<SocketChannel>
private final ServiceConfiguration brokerConf;
private NettySSLContextAutoRefreshBuilder nettySSLContextAutoRefreshBuilder;

// This cache is used to maintain a list of active connections to iterate over them
// We keep weak references to have the cache to be auto cleaned up when the connections
// objects are GCed.
private final Cache<SocketAddress, ServerCnx> connections = Caffeine.newBuilder()
.weakKeys()
.weakValues()
.build();

/**
* @param pulsar
* An instance of {@link PulsarService}
Expand Down Expand Up @@ -112,10 +99,6 @@ public PulsarChannelInitializer(PulsarService pulsar, PulsarChannelOptions opts)
this.sslCtxRefresher = null;
}
this.brokerConf = pulsar.getConfiguration();

pulsar.getExecutor().scheduleAtFixedRate(safeRun(this::refreshAuthenticationCredentials),
pulsar.getConfig().getAuthenticationRefreshCheckSeconds(),
pulsar.getConfig().getAuthenticationRefreshCheckSeconds(), TimeUnit.SECONDS);
}

@Override
Expand Down Expand Up @@ -145,18 +128,6 @@ protected void initChannel(SocketChannel ch) throws Exception {
ch.pipeline().addLast("flowController", new FlowControlHandler());
ServerCnx cnx = newServerCnx(pulsar, listenerName);
ch.pipeline().addLast("handler", cnx);

connections.put(ch.remoteAddress(), cnx);
}

private void refreshAuthenticationCredentials() {
connections.asMap().values().forEach(cnx -> {
try {
cnx.refreshAuthenticationCredentials();
} catch (Throwable t) {
log.warn("[{}] Failed to refresh auth credentials", cnx.clientAddress());
}
});
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.FastThreadLocal;
import io.netty.util.concurrent.Promise;
import io.netty.util.concurrent.ScheduledFuture;
import io.prometheus.client.Gauge;
import java.io.IOException;
import java.net.InetSocketAddress;
Expand Down Expand Up @@ -177,6 +178,7 @@ public class ServerCnx extends PulsarHandler implements TransportCnx {
private AuthenticationState originalAuthState;
private volatile AuthenticationDataSource originalAuthData;
private boolean pendingAuthChallengeResponse = false;
private ScheduledFuture<?> authRefreshTask;

// Max number of pending requests per connections. If multiple producers are sharing the same connection the flow
// control done by a single producer might not be enough to prevent write spikes on the broker.
Expand Down Expand Up @@ -306,6 +308,9 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
}

cnxsPerThread.get().remove(this);
if (authRefreshTask != null) {
authRefreshTask.cancel(false);
}

// Connection is gone, close the producers immediately
producers.forEach((__, producerFuture) -> {
Expand Down Expand Up @@ -656,15 +661,19 @@ private void doAuthentication(AuthData clientData,

if (state != State.Connected) {
// First time authentication is done
if (service.isAuthenticationEnabled() && service.isAuthorizationEnabled()) {
if (!service.getAuthorizationService()
.isValidOriginalPrincipal(this.authRole, originalPrincipal, remoteAddress)) {
state = State.Failed;
service.getPulsarStats().recordConnectionCreateFail();
final ByteBuf msg = Commands.newError(-1, ServerError.AuthorizationError, "Invalid roles.");
ctx.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE);
return;
if (service.isAuthenticationEnabled()) {
if (service.isAuthorizationEnabled()) {
if (!service.getAuthorizationService()
.isValidOriginalPrincipal(this.authRole, originalPrincipal, remoteAddress)) {
state = State.Failed;
service.getPulsarStats().recordConnectionCreateFail();
final ByteBuf msg = Commands.newError(-1, ServerError.AuthorizationError,
"Invalid roles.");
ctx.writeAndFlush(msg).addListener(ChannelFutureListener.CLOSE);
return;
}
}
maybeScheduleAuthenticationCredentialsRefresh();
}
completeConnect(clientProtocolVersion, clientVersion);
} else {
Expand All @@ -691,61 +700,75 @@ private void doAuthentication(AuthData clientData,
}
}

public void refreshAuthenticationCredentials() {
AuthenticationState authState = this.originalAuthState != null ? originalAuthState : this.authState;

/**
* Method to initialize the {@link #authRefreshTask} task.
*/
private void maybeScheduleAuthenticationCredentialsRefresh() {
assert ctx.executor().inEventLoop();
assert authRefreshTask == null;
if (authState == null) {
// Authentication is disabled or there's no local state to refresh
return;
} else if (getState() != State.Connected || !isActive) {
// Connection is either still being established or already closed.
}
authRefreshTask = ctx.executor().scheduleAtFixedRate(this::refreshAuthenticationCredentials,
service.getPulsar().getConfig().getAuthenticationRefreshCheckSeconds(),
service.getPulsar().getConfig().getAuthenticationRefreshCheckSeconds(),
TimeUnit.SECONDS);
}

private void refreshAuthenticationCredentials() {
assert ctx.executor().inEventLoop();
AuthenticationState authState = this.originalAuthState != null ? originalAuthState : this.authState;
if (getState() == State.Failed) {
// Happens when an exception is thrown that causes this connection to close.
return;
} else if (!authState.isExpired()) {
// Credentials are still valid. Nothing to do at this point
return;
} else if (originalPrincipal != null && originalAuthState == null) {
// This case is only checked when the authState is expired because we've reached a point where
// authentication needs to be refreshed, but the protocol does not support it unless the proxy forwards
// the originalAuthData.
log.info(
"[{}] Cannot revalidate user credential when using proxy and"
+ " not forwarding the credentials. Closing connection",
remoteAddress);
ctx.close();
return;
}

ctx.executor().execute(SafeRun.safeRun(() -> {
log.info("[{}] Refreshing authentication credentials for originalPrincipal {} and authRole {}",
remoteAddress, originalPrincipal, this.authRole);

if (!supportsAuthenticationRefresh()) {
log.warn("[{}] Closing connection because client doesn't support auth credentials refresh",
remoteAddress);
ctx.close();
return;
}
if (!supportsAuthenticationRefresh()) {
log.warn("[{}] Closing connection because client doesn't support auth credentials refresh",
remoteAddress);
ctx.close();
return;
}

if (pendingAuthChallengeResponse) {
log.warn("[{}] Closing connection after timeout on refreshing auth credentials",
remoteAddress);
ctx.close();
return;
}
if (pendingAuthChallengeResponse) {
log.warn("[{}] Closing connection after timeout on refreshing auth credentials",
remoteAddress);
ctx.close();
return;
}

try {
AuthData brokerData = authState.refreshAuthentication();
log.info("[{}] Refreshing authentication credentials for originalPrincipal {} and authRole {}",
remoteAddress, originalPrincipal, this.authRole);
try {
AuthData brokerData = authState.refreshAuthentication();

ctx.writeAndFlush(Commands.newAuthChallenge(authMethod, brokerData,
getRemoteEndpointProtocolVersion()));
if (log.isDebugEnabled()) {
log.debug("[{}] Sent auth challenge to client to refresh credentials with method: {}.",
remoteAddress, authMethod);
}
ctx.writeAndFlush(Commands.newAuthChallenge(authMethod, brokerData,
getRemoteEndpointProtocolVersion()));
if (log.isDebugEnabled()) {
log.debug("[{}] Sent auth challenge to client to refresh credentials with method: {}.",
remoteAddress, authMethod);
}

pendingAuthChallengeResponse = true;
pendingAuthChallengeResponse = true;

} catch (AuthenticationException e) {
log.warn("[{}] Failed to refresh authentication: {}", remoteAddress, e);
ctx.close();
}
}));
} catch (AuthenticationException e) {
log.warn("[{}] Failed to refresh authentication: {}", remoteAddress, e);
ctx.close();
}
}

private static final byte[] emptyArray = new byte[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -496,10 +496,13 @@ public void testAuthChallengePrincipalChangeFails() throws Exception {
when(brokerService.getAuthenticationService()).thenReturn(authenticationService);
when(authenticationService.getAuthenticationProvider(authMethodName)).thenReturn(authenticationProvider);
svcConfig.setAuthenticationEnabled(true);
svcConfig.setAuthenticationRefreshCheckSeconds(30);

resetChannel();
assertTrue(channel.isActive());
assertEquals(serverCnx.getState(), State.Start);
// Don't want the keep alive task affecting which messages are handled
serverCnx.cancelKeepAliveTask();

ByteBuf clientCommand = Commands.newConnect(authMethodName, "pass.client", "");
channel.writeInbound(clientCommand);
Expand All @@ -512,7 +515,7 @@ public void testAuthChallengePrincipalChangeFails() throws Exception {

// Trigger the ServerCnx to check if authentication is expired (it is because of our special implementation)
// and then force channel to run the task
serverCnx.refreshAuthenticationCredentials();
channel.advanceTimeBy(30, TimeUnit.SECONDS);
channel.runPendingTasks();
Object responseAuthChallenge1 = getResponse();
assertTrue(responseAuthChallenge1 instanceof CommandAuthChallenge);
Expand All @@ -522,7 +525,7 @@ public void testAuthChallengePrincipalChangeFails() throws Exception {
channel.writeInbound(authResponse1);

// Trigger the ServerCnx to check if authentication is expired again
serverCnx.refreshAuthenticationCredentials();
channel.advanceTimeBy(30, TimeUnit.SECONDS);
assertTrue(channel.hasPendingTasks(), "This test assumes there are pending tasks to run.");
channel.runPendingTasks();
Object responseAuthChallenge2 = getResponse();
Expand All @@ -548,10 +551,13 @@ public void testAuthChallengeOriginalPrincipalChangeFails() throws Exception {
svcConfig.setAuthenticationEnabled(true);
svcConfig.setAuthenticateOriginalAuthData(true);
svcConfig.setProxyRoles(Collections.singleton("pass.proxy"));
svcConfig.setAuthenticationRefreshCheckSeconds(30);

resetChannel();
assertTrue(channel.isActive());
assertEquals(serverCnx.getState(), State.Start);
// Don't want the keep alive task affecting which messages are handled
serverCnx.cancelKeepAliveTask();

ByteBuf clientCommand = Commands.newConnect(authMethodName, "pass.proxy", 1, null,
null, "pass.client", "pass.client", authMethodName);
Expand All @@ -568,7 +574,7 @@ public void testAuthChallengeOriginalPrincipalChangeFails() throws Exception {

// Trigger the ServerCnx to check if authentication is expired (it is because of our special implementation)
// and then force channel to run the task
serverCnx.refreshAuthenticationCredentials();
channel.advanceTimeBy(30, TimeUnit.SECONDS);
assertTrue(channel.hasPendingTasks(), "This test assumes there are pending tasks to run.");
channel.runPendingTasks();
Object responseAuthChallenge1 = getResponse();
Expand All @@ -579,7 +585,7 @@ public void testAuthChallengeOriginalPrincipalChangeFails() throws Exception {
channel.writeInbound(authResponse1);

// Trigger the ServerCnx to check if authentication is expired again
serverCnx.refreshAuthenticationCredentials();
channel.advanceTimeBy(30, TimeUnit.SECONDS);
channel.runPendingTasks();
Object responseAuthChallenge2 = getResponse();
assertTrue(responseAuthChallenge2 instanceof CommandAuthChallenge);
Expand Down

0 comments on commit 26e1053

Please sign in to comment.