Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple nio constructs from the tcp transport #27484

Merged
merged 9 commits into from
Nov 22, 2017
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.SelectionKeyUtils;

import java.io.IOException;
import java.util.function.Consumer;
import java.util.function.Supplier;

/**
Expand All @@ -37,15 +35,10 @@
public class AcceptorEventHandler extends EventHandler {

private final Supplier<SocketSelector> selectorSupplier;
private final Consumer<NioChannel> acceptedChannelCallback;
private final OpenChannels openChannels;

public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier<SocketSelector> selectorSupplier,
Consumer<NioChannel> acceptedChannelCallback) {
super(logger, openChannels);
this.openChannels = openChannels;
public AcceptorEventHandler(Logger logger, Supplier<SocketSelector> selectorSupplier) {
super(logger);
this.selectorSupplier = selectorSupplier;
this.acceptedChannelCallback = acceptedChannelCallback;
}

/**
Expand All @@ -56,7 +49,6 @@ public AcceptorEventHandler(Logger logger, OpenChannels openChannels, Supplier<S
*/
void serverChannelRegistered(NioServerSocketChannel nioServerSocketChannel) {
SelectionKeyUtils.setAcceptInterested(nioServerSocketChannel);
openChannels.serverChannelOpened(nioServerSocketChannel);
}

