Skip to content

Commit

Permalink
Fixed ssl error handlers
Browse files Browse the repository at this point in the history
  • Loading branch information
chkead committed Jan 30, 2024
1 parent a98e309 commit 361b81c
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.hivemq.configuration.service.entity.Tls;
import com.hivemq.configuration.service.entity.TlsListener;
import com.hivemq.extension.sdk.api.annotations.NotNull;
import com.hivemq.mqtt.handler.connect.NoTlsHandshakeIdleHandler;
import com.hivemq.mqtt.handler.disconnect.MqttServerDisconnector;
import com.hivemq.security.exception.SslException;
import com.hivemq.security.ssl.SslClientCertificateHandler;
Expand All @@ -28,13 +27,6 @@
import com.hivemq.security.ssl.SslSniHandler;
import io.netty.channel.Channel;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.IdleStateHandler;

import java.util.concurrent.TimeUnit;

import static com.hivemq.bootstrap.netty.ChannelHandlerNames.NEW_CONNECTION_IDLE_HANDLER;
import static com.hivemq.bootstrap.netty.ChannelHandlerNames.NO_TLS_HANDSHAKE_IDLE_EVENT_HANDLER;
import static com.hivemq.bootstrap.netty.ChannelHandlerNames.SSL_CLIENT_CERTIFICATE_HANDLER;
import static com.hivemq.bootstrap.netty.ChannelHandlerNames.SSL_EXCEPTION_HANDLER;
import static com.hivemq.bootstrap.netty.ChannelHandlerNames.SSL_HANDLER;
Expand Down Expand Up @@ -71,29 +63,12 @@ protected void addNoConnectIdleHandlerAfterTlsHandshake(@NotNull final Channel c

@Override
protected void addSpecialHandlers(@NotNull final Channel ch) throws SslException {
final int handshakeTimeout = tlsListener.getTls().getHandshakeTimeout();

final IdleStateHandler idleStateHandler = new IdleStateHandler(handshakeTimeout, 0, 0, TimeUnit.MILLISECONDS);
final MqttServerDisconnector mqttServerDisconnector = channelDependencies.getMqttServerDisconnector();
final NoTlsHandshakeIdleHandler noTlsHandshakeIdleHandler =
new NoTlsHandshakeIdleHandler(mqttServerDisconnector);
if (handshakeTimeout > 0) {
ch.pipeline().addLast(NEW_CONNECTION_IDLE_HANDLER, idleStateHandler);
ch.pipeline().addLast(NO_TLS_HANDSHAKE_IDLE_EVENT_HANDLER, noTlsHandshakeIdleHandler);
}

final Tls tls = tlsListener.getTls();
final SslContext sslContext = sslFactory.getSslContext(tls);
final SslHandler sslHandler = sslFactory.getSslHandler(ch, tls, sslContext);
sslHandler.handshakeFuture().addListener(future -> {
if (handshakeTimeout > 0) {
ch.pipeline().remove(idleStateHandler);
ch.pipeline().remove(noTlsHandshakeIdleHandler);
}
addNoConnectIdleHandlerAfterTlsHandshake(ch);
});
final MqttServerDisconnector mqttServerDisconnector = channelDependencies.getMqttServerDisconnector();

ch.pipeline().addFirst(SSL_HANDLER, new SslSniHandler(sslHandler, sslContext));
ch.pipeline().addFirst(SSL_HANDLER, new SslSniHandler(sslContext, sslFactory, mqttServerDisconnector, ch, tls,
this::addNoConnectIdleHandlerAfterTlsHandshake));
ch.pipeline().addAfter(SSL_HANDLER, SSL_EXCEPTION_HANDLER, new SslExceptionHandler(mqttServerDisconnector));
ch.pipeline()
.addAfter(SSL_EXCEPTION_HANDLER, SSL_PARAMETER_HANDLER, channelDependencies.getSslParameterHandler());
Expand Down
13 changes: 12 additions & 1 deletion src/main/java/com/hivemq/security/ssl/SslFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,20 @@ public SslFactory(final @NotNull SslContextStore sslContextStore) {
public @NotNull SslHandler getSslHandler(
final @NotNull Channel ch, final @NotNull Tls tls, final @NotNull SslContext sslContext)
throws SslException {

final SSLEngine sslEngine = sslContext.newEngine(ch.alloc());
return getSslHandler(sslEngine, tls);
}

public @NotNull SslHandler getSslHandler(
final @NotNull Channel ch, final @NotNull Tls tls, final @NotNull SslContext sslContext, final String hostname, final int port)
throws SslException {
final SSLEngine sslEngine = sslContext.newEngine(ch.alloc(), hostname, port);
return getSslHandler(sslEngine, tls);
}

public @NotNull SslHandler getSslHandler(
final @NotNull SSLEngine sslEngine, final @NotNull Tls tls)
throws SslException {
// if prefer server suites is null -> use default of the engine
final Boolean preferServerCipherSuites = tls.isPreferServerCipherSuites();
if (preferServerCipherSuites != null) {
Expand Down
61 changes: 47 additions & 14 deletions src/main/java/com/hivemq/security/ssl/SslSniHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,46 +16,70 @@
package com.hivemq.security.ssl;

import com.hivemq.bootstrap.ClientConnectionContext;
import com.hivemq.configuration.service.entity.Tls;
import com.hivemq.extension.sdk.api.annotations.NotNull;
import com.hivemq.extension.sdk.api.annotations.Nullable;
import com.hivemq.mqtt.handler.connect.NoTlsHandshakeIdleHandler;
import com.hivemq.mqtt.handler.disconnect.MqttServerDisconnector;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.timeout.IdleStateHandler;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.NoSuchElementException;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import static com.hivemq.bootstrap.netty.ChannelHandlerNames.NEW_CONNECTION_IDLE_HANDLER;
import static com.hivemq.bootstrap.netty.ChannelHandlerNames.NO_TLS_HANDSHAKE_IDLE_EVENT_HANDLER;
import static com.hivemq.bootstrap.netty.ChannelHandlerNames.SSL_HANDLER;

public class SslSniHandler extends SniHandler {

private static final Logger log = LoggerFactory.getLogger(SslSniHandler.class);

private final @NotNull SslHandler sslHandler;

private final @NotNull Tls tls;
private final @NotNull Channel ch;
private final MqttServerDisconnector mqttServerDisconnector;
private final Consumer<Channel> idleHandlerFunction;
private final @NotNull SslFactory sslFactory;
private final @NotNull HashMap<String, SslHandler> aliasSslHandlerMap = new HashMap<>();

public SslSniHandler(final @NotNull SslHandler sslHandler, final @NotNull SslContext sslContext) {
private final IdleStateHandler idleStateHandler;
private final NoTlsHandshakeIdleHandler noTlsHandshakeIdleHandler;

public SslSniHandler(
final @NotNull SslContext sslContext, final @NotNull SslFactory sslFactory, final MqttServerDisconnector mqttServerDisconnector,
final @NotNull Channel ch, final Tls tls,
final Consumer<Channel> idleHandlerFunction) {
super((input, promise) -> {
//This could be used to return a different SslContext depending on the provided hostname
//For now the same SslContext is returned independent of the provided hostname

log.trace("SSLContext with input {}, cipherSuites {} and attributes {}", input, sslContext.cipherSuites(), sslContext.attributes());
log.info("SSLContext with input {}, cipherSuites {} and attributes {}", input, sslContext.cipherSuites(), sslContext.attributes());
promise.setSuccess(sslContext);
return promise;
});
this.sslHandler = sslHandler;
}

@Override
protected Future<SslContext> lookup(ChannelHandlerContext ctx, String hostname) throws Exception {
log.trace("lookup ChannelHandlerContext ctx: {} hostname: {}", ctx, hostname);
return mapping.map(hostname, ctx.executor().<SslContext>newPromise());
this.tls = tls;
this.ch = ch;
this.mqttServerDisconnector = mqttServerDisconnector;
this.idleHandlerFunction = idleHandlerFunction;
this.sslFactory = sslFactory;

final int handshakeTimeout = tls.getHandshakeTimeout();
idleStateHandler = new IdleStateHandler(handshakeTimeout, 0, 0, TimeUnit.MILLISECONDS);
noTlsHandshakeIdleHandler = new NoTlsHandshakeIdleHandler(mqttServerDisconnector);

if (handshakeTimeout > 0) {
ch.pipeline().addLast(NEW_CONNECTION_IDLE_HANDLER, idleStateHandler);
ch.pipeline().addLast(NO_TLS_HANDSHAKE_IDLE_EVENT_HANDLER, noTlsHandshakeIdleHandler);
}
}

@Override
Expand All @@ -77,11 +101,20 @@ protected void replaceHandler(
SslHandler sslHandlerInstance = null;
try {
final int port = ClientConnectionContext.of(ctx.channel()).getConnectedListener().getPort();
log.trace("Replace ssl handler for hostname: {} and port: {}", hostname, port);
log.info("Replace ssl handler for hostname: {} and port: {}", hostname, port);
if (!aliasSslHandlerMap.containsKey(hostname)) {
aliasSslHandlerMap.put(hostname, sslContext.newHandler(ctx.alloc(), hostname, port));
aliasSslHandlerMap.put(hostname, sslFactory.getSslHandler(ch, tls, sslContext, hostname, port));
}
sslHandlerInstance = aliasSslHandlerMap.get(hostname);

sslHandlerInstance.handshakeFuture().addListener(future -> {
if (tls.getHandshakeTimeout() > 0) {
ch.pipeline().remove(idleStateHandler);
ch.pipeline().remove(noTlsHandshakeIdleHandler);
}
idleHandlerFunction.accept(ch);
});

ctx.pipeline().replace(this, SSL_HANDLER, sslHandlerInstance);
sslHandlerInstance = null;
} catch (final NoSuchElementException ignored) {
Expand Down
27 changes: 18 additions & 9 deletions src/test/java/com/hivemq/security/ssl/SslSniHandlerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,20 @@
import com.hivemq.bootstrap.ClientConnection;
import com.hivemq.bootstrap.ClientConnectionContext;
import com.hivemq.configuration.service.entity.TcpListener;
import io.netty.buffer.ByteBufAllocator;
import com.hivemq.configuration.service.entity.Tls;
import com.hivemq.mqtt.handler.disconnect.MqttServerDisconnector;
import io.netty.channel.Channel;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslHandler;
import io.netty.util.concurrent.DefaultPromise;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GlobalEventExecutor;
import org.junit.Test;
import org.mockito.Mockito;
import util.DummyClientConnection;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertSame;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand All @@ -47,14 +45,25 @@ public void test_replaceHandler() throws Exception {

final SslHandler sslHandler = mock(SslHandler.class);
final SslContext sslContext = mock(SslContext.class);
final SslFactory sslFactory = mock(SslFactory.class);
final Tls tls = mock(Tls.class);

final SslSniHandler sslSniHandler = new SslSniHandler(sslHandler, sslContext);
final Channel ch = mock(Channel.class);
final MqttServerDisconnector mqttServerDisconnector = mock(MqttServerDisconnector.class);

final SslSniHandler sslSniHandler = new SslSniHandler(sslContext,
sslFactory,
mqttServerDisconnector, ch, tls, (final Channel channel) -> {});
final Channel channel = new EmbeddedChannel(sslSniHandler);
final DummyClientConnection dummyClientConnection = new DummyClientConnection(channel, null, new TcpListener(8883, "localhost", "ssl"));


final Future<Channel> value = new DefaultPromise<>(GlobalEventExecutor.INSTANCE);
when(sslHandler.handshakeFuture()).thenReturn(value);
when(sslFactory.getSslHandler(ch, tls, sslContext, "abc.com", 8883)).thenReturn(sslHandler);

channel.attr(ClientConnectionContext.CHANNEL_ATTRIBUTE_NAME).set(dummyClientConnection);

when(sslContext.newHandler(channel.pipeline().firstContext().alloc(), "abc.com", 8883)).thenReturn(sslHandler);
sslSniHandler.replaceHandler(channel.pipeline().firstContext(), "abc.com", sslContext);

assertEquals("abc.com", ClientConnection.of(channel).getAuthSniHostname());
Expand Down

0 comments on commit 361b81c

Please sign in to comment.