Skip to content

Commit

Permalink
Decouple nio constructs from the tcp transport (#27484)
Browse files Browse the repository at this point in the history
This is related to #27260. Currently, basic nio constructs (nio
channels, the channel factories, selector event handlers, etc) implement
logic that is specific to the tcp transport. For example, NioChannel
implements the TcpChannel interface. These nio constructs at some point
will also need to support other protocols (ex: http).

This commit separates the TcpTransport logic from the nio building
blocks.
  • Loading branch information
Tim-Brooks authored Nov 22, 2017
1 parent 9fbbc46 commit ef34555
Show file tree
Hide file tree
Showing 25 changed files with 374 additions and 190 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.SelectionKeyUtils;

import java.io.IOException;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
Expand All @@ -37,15 +35,10 @@
public class AcceptorEventHandler extends EventHandler {

private final Supplier<SocketSelector> selectorSupplier;
private final Consumer<NioChannel> acceptedChannelCallback;
private final OpenChannels openChannels;

public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier<SocketSelector> selectorSupplier,
Consumer<NioChannel> acceptedChannelCallback) {
super(logger, openChannels);
this.openChannels = openChannels;
public AcceptorEventHandler(Logger logger, Supplier<SocketSelector> selectorSupplier) {
super(logger);
this.selectorSupplier = selectorSupplier;
this.acceptedChannelCallback = acceptedChannelCallback;
}

/**
Expand All @@ -56,7 +49,6 @@ public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier<S
*/
void serverChannelRegistered(NioServerSocketChannel nioServerSocketChannel) {
SelectionKeyUtils.setAcceptInterested(nioServerSocketChannel);
openChannels.serverChannelOpened(nioServerSocketChannel);
}

/**
Expand All @@ -79,8 +71,7 @@ void acceptChannel(NioServerSocketChannel nioServerChannel) throws IOException {
ChannelFactory channelFactory = nioServerChannel.getChannelFactory();
SocketSelector selector = selectorSupplier.get();
NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel, selector);
openChannels.acceptedChannelOpened(nioSocketChannel);
acceptedChannelCallback.accept(nioSocketChannel);
nioServerChannel.getAcceptContext().accept(nioSocketChannel);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@
public abstract class EventHandler {

protected final Logger logger;
private final OpenChannels openChannels;

public EventHandler(Logger logger, OpenChannels openChannels) {
public EventHandler(Logger logger) {
this.logger = logger;
this.openChannels = openChannels;
}

/**
Expand Down Expand Up @@ -71,7 +69,6 @@ void uncaughtException(Exception exception) {
* @param channel that should be closed
*/
void handleClose(NioChannel channel) {
openChannels.channelClosed(channel);
try {
channel.closeFromSelector();
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transports;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpChannelFactory;
import org.elasticsearch.transport.nio.channel.TcpNioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpReadContext;
import org.elasticsearch.transport.nio.channel.TcpWriteContext;

Expand Down Expand Up @@ -65,12 +67,12 @@ public class NioTransport extends TcpTransport {
public static final Setting<Integer> NIO_ACCEPTOR_COUNT =
intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope);

protected final OpenChannels openChannels = new OpenChannels(logger);
private final ConcurrentMap<String, ChannelFactory> profileToChannelFactory = newConcurrentMap();
private final OpenChannels openChannels = new OpenChannels(logger);
private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap();
private final ArrayList<AcceptingSelector> acceptors = new ArrayList<>();
private final ArrayList<SocketSelector> socketSelectors = new ArrayList<>();
private RoundRobinSelectorSupplier clientSelectorSupplier;
private ChannelFactory clientChannelFactory;
private TcpChannelFactory clientChannelFactory;
private int acceptorNumber;

public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
Expand All @@ -84,17 +86,21 @@ public long getNumOpenServerConnections() {
}

@Override
protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException {
ChannelFactory channelFactory = this.profileToChannelFactory.get(name);
protected TcpNioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException {
TcpChannelFactory channelFactory = this.profileToChannelFactory.get(name);
AcceptingSelector selector = acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings));
return channelFactory.openNioServerSocketChannel(address, selector);
TcpNioServerSocketChannel serverChannel = channelFactory.openNioServerSocketChannel(address, selector);
openChannels.serverChannelOpened(serverChannel);
serverChannel.addCloseListener(ActionListener.wrap(() -> openChannels.channelClosed(serverChannel)));
return serverChannel;
}

@Override
protected NioChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
NioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get());
TcpNioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get());
openChannels.clientChannelOpened(channel);
channel.addCloseListener(ActionListener.wrap(() -> openChannels.channelClosed(channel)));
channel.addConnectListener(connectListener);
return channel;
}
Expand All @@ -119,14 +125,14 @@ protected void doStart() {

Consumer<NioSocketChannel> clientContextSetter = getContextSetter("client-socket");
clientSelectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
clientChannelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), clientContextSetter);
ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
clientChannelFactory = new TcpChannelFactory(clientProfileSettings, clientContextSetter, getServerContextSetter());

