diff --git a/src/main/java/io/lettuce/core/BaseRedisAuthenticationHandler.java b/src/main/java/io/lettuce/core/BaseRedisAuthenticationHandler.java new file mode 100644 index 000000000..fa9f9eb84 --- /dev/null +++ b/src/main/java/io/lettuce/core/BaseRedisAuthenticationHandler.java @@ -0,0 +1,117 @@ +package io.lettuce.core; + +import io.lettuce.core.codec.StringCodec; +import io.lettuce.core.protocol.AsyncCommand; +import io.lettuce.core.protocol.RedisCommand; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import reactor.core.Disposable; +import reactor.core.publisher.Flux; + +import java.nio.CharBuffer; +import java.util.concurrent.atomic.AtomicReference; + +public abstract class BaseRedisAuthenticationHandler> { + + private static final InternalLogger log = InternalLoggerFactory.getInstance(BaseRedisAuthenticationHandler.class); + + protected final T connection; + + private final RedisCommandBuilder commandBuilder = new RedisCommandBuilder<>(StringCodec.UTF8); + + private final AtomicReference credentialsSubscription = new AtomicReference<>(); + + public BaseRedisAuthenticationHandler(T connection) { + this.connection = connection; + } + + /** + * Subscribes to the provided `Flux` of credentials if the given `RedisCredentialsProvider` supports streaming credentials. + *

+ * This method subscribes to a stream of credentials provided by the `StreamingCredentialsProvider`. Each time new + * credentials are received, the client is reauthenticated. If the connection is not supported, the method returns without + * subscribing. + *

+ * The previous subscription, if any, is disposed of before setting the new subscription. + * + * @param credentialsProvider the credentials provider to subscribe to + */ + public void subscribe(RedisCredentialsProvider credentialsProvider) { + if (credentialsProvider == null) { + return; + } + + if (credentialsProvider instanceof StreamingCredentialsProvider) { + if (!isSupportedConnection()) { + return; + } + + Flux credentialsFlux = ((StreamingCredentialsProvider) credentialsProvider).credentials(); + + Disposable subscription = credentialsFlux.subscribe(this::onNext, this::onError, this::complete); + + Disposable oldSubscription = credentialsSubscription.getAndSet(subscription); + if (oldSubscription != null && !oldSubscription.isDisposed()) { + oldSubscription.dispose(); + } + } + } + + /** + * Unsubscribes from the current credentials stream. + */ + public void unsubscribe() { + Disposable subscription = credentialsSubscription.getAndSet(null); + if (subscription != null && !subscription.isDisposed()) { + subscription.dispose(); + } + } + + protected void complete() { + log.debug("Credentials stream completed"); + } + + protected void onNext(RedisCredentials credentials) { + reauthenticate(credentials); + } + + protected void onError(Throwable e) { + log.error("Credentials renew failed.", e); + } + + /** + * Performs re-authentication with the provided credentials. + * + * @param credentials the new credentials + */ + private void reauthenticate(RedisCredentials credentials) { + CharSequence password = CharBuffer.wrap(credentials.getPassword()); + + AsyncCommand authCmd; + if (credentials.hasUsername()) { + authCmd = new AsyncCommand<>(commandBuilder.auth(credentials.getUsername(), password)); + } else { + authCmd = new AsyncCommand<>(commandBuilder.auth(password)); + } + + dispatchAuth(authCmd).exceptionally(throwable -> { + log.error("Re-authentication {} failed.", credentials.hasUsername() ? "with username" : "without username", + throwable); + return null; + }); + } + + protected boolean isSupportedConnection() { + return true; + } + + private AsyncCommand dispatchAuth(RedisCommand authCommand) { + AsyncCommand asyncCommand = new AsyncCommand<>(authCommand); + RedisCommand dispatched = connection.getChannelWriter().write(asyncCommand); + if (dispatched instanceof AsyncCommand) { + return (AsyncCommand) dispatched; + } + return asyncCommand; + } + +} diff --git a/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java b/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java new file mode 100644 index 000000000..cdf57e987 --- /dev/null +++ b/src/main/java/io/lettuce/core/RedisAuthenticationHandler.java @@ -0,0 +1,44 @@ +/* + * Copyright 2019-Present, Redis Ltd. and Contributors + * All rights reserved. + * + * Licensed under the MIT License. + * + * This file contains contributions from third-party contributors + * licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.lettuce.core; + +import io.lettuce.core.protocol.ProtocolVersion; +import io.lettuce.core.pubsub.StatefulRedisPubSubConnection; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +class RedisAuthenticationHandler extends BaseRedisAuthenticationHandler> { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(RedisAuthenticationHandler.class); + + public RedisAuthenticationHandler(StatefulRedisConnectionImpl connection) { + super(connection); + } + + protected boolean isSupportedConnection() { + if (connection instanceof StatefulRedisPubSubConnection + && ProtocolVersion.RESP2 == connection.getConnectionState().getNegotiatedProtocolVersion()) { + logger.warn("Renewable credentials are not supported with RESP2 protocol on a pub/sub connection."); + return false; + } + return true; + } + +} diff --git a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java index 14ba7b570..b385c6c36 100644 --- a/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java +++ b/src/main/java/io/lettuce/core/StatefulRedisConnectionImpl.java @@ -67,6 +67,8 @@ public class StatefulRedisConnectionImpl extends RedisChannelHandler private final PushHandler pushHandler; + private final RedisAuthenticationHandler authHandler; + private final Mono parser; protected MultiOutput multi; @@ -104,6 +106,8 @@ public StatefulRedisConnectionImpl(RedisChannelWriter writer, PushHandler pushHa this.async = newRedisAsyncCommandsImpl(); this.sync = newRedisSyncCommandsImpl(); this.reactive = newRedisReactiveCommandsImpl(); + + this.authHandler = new RedisAuthenticationHandler(this); } public RedisCodec getCodec() { @@ -315,4 +319,16 @@ public ConnectionState getConnectionState() { return state; } + @Override + public void activated() { + super.activated(); + authHandler.subscribe(state.getCredentialsProvider()); + } + + @Override + public void deactivated() { + authHandler.unsubscribe(); + super.deactivated(); + } + } diff --git a/src/main/java/io/lettuce/core/StreamingCredentialsProvider.java b/src/main/java/io/lettuce/core/StreamingCredentialsProvider.java new file mode 100644 index 000000000..08ab89850 --- /dev/null +++ b/src/main/java/io/lettuce/core/StreamingCredentialsProvider.java @@ -0,0 +1,15 @@ +package io.lettuce.core; + +import reactor.core.publisher.Flux; + +public interface StreamingCredentialsProvider extends RedisCredentialsProvider { + + /** + * Returns a {@link Flux} emitting {@link RedisCredentials} that can be used to authorize a Redis connection. This + * credential provider supports streaming credentials, meaning that it can emit multiple credentials over time. + * + * @return + */ + Flux credentials(); + +} diff --git a/src/main/java/io/lettuce/core/cluster/RedisClusterAuthenticationHandler.java b/src/main/java/io/lettuce/core/cluster/RedisClusterAuthenticationHandler.java new file mode 100644 index 000000000..1e7f52405 --- /dev/null +++ b/src/main/java/io/lettuce/core/cluster/RedisClusterAuthenticationHandler.java @@ -0,0 +1,45 @@ +/* + * Copyright 2019-Present, Redis Ltd. and Contributors + * All rights reserved. + * + * Licensed under the MIT License. + * + * This file contains contributions from third-party contributors + * licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.lettuce.core.cluster; + +import io.lettuce.core.BaseRedisAuthenticationHandler; +import io.lettuce.core.cluster.pubsub.StatefulRedisClusterPubSubConnection; +import io.lettuce.core.protocol.ProtocolVersion; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +class RedisClusterAuthenticationHandler extends BaseRedisAuthenticationHandler> { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(RedisClusterAuthenticationHandler.class); + + public RedisClusterAuthenticationHandler(StatefulRedisClusterConnectionImpl connection) { + super(connection); + } + + protected boolean isSupportedConnection() { + if (connection instanceof StatefulRedisClusterPubSubConnection + && ProtocolVersion.RESP2 == connection.getConnectionState().getNegotiatedProtocolVersion()) { + logger.warn("Renewable credentials are not supported with RESP2 protocol on a pub/sub connection."); + return false; + } + return true; + } + +} diff --git a/src/main/java/io/lettuce/core/cluster/StatefulRedisClusterConnectionImpl.java b/src/main/java/io/lettuce/core/cluster/StatefulRedisClusterConnectionImpl.java index c84193491..109c0da39 100644 --- a/src/main/java/io/lettuce/core/cluster/StatefulRedisClusterConnectionImpl.java +++ b/src/main/java/io/lettuce/core/cluster/StatefulRedisClusterConnectionImpl.java @@ -89,6 +89,8 @@ public class StatefulRedisClusterConnectionImpl extends RedisChannelHandle private volatile Partitions partitions; + private final RedisClusterAuthenticationHandler authHandler; + /** * Initialize a new connection. * @@ -123,6 +125,8 @@ public StatefulRedisClusterConnectionImpl(RedisChannelWriter writer, ClusterPush this.async = newRedisAdvancedClusterAsyncCommandsImpl(); this.sync = newRedisAdvancedClusterCommandsImpl(); this.reactive = newRedisAdvancedClusterReactiveCommandsImpl(); + + this.authHandler = new RedisClusterAuthenticationHandler(this); } protected RedisAdvancedClusterReactiveCommandsImpl newRedisAdvancedClusterReactiveCommandsImpl() { @@ -230,6 +234,12 @@ public void activated() { super.activated(); async.clusterMyId().thenAccept(connectionState::setNodeId); + authHandler.subscribe(connectionState.getCredentialsProvider()); + } + + @Override + public void deactivated() { + authHandler.unsubscribe(); } ClusterDistributionChannelWriter getClusterDistributionChannelWriter() { diff --git a/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java b/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java index 864a2103b..93d502747 100644 --- a/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java +++ b/src/test/java/io/lettuce/core/AuthenticationIntegrationTests.java @@ -5,6 +5,10 @@ import javax.inject.Inject; +import io.lettuce.core.event.command.CommandListener; +import io.lettuce.core.event.command.CommandSucceededEvent; +import io.lettuce.core.protocol.RedisCommand; +import org.awaitility.Awaitility; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Tag; import org.junit.jupiter.api.Test; @@ -19,12 +23,20 @@ import io.lettuce.test.WithPassword; import io.lettuce.test.condition.EnabledOnCommand; import io.lettuce.test.settings.TestSettings; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.publisher.Sinks; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; /** * Integration test for authentication. * * @author Mark Paluch + * @author Ivo Gaydajiev */ @Tag(INTEGRATION_TEST) @ExtendWith(LettuceExtension.class) @@ -71,4 +83,116 @@ void ownCredentialProvider(RedisClient client) { }); } + // Simulate test user credential rotation, and verify that re-authentication is successful + @Test + @Inject + void renewableCredentialProvider(RedisClient client) { + + // Thread-safe list to capture intercepted commands + List> interceptedCommands = Collections.synchronizedList(new ArrayList<>()); + + // CommandListener to track successful commands + CommandListener commandListener = new CommandListener() { + + @Override + public void commandSucceeded(CommandSucceededEvent event) { + interceptedCommands.add(event.getCommand()); + } + + }; + + // Add CommandListener to the client + client.addListener(commandListener); + + // Configure client options + client.setOptions( + ClientOptions.builder().disconnectedBehavior(ClientOptions.DisconnectedBehavior.REJECT_COMMANDS).build()); + + // Connection for managing test user credential rotation + StatefulRedisConnection adminConnection = client.connect(); + + String testUser = "streaming_cred_test_user"; + char[] initialPassword = "token_1".toCharArray(); + char[] updatedPassword = "token_2".toCharArray(); + + // Streaming credentials provider to simulate token emission + RenewableRedisCredentialsProvider credentialsProvider = new RenewableRedisCredentialsProvider(); + + // Build RedisURI with streaming credentials provider + RedisURI uri = RedisURI.builder().withHost(TestSettings.host()).withPort(TestSettings.port()) + .withClientName("streaming_cred_test").withAuthentication(credentialsProvider) + .withTimeout(Duration.ofSeconds(1)).build(); + + // Create test user and set initial credentials + createTestUser(adminConnection, testUser, initialPassword); + credentialsProvider.emitToken(new StaticRedisCredentials(testUser, initialPassword)); + + // Establish connection using the streaming credentials provider + StatefulRedisConnection userConnection = client.connect(StringCodec.UTF8, uri); + + // Verify initial authentication + assertThat(userConnection.sync().aclWhoami()).isEqualTo(testUser); + + // Update test user credentials and emit updated credentials + updateTestUser(adminConnection, testUser, updatedPassword); + credentialsProvider.emitToken(new StaticRedisCredentials(testUser, updatedPassword)); + + // Wait for the `AUTH` command with updated credentials + Awaitility.await().atMost(Duration.ofSeconds(1)).until(() -> interceptedCommands.stream() + .anyMatch(command -> isAuthCommandWithCredentials(command, testUser, updatedPassword))); + + // Verify re-authentication and connection functionality + assertThat(userConnection.sync().ping()).isEqualTo("PONG"); + assertThat(userConnection.sync().aclWhoami()).isEqualTo(testUser); + + // Clean up + adminConnection.close(); + userConnection.close(); + } + + private void createTestUser(StatefulRedisConnection connection, String username, char[] password) { + AclSetuserArgs args = AclSetuserArgs.Builder.on().allCommands().allChannels().allKeys().nopass() + .addPassword(String.valueOf(password)); + connection.sync().aclSetuser(username, args); + } + + private void updateTestUser(StatefulRedisConnection connection, String username, char[] newPassword) { + AclSetuserArgs args = AclSetuserArgs.Builder.on().allCommands().allChannels().allKeys().nopass() + .addPassword(String.valueOf(newPassword)); + connection.sync().aclSetuser(username, args); + } + + private boolean isAuthCommandWithCredentials(RedisCommand command, String username, char[] password) { + if (command.getType() == CommandType.AUTH) { + CommandArgs args = command.getArgs(); + return args.toCommandString().contains(username) && args.toCommandString().contains(String.valueOf(password)); + } + return false; + } + + static class RenewableRedisCredentialsProvider implements StreamingCredentialsProvider { + + private final Sinks.Many credentialsSink = Sinks.many().replay().latest(); + + @Override + public Mono resolveCredentials() { + + return credentialsSink.asFlux().next(); + } + + public Flux credentials() { + + return credentialsSink.asFlux().onBackpressureLatest(); // Provide a continuous stream of credentials + } + + public void shutdown() { + credentialsSink.tryEmitComplete(); + } + + public void emitToken(RedisCredentials credentials) { + credentialsSink.tryEmitNext(credentials); + } + + } + } diff --git a/src/test/java/io/lettuce/core/BaseRedisAuthenticationHandlerTest.java b/src/test/java/io/lettuce/core/BaseRedisAuthenticationHandlerTest.java new file mode 100644 index 000000000..ba5b087b1 --- /dev/null +++ b/src/test/java/io/lettuce/core/BaseRedisAuthenticationHandlerTest.java @@ -0,0 +1,111 @@ +package io.lettuce.core; + +import io.lettuce.core.protocol.CommandType; +import io.lettuce.core.protocol.RedisCommand; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Sinks; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +class BaseRedisAuthenticationHandlerTest { + + private BaseRedisAuthenticationHandler> handler; + + private RedisChannelHandler connection; + + private RedisChannelWriter channelWriter; + + private StreamingCredentialsProvider streamingCredentialsProvider; + + private Sinks.Many sink; + + @BeforeEach + void setUp() { + + connection = mock(RedisChannelHandler.class); + channelWriter = mock(RedisChannelWriter.class); + when(connection.getChannelWriter()).thenReturn(channelWriter); + streamingCredentialsProvider = mock(StreamingCredentialsProvider.class); + sink = Sinks.many().replay().latest(); + Flux credentialsFlux = sink.asFlux(); + when(streamingCredentialsProvider.credentials()).thenReturn(credentialsFlux); + handler = new BaseRedisAuthenticationHandler>(connection) { + + @Override + protected boolean isSupportedConnection() { + return true; + } + + }; + } + + @SuppressWarnings("unchecked") + @Test + void subscribeWithStreamingCredentialsProviderInvokesReauth() { + + // Subscribe to the provider + handler.subscribe(streamingCredentialsProvider); + sink.tryEmitNext(RedisCredentials.just("newuser", "newpassword")); + + // Ensure credentials() method was invoked + verify(streamingCredentialsProvider).credentials(); + + // Verify that write() is invoked once + verify(channelWriter, times(1)).write(any(RedisCommand.class)); + + ArgumentCaptor> captor = ArgumentCaptor.forClass(RedisCommand.class); + verify(channelWriter).write(captor.capture()); + + RedisCommand capturedCommand = captor.getValue(); + assertThat(capturedCommand.getType()).isEqualTo(CommandType.AUTH); + assertThat(capturedCommand.getArgs().toCommandString()).contains("newuser"); + assertThat(capturedCommand.getArgs().toCommandString()).contains("newpassword"); + } + + @Test + void shouldHandleErrorInCredentialsStream() { + Sinks.Many sink = Sinks.many().replay().latest(); + Flux credentialsFlux = sink.asFlux(); + StreamingCredentialsProvider credentialsProvider = mock(StreamingCredentialsProvider.class); + when(credentialsProvider.credentials()).thenReturn(credentialsFlux); + + // Subscribe to the provider and simulate an error + handler.subscribe(credentialsProvider); + sink.tryEmitError(new RuntimeException("Test error")); + + verify(connection.getChannelWriter(), times(0)).write(any(RedisCommand.class)); // No command should be sent + } + + @Test + void shouldNotSubscribeIfConnectionIsNotSupported() { + Sinks.Many sink = Sinks.many().replay().latest(); + Flux credentialsFlux = sink.asFlux(); + StreamingCredentialsProvider credentialsProvider = mock(StreamingCredentialsProvider.class); + when(credentialsProvider.credentials()).thenReturn(credentialsFlux); + + BaseRedisAuthenticationHandler handler = new BaseRedisAuthenticationHandler>(connection) { + + @Override + protected boolean isSupportedConnection() { + // Simulate : Pub/Sub connections are not supported with RESP2 + return false; + } + + }; + + // Subscribe to the provider (it should not subscribe due to unsupported connection) + handler.subscribe(credentialsProvider); + + // Ensure credentials() was not called + verify(credentialsProvider, times(0)).credentials(); + } + +} diff --git a/src/test/java/io/lettuce/core/RedisAuthenticationHandlerTest.java b/src/test/java/io/lettuce/core/RedisAuthenticationHandlerTest.java new file mode 100644 index 000000000..04ca9405f --- /dev/null +++ b/src/test/java/io/lettuce/core/RedisAuthenticationHandlerTest.java @@ -0,0 +1,52 @@ +package io.lettuce.core; + +import io.lettuce.core.protocol.ProtocolVersion; +import io.lettuce.core.pubsub.StatefulRedisPubSubConnection; +import io.lettuce.core.pubsub.StatefulRedisPubSubConnectionImpl; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +public class RedisAuthenticationHandlerTest { + + @Test + void testIsSupportedConnectionWithRESP2ProtocolOnPubSubConnection() { + StatefulRedisPubSubConnectionImpl connection = mock(StatefulRedisPubSubConnectionImpl.class, + withSettings().extraInterfaces(StatefulRedisPubSubConnection.class)); + + ConnectionState connectionState = mock(ConnectionState.class); + when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP2); + when(connection.getConnectionState()).thenReturn(connectionState); + RedisAuthenticationHandler handler = new RedisAuthenticationHandler(connection); + + assertFalse(handler.isSupportedConnection()); + } + + @Test + void testIsSupportedConnectionWithNonPubSubConnection() { + StatefulRedisConnectionImpl connection = mock(StatefulRedisConnectionImpl.class); + ConnectionState connectionState = mock(ConnectionState.class); + when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP2); + when(connection.getConnectionState()).thenReturn(connectionState); + RedisAuthenticationHandler handler = new RedisAuthenticationHandler(connection); + + assertTrue(handler.isSupportedConnection()); + } + + @Test + void testIsSupportedConnectionWithRESP3ProtocolOnPubSubConnection() { + + StatefulRedisPubSubConnectionImpl connection = mock(StatefulRedisPubSubConnectionImpl.class); + ConnectionState connectionState = mock(ConnectionState.class); + when(connectionState.getNegotiatedProtocolVersion()).thenReturn(ProtocolVersion.RESP3); + when(connection.getConnectionState()).thenReturn(connectionState); + RedisAuthenticationHandler handler = new RedisAuthenticationHandler(connection); + + assertTrue(handler.isSupportedConnection()); + } + +}