diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java index 8b174eac468e..211e609ba4c0 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/BytesChannelContext.java @@ -21,19 +21,12 @@ import java.io.IOException; import java.util.function.Consumer; -import java.util.function.Predicate; public class BytesChannelContext extends SocketChannelContext { public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, - ReadWriteHandler handler, InboundChannelBuffer channelBuffer) { - this(channel, selector, exceptionHandler, handler, channelBuffer, ALWAYS_ALLOW_CHANNEL); - } - - public BytesChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, - ReadWriteHandler handler, InboundChannelBuffer channelBuffer, - Predicate allowChannelPredicate) { - super(channel, selector, exceptionHandler, handler, channelBuffer, allowChannelPredicate); + NioChannelHandler handler, InboundChannelBuffer channelBuffer) { + super(channel, selector, exceptionHandler, handler, channelBuffer); } @Override diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java index 07333aa2eebc..48d83d216924 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/BytesWriteHandler.java @@ -24,7 +24,7 @@ import java.util.List; import java.util.function.BiConsumer; -public abstract class BytesWriteHandler implements ReadWriteHandler { +public abstract class BytesWriteHandler implements NioChannelHandler { private static final List EMPTY_LIST = Collections.emptyList(); @@ -48,6 +48,11 @@ public List pollFlushOperations() { return EMPTY_LIST; } + @Override + public boolean closeNow() { + return false; + } + @Override public void close() {} } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/DelegatingHandler.java b/libs/nio/src/main/java/org/elasticsearch/nio/DelegatingHandler.java new file mode 100644 index 000000000000..d928b0bf9349 --- /dev/null +++ b/libs/nio/src/main/java/org/elasticsearch/nio/DelegatingHandler.java @@ -0,0 +1,68 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.nio; + +import java.io.IOException; +import java.util.List; +import java.util.function.BiConsumer; + +public abstract class DelegatingHandler implements NioChannelHandler { + + private NioChannelHandler delegate; + + public DelegatingHandler(NioChannelHandler delegate) { + this.delegate = delegate; + } + + @Override + public void channelRegistered() { + this.delegate.channelRegistered(); + } + + @Override + public WriteOperation createWriteOperation(SocketChannelContext context, Object message, BiConsumer listener) { + return delegate.createWriteOperation(context, message, listener); + } + + @Override + public List writeToBytes(WriteOperation writeOperation) { + return delegate.writeToBytes(writeOperation); + } + + @Override + public List pollFlushOperations() { + return delegate.pollFlushOperations(); + } + + @Override + public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException { + return delegate.consumeReads(channelBuffer); + } + + @Override + public boolean closeNow() { + return delegate.closeNow(); + } + + @Override + public void close() throws IOException { + delegate.close(); + } +} diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/ReadWriteHandler.java b/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java similarity index 90% rename from libs/nio/src/main/java/org/elasticsearch/nio/ReadWriteHandler.java rename to libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java index 92b276ad2d6d..61bda9a45076 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/ReadWriteHandler.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/NioChannelHandler.java @@ -24,9 +24,9 @@ import java.util.function.BiConsumer; /** - * Implements the application specific logic for handling inbound and outbound messages for a channel. + * Implements the application specific logic for handling channel operations. */ -public interface ReadWriteHandler { +public interface NioChannelHandler { /** * This method is called when the channel is registered with its selector. @@ -72,5 +72,12 @@ public interface ReadWriteHandler { */ int consumeReads(InboundChannelBuffer channelBuffer) throws IOException; + /** + * This method indicates if the underlying channel should be closed. + * + * @return if the channel should be closed + */ + boolean closeNow(); + void close() throws IOException; } diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java index 22d85472126c..21de98e096c0 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/SocketChannelContext.java @@ -32,7 +32,6 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Predicate; /** * This context should implement the specific logic for a channel. When a channel receives a notification @@ -45,13 +44,10 @@ */ public abstract class SocketChannelContext extends ChannelContext { - protected static final Predicate ALWAYS_ALLOW_CHANNEL = (c) -> true; - protected final NioSocketChannel channel; protected final InboundChannelBuffer channelBuffer; protected final AtomicBoolean isClosing = new AtomicBoolean(false); - private final ReadWriteHandler readWriteHandler; - private final Predicate allowChannelPredicate; + private final NioChannelHandler readWriteHandler; private final NioSelector selector; private final CompletableContext connectContext = new CompletableContext<>(); private final LinkedList pendingFlushes = new LinkedList<>(); @@ -59,14 +55,12 @@ public abstract class SocketChannelContext extends ChannelContext private Exception connectException; protected SocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, - ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer, - Predicate allowChannelPredicate) { + NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) { super(channel.getRawChannel(), exceptionHandler); this.selector = selector; this.channel = channel; this.readWriteHandler = readWriteHandler; this.channelBuffer = channelBuffer; - this.allowChannelPredicate = allowChannelPredicate; } @Override @@ -171,9 +165,6 @@ protected FlushOperation getPendingFlush() { protected void register() throws IOException { super.register(); readWriteHandler.channelRegistered(); - if (allowChannelPredicate.test(channel) == false) { - closeNow = true; - } } @Override @@ -233,7 +224,7 @@ public boolean readyForFlush() { public abstract boolean selectorShouldClose(); protected boolean closeNow() { - return closeNow; + return closeNow || readWriteHandler.closeNow(); } protected void setCloseNow() { diff --git a/libs/nio/src/main/java/org/elasticsearch/nio/WriteOperation.java b/libs/nio/src/main/java/org/elasticsearch/nio/WriteOperation.java index 3d17519be7e1..b5f60bd28a1d 100644 --- a/libs/nio/src/main/java/org/elasticsearch/nio/WriteOperation.java +++ b/libs/nio/src/main/java/org/elasticsearch/nio/WriteOperation.java @@ -23,7 +23,7 @@ /** * This is a basic write operation that can be queued with a channel. The only requirements of a write * operation is that is has a listener and a reference to its channel. The actual conversion of the write - * operation implementation to bytes will be performed by the {@link ReadWriteHandler}. + * operation implementation to bytes will be performed by the {@link NioChannelHandler}. */ public interface WriteOperation { diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java index 1b42e8be60d7..578890b152ff 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/EventHandlerTests.java @@ -44,7 +44,7 @@ public class EventHandlerTests extends ESTestCase { private Consumer channelExceptionHandler; private Consumer genericExceptionHandler; - private ReadWriteHandler readWriteHandler; + private NioChannelHandler readWriteHandler; private EventHandler handler; private DoNotRegisterSocketContext context; private DoNotRegisterServerContext serverContext; @@ -56,7 +56,7 @@ public class EventHandlerTests extends ESTestCase { public void setUpHandler() throws IOException { channelExceptionHandler = mock(Consumer.class); genericExceptionHandler = mock(Consumer.class); - readWriteHandler = mock(ReadWriteHandler.class); + readWriteHandler = mock(NioChannelHandler.class); channelFactory = mock(ChannelFactory.class); NioSelector selector = mock(NioSelector.class); ArrayList selectors = new ArrayList<>(); @@ -260,7 +260,7 @@ private class DoNotRegisterSocketContext extends BytesChannelContext { DoNotRegisterSocketContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, - ReadWriteHandler handler) { + NioChannelHandler handler) { super(channel, selector, exceptionHandler, handler, InboundChannelBuffer.allocatingInstance()); } diff --git a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java index 0040f70df85a..c0c203f728fd 100644 --- a/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java +++ b/libs/nio/src/test/java/org/elasticsearch/nio/SocketChannelContextTests.java @@ -35,7 +35,6 @@ import java.util.function.BiConsumer; import java.util.function.Consumer; import java.util.function.IntFunction; -import java.util.function.Predicate; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; @@ -54,7 +53,7 @@ public class SocketChannelContextTests extends ESTestCase { private NioSocketChannel channel; private BiConsumer listener; private NioSelector selector; - private ReadWriteHandler readWriteHandler; + private NioChannelHandler readWriteHandler; private ByteBuffer ioBuffer = ByteBuffer.allocate(1024); @SuppressWarnings("unchecked") @@ -68,7 +67,7 @@ public void setup() throws Exception { when(channel.getRawChannel()).thenReturn(rawChannel); exceptionHandler = mock(Consumer.class); selector = mock(NioSelector.class); - readWriteHandler = mock(ReadWriteHandler.class); + readWriteHandler = mock(NioChannelHandler.class); InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); @@ -102,22 +101,6 @@ public void testSignalWhenPeerClosed() throws IOException { assertTrue(context.closeNow()); } - public void testValidateInRegisterCanSucceed() throws IOException { - InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, (c) -> true); - assertFalse(context.closeNow()); - context.register(); - assertFalse(context.closeNow()); - } - - public void testValidateInRegisterCanFail() throws IOException { - InboundChannelBuffer channelBuffer = InboundChannelBuffer.allocatingInstance(); - context = new TestSocketChannelContext(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, (c) -> false); - assertFalse(context.closeNow()); - context.register(); - assertTrue(context.closeNow()); - } - public void testConnectSucceeds() throws IOException { AtomicBoolean listenerCalled = new AtomicBoolean(false); when(rawChannel.finishConnect()).thenReturn(false, true); @@ -394,14 +377,8 @@ public void testFlushBuffersHandlesIOExceptionSecondTimeThroughLoop() throws IOE private static class TestSocketChannelContext extends SocketChannelContext { private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, - ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer) { - this(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, ALWAYS_ALLOW_CHANNEL); - } - - private TestSocketChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, - ReadWriteHandler readWriteHandler, InboundChannelBuffer channelBuffer, - Predicate allowChannelPredicate) { - super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate); + NioChannelHandler readWriteHandler, InboundChannelBuffer channelBuffer) { + super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); } @Override diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java index 7a4fbfe42aef..c603e20ffc93 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java @@ -38,7 +38,7 @@ import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.ReadWriteHandler; +import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.TaskScheduler; import org.elasticsearch.nio.WriteOperation; @@ -50,7 +50,7 @@ import java.util.function.BiConsumer; import java.util.function.LongSupplier; -public class HttpReadWriteHandler implements ReadWriteHandler { +public class HttpReadWriteHandler implements NioChannelHandler { private final NettyAdaptor adaptor; private final NioHttpChannel nioHttpChannel; @@ -140,6 +140,11 @@ public List pollFlushOperations() { return copiedOperations; } + @Override + public boolean closeNow() { + return false; + } + @Override public void close() throws IOException { try { diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java index 634ea7b44af7..ed55007f3ba6 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpClient.java @@ -49,7 +49,7 @@ import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.ReadWriteHandler; +import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.WriteOperation; import org.elasticsearch.tasks.Task; @@ -207,7 +207,7 @@ public NioServerSocketChannel createServerChannel(NioSelector selector, ServerSo } } - private static class HttpClientHandler implements ReadWriteHandler { + private static class HttpClientHandler implements NioChannelHandler { private final NettyAdaptor adaptor; private final CountDownLatch latch; @@ -277,6 +277,11 @@ public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException { return bytesConsumed; } + @Override + public boolean closeNow() { + return false; + } + @Override public void close() throws IOException { try { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java index afb13ceff2ed..12f6b67d6724 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilter.java @@ -5,28 +5,49 @@ */ package org.elasticsearch.xpack.security.transport.nio; -import org.elasticsearch.common.Nullable; -import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.DelegatingHandler; +import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.xpack.security.transport.filter.IPFilter; -import java.util.function.Predicate; +import java.io.IOException; +import java.net.InetSocketAddress; -public final class NioIPFilter implements Predicate { +public final class NioIPFilter extends DelegatingHandler { + private final InetSocketAddress remoteAddress; private final IPFilter filter; private final String profile; + private boolean denied = false; - NioIPFilter(@Nullable IPFilter filter, String profile) { + NioIPFilter(NioChannelHandler delegate, InetSocketAddress remoteAddress, IPFilter filter, String profile) { + super(delegate); + this.remoteAddress = remoteAddress; this.filter = filter; this.profile = profile; } @Override - public boolean test(NioSocketChannel nioChannel) { - if (filter != null) { - return filter.accept(profile, nioChannel.getRemoteAddress()); + public void channelRegistered() { + if (filter.accept(profile, remoteAddress)) { + super.channelRegistered(); } else { - return true; + denied = true; } } + + @Override + public int consumeReads(InboundChannelBuffer channelBuffer) throws IOException { + if (denied) { + // Do not consume any reads if channel is disallowed + return 0; + } else { + return super.consumeReads(channelBuffer); + } + } + + @Override + public boolean closeNow() { + return denied; + } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java index 5f7037d57e3f..0ea50e844b61 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SSLChannelContext.java @@ -9,9 +9,9 @@ import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.WriteOperation; @@ -23,12 +23,11 @@ import java.util.concurrent.TimeUnit; import java.util.function.BiConsumer; import java.util.function.Consumer; -import java.util.function.Predicate; /** * Provides a TLS/SSL read/write layer over a channel. This context will use a {@link SSLDriver} to handshake * with the peer channel. Once the handshake is complete, any data from the peer channel will be decrypted - * before being passed to the {@link ReadWriteHandler}. Outbound data will be encrypted before being flushed + * before being passed to the {@link NioChannelHandler}. Outbound data will be encrypted before being flushed * to the channel. */ public final class SSLChannelContext extends SocketChannelContext { @@ -43,15 +42,14 @@ public final class SSLChannelContext extends SocketChannelContext { private Runnable closeTimeoutCanceller = DEFAULT_TIMEOUT_CANCELLER; SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, - ReadWriteHandler readWriteHandler, InboundChannelBuffer applicationBuffer) { + NioChannelHandler readWriteHandler, InboundChannelBuffer applicationBuffer) { this(channel, selector, exceptionHandler, sslDriver, readWriteHandler, InboundChannelBuffer.allocatingInstance(), - applicationBuffer, ALWAYS_ALLOW_CHANNEL); + applicationBuffer); } SSLChannelContext(NioSocketChannel channel, NioSelector selector, Consumer exceptionHandler, SSLDriver sslDriver, - ReadWriteHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer, - Predicate allowChannelPredicate) { - super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer, allowChannelPredicate); + NioChannelHandler readWriteHandler, InboundChannelBuffer networkReadBuffer, InboundChannelBuffer channelBuffer) { + super(channel, selector, exceptionHandler, readWriteHandler, channelBuffer); this.sslDriver = sslDriver; this.networkReadBuffer = networkReadBuffer; } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java index ddf465f81d90..bf476e5b7346 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioHttpServerTransport.java @@ -19,6 +19,7 @@ import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; @@ -44,7 +45,6 @@ public class SecurityNioHttpServerTransport extends NioHttpServerTransport { private final SecurityHttpExceptionHandler securityExceptionHandler; private final IPFilter ipFilter; - private final NioIPFilter nioIpFilter; private final SSLService sslService; private final SSLConfiguration sslConfiguration; private final boolean sslEnabled; @@ -56,7 +56,6 @@ public SecurityNioHttpServerTransport(Settings settings, NetworkService networkS super(settings, networkService, bigArrays, pageCacheRecycler, threadPool, xContentRegistry, dispatcher, nioGroupFactory); this.securityExceptionHandler = new SecurityHttpExceptionHandler(logger, lifecycle, (c, e) -> super.onException(c, e)); this.ipFilter = ipFilter; - this.nioIpFilter = new NioIPFilter(ipFilter, IPFilter.HTTP_PROFILE_NAME); this.sslEnabled = HTTP_SSL_ENABLED.get(settings); this.sslService = sslService; if (sslEnabled) { @@ -91,6 +90,13 @@ public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) NioHttpChannel httpChannel = new NioHttpChannel(channel); HttpReadWriteHandler httpHandler = new HttpReadWriteHandler(httpChannel,SecurityNioHttpServerTransport.this, handlingSettings, corsConfig, selector.getTaskScheduler(), threadPool::relativeTimeInNanos); + final NioChannelHandler handler; + if (ipFilter != null) { + handler = new NioIPFilter(httpHandler, httpChannel.getRemoteAddress(), ipFilter, IPFilter.HTTP_PROFILE_NAME); + } else { + handler = httpHandler; + } + InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator); Consumer exceptionHandler = (e) -> securityExceptionHandler.accept(httpChannel, e); @@ -107,10 +113,10 @@ public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) } SSLDriver sslDriver = new SSLDriver(sslEngine, pageAllocator, false); InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator); - context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, httpHandler, networkBuffer, - applicationBuffer, nioIpFilter); + context = new SSLChannelContext(httpChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer, + applicationBuffer); } else { - context = new BytesChannelContext(httpChannel, selector, exceptionHandler, httpHandler, networkBuffer, nioIpFilter); + context = new BytesChannelContext(httpChannel, selector, exceptionHandler, handler, networkBuffer); } httpChannel.setContext(context); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java index cf32809333e7..8d22d15612e7 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/nio/SecurityNioTransport.java @@ -18,6 +18,7 @@ import org.elasticsearch.nio.BytesChannelContext; import org.elasticsearch.nio.ChannelFactory; import org.elasticsearch.nio.InboundChannelBuffer; +import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; @@ -65,19 +66,19 @@ public class SecurityNioTransport extends NioTransport { private static final Logger logger = LogManager.getLogger(SecurityNioTransport.class); private final SecurityTransportExceptionHandler exceptionHandler; - private final IPFilter authenticator; + private final IPFilter ipFilter; private final SSLService sslService; private final Map profileConfiguration; private final boolean sslEnabled; public SecurityNioTransport(Settings settings, Version version, ThreadPool threadPool, NetworkService networkService, PageCacheRecycler pageCacheRecycler, NamedWriteableRegistry namedWriteableRegistry, - CircuitBreakerService circuitBreakerService, @Nullable final IPFilter authenticator, + CircuitBreakerService circuitBreakerService, @Nullable final IPFilter ipFilter, SSLService sslService, NioGroupFactory groupFactory) { super(settings, version, threadPool, networkService, pageCacheRecycler, namedWriteableRegistry, circuitBreakerService, groupFactory); this.exceptionHandler = new SecurityTransportExceptionHandler(logger, lifecycle, (c, e) -> super.onException(c, e)); - this.authenticator = authenticator; + this.ipFilter = ipFilter; this.sslService = sslService; this.sslEnabled = XPackSettings.TRANSPORT_SSL_ENABLED.get(settings); if (sslEnabled) { @@ -92,8 +93,8 @@ public SecurityNioTransport(Settings settings, Version version, ThreadPool threa @Override protected void doStart() { super.doStart(); - if (authenticator != null) { - authenticator.setBoundTransportAddress(boundAddress(), profileBoundAddresses()); + if (ipFilter != null) { + ipFilter.setBoundTransportAddress(boundAddress(), profileBoundAddresses()); } } @@ -132,7 +133,6 @@ private class SecurityTcpChannelFactory extends TcpChannelFactory { private final String profileName; private final boolean isClient; - private final NioIPFilter ipFilter; private SecurityTcpChannelFactory(ProfileSettings profileSettings, boolean isClient) { this(new RawChannelFactory(profileSettings.tcpNoDelay, @@ -146,13 +146,18 @@ private SecurityTcpChannelFactory(RawChannelFactory rawChannelFactory, String pr super(rawChannelFactory); this.profileName = profileName; this.isClient = isClient; - this.ipFilter = new NioIPFilter(authenticator, profileName); } @Override public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { NioTcpChannel nioChannel = new NioTcpChannel(isClient == false, profileName, channel); TcpReadWriteHandler readWriteHandler = new TcpReadWriteHandler(nioChannel, SecurityNioTransport.this); + final NioChannelHandler handler; + if (ipFilter != null) { + handler = new NioIPFilter(readWriteHandler, nioChannel.getRemoteAddress(), ipFilter, profileName); + } else { + handler = readWriteHandler; + } InboundChannelBuffer networkBuffer = new InboundChannelBuffer(pageAllocator); Consumer exceptionHandler = (e) -> onException(nioChannel, e); @@ -160,10 +165,10 @@ public NioTcpChannel createChannel(NioSelector selector, SocketChannel channel) if (sslEnabled) { SSLDriver sslDriver = new SSLDriver(createSSLEngine(channel), pageAllocator, isClient); InboundChannelBuffer applicationBuffer = new InboundChannelBuffer(pageAllocator); - context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, readWriteHandler, networkBuffer, - applicationBuffer, ipFilter); + context = new SSLChannelContext(nioChannel, selector, exceptionHandler, sslDriver, handler, networkBuffer, + applicationBuffer); } else { - context = new BytesChannelContext(nioChannel, selector, exceptionHandler, readWriteHandler, networkBuffer, ipFilter); + context = new BytesChannelContext(nioChannel, selector, exceptionHandler, handler, networkBuffer); } nioChannel.setContext(context); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java index 3df00018af42..e7612c0c0d7f 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/transport/nio/NioIPFilterTests.java @@ -13,7 +13,7 @@ import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.license.XPackLicenseState; -import org.elasticsearch.nio.NioSocketChannel; +import org.elasticsearch.nio.NioChannelHandler; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.Transport; import org.elasticsearch.xpack.security.audit.AuditTrailService; @@ -26,13 +26,15 @@ import java.util.Collections; import java.util.HashSet; -import static org.hamcrest.Matchers.is; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class NioIPFilterTests extends ESTestCase { - private NioIPFilter nioIPFilter; + private IPFilter ipFilter; + private String profile; @Before public void init() throws Exception { @@ -59,7 +61,7 @@ public void init() throws Exception { XPackLicenseState licenseState = mock(XPackLicenseState.class); when(licenseState.isIpFilteringAllowed()).thenReturn(true); AuditTrailService auditTrailService = new AuditTrailService(Collections.emptyList(), licenseState); - IPFilter ipFilter = new IPFilter(settings, auditTrailService, clusterSettings, licenseState); + ipFilter = new IPFilter(settings, auditTrailService, clusterSettings, licenseState); ipFilter.setBoundTransportAddress(transport.boundAddress(), transport.profileBoundAddresses()); if (isHttpEnabled) { HttpServerTransport httpTransport = mock(HttpServerTransport.class); @@ -70,21 +72,27 @@ public void init() throws Exception { } if (isHttpEnabled) { - nioIPFilter = new NioIPFilter(ipFilter, IPFilter.HTTP_PROFILE_NAME); + profile = IPFilter.HTTP_PROFILE_NAME; } else { - nioIPFilter = new NioIPFilter(ipFilter, "default"); + profile = "default"; } } - public void testThatFilteringWorksByIp() throws Exception { + public void testThatFilterCanPass() throws Exception { InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("127.0.0.1"), 12345); - NioSocketChannel channel1 = mock(NioSocketChannel.class); - when(channel1.getRemoteAddress()).thenReturn(localhostAddr); - assertThat(nioIPFilter.test(channel1), is(true)); + NioChannelHandler delegate = mock(NioChannelHandler.class); + NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile); + nioIPFilter.channelRegistered(); + verify(delegate).channelRegistered(); + assertFalse(nioIPFilter.closeNow()); + } - InetSocketAddress remoteAddr = new InetSocketAddress(InetAddresses.forString("10.0.0.8"), 12345); - NioSocketChannel channel2 = mock(NioSocketChannel.class); - when(channel2.getRemoteAddress()).thenReturn(remoteAddr); - assertThat(nioIPFilter.test(channel2), is(false)); + public void testThatFilterCanFail() throws Exception { + InetSocketAddress localhostAddr = new InetSocketAddress(InetAddresses.forString("10.0.0.8"), 12345); + NioChannelHandler delegate = mock(NioChannelHandler.class); + NioIPFilter nioIPFilter = new NioIPFilter(delegate, localhostAddr, ipFilter, profile); + nioIPFilter.channelRegistered(); + verify(delegate, times(0)).channelRegistered(); + assertTrue(nioIPFilter.closeNow()); } }