/**
Expand All @@ -79,8 +71,7 @@ void acceptChannel(NioServerSocketChannel nioServerChannel) throws IOException {
ChannelFactory channelFactory = nioServerChannel.getChannelFactory();
SocketSelector selector = selectorSupplier.get();
NioSocketChannel nioSocketChannel = channelFactory.acceptNioChannel(nioServerChannel, selector);
openChannels.acceptedChannelOpened(nioSocketChannel);
acceptedChannelCallback.accept(nioSocketChannel);
nioServerChannel.getAcceptContext().accept(nioSocketChannel);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,9 @@
public abstract class EventHandler {

protected final Logger logger;
private final OpenChannels openChannels;

public EventHandler(Logger logger, OpenChannels openChannels) {
public EventHandler(Logger logger) {
this.logger = logger;
this.openChannels = openChannels;
}

/**
Expand Down Expand Up @@ -71,7 +69,6 @@ void uncaughtException(Exception exception) {
* @param channel that should be closed
*/
void handleClose(NioChannel channel) {
openChannels.channelClosed(channel);
try {
channel.closeFromSelector();
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TcpTransport;
import org.elasticsearch.transport.Transports;
import org.elasticsearch.transport.nio.channel.ChannelFactory;
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpChannelFactory;
import org.elasticsearch.transport.nio.channel.TcpNioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpReadContext;
import org.elasticsearch.transport.nio.channel.TcpWriteContext;

Expand Down Expand Up @@ -65,12 +67,12 @@ public class NioTransport extends TcpTransport {
public static final Setting<Integer> NIO_ACCEPTOR_COUNT =
intSetting("transport.nio.acceptor_count", 1, 1, Setting.Property.NodeScope);

protected final OpenChannels openChannels = new OpenChannels(logger);
private final ConcurrentMap<String, ChannelFactory> profileToChannelFactory = newConcurrentMap();
private final OpenChannels openChannels = new OpenChannels(logger);
private final ConcurrentMap<String, TcpChannelFactory> profileToChannelFactory = newConcurrentMap();
private final ArrayList<AcceptingSelector> acceptors = new ArrayList<>();
private final ArrayList<SocketSelector> socketSelectors = new ArrayList<>();
private RoundRobinSelectorSupplier clientSelectorSupplier;
private ChannelFactory clientChannelFactory;
private TcpChannelFactory clientChannelFactory;
private int acceptorNumber;

public NioTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays,
Expand All @@ -84,17 +86,21 @@ public long getNumOpenServerConnections() {
}

@Override
protected NioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException {
ChannelFactory channelFactory = this.profileToChannelFactory.get(name);
protected TcpNioServerSocketChannel bind(String name, InetSocketAddress address) throws IOException {
TcpChannelFactory channelFactory = this.profileToChannelFactory.get(name);
AcceptingSelector selector = acceptors.get(++acceptorNumber % NioTransport.NIO_ACCEPTOR_COUNT.get(settings));
return channelFactory.openNioServerSocketChannel(address, selector);
TcpNioServerSocketChannel serverChannel = channelFactory.openNioServerSocketChannel(address, selector);
openChannels.serverChannelOpened(serverChannel);
serverChannel.addCloseListener(new RemoveClosedChannelListener(serverChannel));
return serverChannel;
}

@Override
protected NioChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
protected TcpNioSocketChannel initiateChannel(DiscoveryNode node, TimeValue connectTimeout, ActionListener<Void> connectListener)
throws IOException {
NioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get());
TcpNioSocketChannel channel = clientChannelFactory.openNioChannel(node.getAddress().address(), clientSelectorSupplier.get());
openChannels.clientChannelOpened(channel);
channel.addCloseListener(new RemoveClosedChannelListener(channel));
channel.addConnectListener(connectListener);
return channel;
}
Expand All @@ -119,14 +125,14 @@ protected void doStart() {

Consumer<NioSocketChannel> clientContextSetter = getContextSetter("client-socket");
clientSelectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
clientChannelFactory = new ChannelFactory(new ProfileSettings(settings, "default"), clientContextSetter);
ProfileSettings clientProfileSettings = new ProfileSettings(settings, "default");
clientChannelFactory = new TcpChannelFactory(clientProfileSettings, clientContextSetter, getServerContextSetter());

if (NetworkService.NETWORK_SERVER.get(settings)) {
int acceptorCount = NioTransport.NIO_ACCEPTOR_COUNT.get(settings);
for (int i = 0; i < acceptorCount; ++i) {
Supplier<SocketSelector> selectorSupplier = new RoundRobinSelectorSupplier(socketSelectors);
AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, openChannels, selectorSupplier,
this::serverAcceptedChannel);
AcceptorEventHandler eventHandler = new AcceptorEventHandler(logger, selectorSupplier);
AcceptingSelector acceptor = new AcceptingSelector(eventHandler);
acceptors.add(acceptor);
}
Expand All @@ -143,7 +149,8 @@ protected void doStart() {
for (ProfileSettings profileSettings : profileSettings) {
String profileName = profileSettings.profileName;
Consumer<NioSocketChannel> contextSetter = getContextSetter(profileName);
profileToChannelFactory.putIfAbsent(profileName, new ChannelFactory(profileSettings, contextSetter));
TcpChannelFactory factory = new TcpChannelFactory(profileSettings, contextSetter, getServerContextSetter());
profileToChannelFactory.putIfAbsent(profileName, factory);
bindServer(profileSettings);
}
}
Expand All @@ -169,14 +176,48 @@ protected void stopInternal() {
}

protected SocketEventHandler getSocketEventHandler() {
return new SocketEventHandler(logger, this::exceptionCaught, openChannels);
return new SocketEventHandler(logger);
}

final void exceptionCaught(NioSocketChannel channel, Exception exception) {
onException(channel, exception);
onException((TcpNioSocketChannel) channel, exception);
}

private Consumer<NioSocketChannel> getContextSetter(String profileName) {
return (c) -> c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName,this)), new TcpWriteContext(c));
return (c) -> {
c.setContexts(new TcpReadContext(c, new TcpReadHandler(profileName,this)), new TcpWriteContext(c));
c.setExceptionHandler(this::exceptionCaught);
};
}

private void acceptChannel(NioSocketChannel channel) {
TcpNioSocketChannel tcpChannel = (TcpNioSocketChannel) channel;
openChannels.acceptedChannelOpened(tcpChannel);
tcpChannel.addCloseListener(new RemoveClosedChannelListener(channel));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rather use ActionListener.wrap(v -> openChannels.channelClosed(channel), e -> openChannels.channelClosed(channel)) then we don't need to define this extra class

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And better than that, I can used the Runnable version ActionListener.wrap(() -> openChannels.channelClosed(channel))

serverAcceptedChannel(tcpChannel);

}

private Consumer<NioServerSocketChannel> getServerContextSetter() {
return (c) -> c.setAcceptContext(this::acceptChannel);
}

private class RemoveClosedChannelListener implements ActionListener<Void> {

private final NioChannel channel;

private RemoveClosedChannelListener(NioChannel channel) {
this.channel = channel;
}

@Override
public void onResponse(Void aVoid) {
openChannels.channelClosed(channel);
}

@Override
public void onFailure(Exception e) {
openChannels.channelClosed(channel);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.elasticsearch.transport.nio.channel.NioChannel;
import org.elasticsearch.transport.nio.channel.NioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpNioServerSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel;

import java.util.ArrayList;
import java.util.HashSet;
Expand All @@ -38,17 +40,17 @@
public class OpenChannels implements Releasable {

// TODO: Maybe set concurrency levels?
private final ConcurrentMap<NioSocketChannel, Long> openClientChannels = newConcurrentMap();
private final ConcurrentMap<NioSocketChannel, Long> openAcceptedChannels = newConcurrentMap();
private final ConcurrentMap<NioServerSocketChannel, Long> openServerChannels = newConcurrentMap();
private final ConcurrentMap<TcpNioSocketChannel, Long> openClientChannels = newConcurrentMap();
private final ConcurrentMap<TcpNioSocketChannel, Long> openAcceptedChannels = newConcurrentMap();
private final ConcurrentMap<TcpNioServerSocketChannel, Long> openServerChannels = newConcurrentMap();

private final Logger logger;

public OpenChannels(Logger logger) {
this.logger = logger;
}

public void serverChannelOpened(NioServerSocketChannel channel) {
public void serverChannelOpened(TcpNioServerSocketChannel channel) {
boolean added = openServerChannels.putIfAbsent(channel, System.nanoTime()) == null;
if (added && logger.isTraceEnabled()) {
logger.trace("server channel opened: {}", channel);
Expand All @@ -59,7 +61,7 @@ public long serverChannelsCount() {
return openServerChannels.size();
}

public void acceptedChannelOpened(NioSocketChannel channel) {
public void acceptedChannelOpened(TcpNioSocketChannel channel) {
boolean added = openAcceptedChannels.putIfAbsent(channel, System.nanoTime()) == null;
if (added && logger.isTraceEnabled()) {
logger.trace("accepted channel opened: {}", channel);
Expand All @@ -70,14 +72,14 @@ public HashSet<NioSocketChannel> getAcceptedChannels() {
return new HashSet<>(openAcceptedChannels.keySet());
}

public void clientChannelOpened(NioSocketChannel channel) {
public void clientChannelOpened(TcpNioSocketChannel channel) {
boolean added = openClientChannels.putIfAbsent(channel, System.nanoTime()) == null;
if (added && logger.isTraceEnabled()) {
logger.trace("client channel opened: {}", channel);
}
}

public Map<NioSocketChannel, Long> getClientChannels() {
public Map<TcpNioSocketChannel, Long> getClientChannels() {
return openClientChannels;
}

Expand Down Expand Up @@ -105,7 +107,7 @@ public void closeServerChannels() {

@Override
public void close() {
Stream<NioChannel> channels = Stream.concat(openClientChannels.keySet().stream(), openAcceptedChannels.keySet().stream());
Stream<TcpChannel> channels = Stream.concat(openClientChannels.keySet().stream(), openAcceptedChannels.keySet().stream());
TcpChannel.closeChannels(channels.collect(Collectors.toList()), true);

openClientChannels.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,16 @@
import org.elasticsearch.transport.nio.channel.WriteContext;

import java.io.IOException;
import java.util.function.BiConsumer;

/**
* Event handler designed to handle events from non-server sockets
*/
public class SocketEventHandler extends EventHandler {

private final BiConsumer<NioSocketChannel, Exception> exceptionHandler;
private final Logger logger;

public SocketEventHandler(Logger logger, BiConsumer<NioSocketChannel, Exception> exceptionHandler, OpenChannels openChannels) {
super(logger, openChannels);
this.exceptionHandler = exceptionHandler;
public SocketEventHandler(Logger logger) {
super(logger);
this.logger = logger;
}

Expand Down Expand Up @@ -150,6 +147,6 @@ void genericChannelException(NioChannel channel, Exception exception) {
}

private void exceptionCaught(NioSocketChannel channel, Exception e) {
exceptionHandler.accept(channel, e);
channel.getExceptionHandler().accept(channel, e);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.transport.nio.channel.NioSocketChannel;
import org.elasticsearch.transport.nio.channel.TcpNioSocketChannel;

import java.io.IOException;

Expand All @@ -34,7 +35,7 @@ public TcpReadHandler(String profile, NioTransport transport) {
this.transport = transport;
}

public void handleMessage(BytesReference reference, NioSocketChannel channel, int messageBytesLength) {
public void handleMessage(BytesReference reference, TcpNioSocketChannel channel, int messageBytesLength) {
try {
transport.messageReceived(reference, channel, profile, channel.getRemoteAddress(), messageBytesLength);
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,25 +137,18 @@ public S getRawChannel() {
return socketChannel;
}

@Override
public void addCloseListener(ActionListener<Void> listener) {
closeContext.whenComplete(ActionListener.toBiConsumer(listener));
}

// Package visibility for testing
void setSelectionKey(SelectionKey selectionKey) {
this.selectionKey = selectionKey;
}

// Package visibility for testing

void closeRawChannel() throws IOException {
socketChannel.close();
}

@Override
public void addCloseListener(ActionListener<Void> listener) {
closeContext.whenComplete(ActionListener.toBiConsumer(listener));
}

@Override
public void setSoLinger(int value) throws IOException {
if (isOpen()) {
socketChannel.setOption(StandardSocketOptions.SO_LINGER, value);
}
}
}
Loading