if (NetworkService.NETWORK_SERVER.get(settings)) {
int acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings);
for (int i = 0; i < acceptorCount; ++i) {
Supplier<SocketSelector> selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, openChannels, selectorSupplier,
this::serverAcceptedChannel);
AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, selectorSupplier);
AcceptingSelector acceptor = new AcceptingSelector(eventHandler);
acceptors.add(acceptor);
}
Expand All @@ -143,7 +149,8 @@ protected void doStart() {
for (ProfileSettings profileSettings : profileSettings) {
String profileName = profileSettings.profileName;
Consumer<NioSocketChannel> contextSetter = getContextSetter(profileName);
profileToChannelFactory.putIfAbsent(profileName, new ChannelFactory(profileSettings, contextSetter));
TcpChannelFactory factory = new TcpChannelFactory(profileSettings, contextSetter, getServerContextSetter());
profileToChannelFactory.putIfAbsent(profileName, factory);
bindServer(profileSettings);
}
}
Expand All @@ -169,14 +176,27 @@ protected void stopInternal() {
}

protected SocketEventHandler getSocketEventHandler() {
return new SocketEventHandler(logger, this::exceptionCaught, openChannels);
return new SocketEventHandler(logger);
}

final void exceptionCaught(NioSocketChannel channel, Exception exception) {
onException(channel, exception);
onException((TcpNioSocketChannel) channel, exception);
}

private Consumer<NioSocketChannel> getContextSetter(String profileName) {
return (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName,this)), new TcpWriteContext(c));
return (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName,this)), new TcpWriteContext(c),
this::exceptionCaught);
}

private void acceptChannel(NioSocketChannel channel) {
TcpNioSocketChannel tcpChannel = (TcpNioSocketChannel) channel;
openChannels.acceptedChannelOpened(tcpChannel);
tcpChannel.addCloseListener(ActionListener.wrap(() -> openChannels.channelClosed(channel)));
serverAcceptedChannel(tcpChannel);

}

private Consumer<NioServerSocketChannel> getServerContextSetter() {
return (c) -> c.setAcceptContext(this::acceptChannel);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpNioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel;

import java.util.ArrayList;
import java.util.HashSet;
Expand All @@ -38,17 +40,17 @@
public class OpenChannels implements Releasable {

// TODO: Maybe set concurrency levels?
private final ConcurrentMap<NioSocketChannel, Long> openClientChannels = newConcurrentMap();
private final ConcurrentMap<NioSocketChannel, Long> openAcceptedChannels = newConcurrentMap();
private final ConcurrentMap<NioServerSocketChannel, Long> openServerChannels = newConcurrentMap();
private final ConcurrentMap<TcpNioSocketChannel, Long> openClientChannels = newConcurrentMap();
private final ConcurrentMap<TcpNioSocketChannel, Long> openAcceptedChannels = newConcurrentMap();
private final ConcurrentMap<TcpNioServerSocketChannel, Long> openServerChannels = newConcurrentMap();

private final Logger logger;

public OpenChannels(Logger logger) {
this.logger = logger;
}

public void serverChannelOpened(NioServerSocketChannel channel) {
public void serverChannelOpened(TcpNioServerSocketChannel channel) {
boolean added = openServerChannels.putIfAbsent(channel, System.nanoTime()) == null;
if (added && logger.isTraceEnabled()) {
logger.trace("server channel opened: {}", channel);
Expand All @@ -59,7 +61,7 @@ public long serverChannelsCount() {
return openServerChannels.size();
}

public void acceptedChannelOpened(NioSocketChannel channel) {
public void acceptedChannelOpened(TcpNioSocketChannel channel) {
boolean added = openAcceptedChannels.putIfAbsent(channel, System.nanoTime()) == null;
if (added && logger.isTraceEnabled()) {
logger.trace("accepted channel opened: {}", channel);
Expand All @@ -70,14 +72,14 @@ public HashSet<NioSocketChannel> getAcceptedChannels() {
return new HashSet<>(openAcceptedChannels.keySet());
}

public void clientChannelOpened(NioSocketChannel channel) {
public void clientChannelOpened(TcpNioSocketChannel channel) {
boolean added = openClientChannels.putIfAbsent(channel, System.nanoTime()) == null;
if (added && logger.isTraceEnabled()) {
logger.trace("client channel opened: {}", channel);
}
}

public Map<NioSocketChannel, Long> getClientChannels() {
public Map<TcpNioSocketChannel, Long> getClientChannels() {
return openClientChannels;
}

Expand Down Expand Up @@ -105,7 +107,7 @@ public void closeServerChannels() {

@Override
public void close() {
Stream<NioChannel> channels = Stream.concat(openClientChannels.keySet().stream(), openAcceptedChannels.keySet().stream());
Stream<TcpChannel> channels = Stream.concat(openClientChannels.keySet().stream(), openAcceptedChannels.keySet().stream());
TcpChannel.closeChannels(channels.collect(Collectors.toList()), true);

openClientChannels.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,16 @@
import org.elasticsearch.transport.nio.channel.WriteContext;

import java.io.IOException;
import java.util.function.BiConsumer;

/**
* Event handler designed to handle events from non-server sockets
*/
public class SocketEventHandler extends EventHandler {

private final BiConsumer<NioSocketChannel, Exception> exceptionHandler;
private final Logger logger;

public SocketEventHandler(Logger logger, BiConsumer<NioSocketChannel, Exception> exceptionHandler, OpenChannels openChannels) {
super(logger, openChannels);
this.exceptionHandler = exceptionHandler;
public SocketEventHandler(Logger logger) {
super(logger);
this.logger = logger;
}

Expand Down Expand Up @@ -150,6 +147,6 @@ void genericChannelException(NioChannel channel, Exception exception) {
}

private void exceptionCaught(NioSocketChannel channel, Exception e) {
exceptionHandler.accept(channel, e);
channel.getExceptionContext().accept(channel, e);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel;

import java.io.IOException;

Expand All @@ -34,7 +35,7 @@ public TcpReadHandler(String profile, NioTransport transport) {
this.transport = transport;
}

public void handleMessage(BytesReference reference, NioSocketChannel channel, int messageBytesLength) {
public void handleMessage(BytesReference reference, TcpNioSocketChannel channel, int messageBytesLength) {
try {
transport.messageReceived(reference, channel, profile, channel.getRemoteAddress(), messageBytesLength);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,18 @@ public S getRawChannel() {
return socketChannel;
}

@Override
public void addCloseListener(ActionListener<Void> listener) {
closeContext.whenComplete(ActionListener.toBiConsumer(listener));
}

// Package visibility for testing
void setSelectionKey(SelectionKey selectionKey) {
this.selectionKey = selectionKey;
}

// Package visibility for testing

void closeRawChannel() throws IOException {
socketChannel.close();
}

@Override
public void addCloseListener(ActionListener<Void> listener) {
closeContext.whenComplete(ActionListener.toBiConsumer(listener));
}

@Override
public void setSoLinger(int value) throws IOException {
if (isOpen()) {
socketChannel.setOption(StandardSocketOptions.SO_LINGER, value);
}
}
}
Loading

0 comments on commit ef34555

Please sign in to comment.