Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure local addresses aren't null #31440

Merged
merged 13 commits into from
Jun 21, 2018
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.Closeable;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
Expand Down Expand Up @@ -99,6 +100,11 @@ private Socket internalCreateChannel(NioSelector selector, SocketChannel rawChan
Socket channel = createChannel(selector, rawChannel);
assert channel.getContext() != null : "channel context should have been set on channel";
return channel;
} catch (UncheckedIOException e) {
// This can happen if getRemoteAddress throws IOException.
IOException cause = e.getCause();
closeRawChannel(rawChannel, cause);
throw cause;
} catch (Exception e) {
closeRawChannel(rawChannel, e);
throw e;
Expand Down
13 changes: 2 additions & 11 deletions libs/nio/src/main/java/org/elasticsearch/nio/NioChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.elasticsearch.nio;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.NetworkChannel;
import java.util.function.BiConsumer;
Expand All @@ -32,20 +31,10 @@
*/
public abstract class NioChannel {

private final InetSocketAddress localAddress;

NioChannel(NetworkChannel socketChannel) throws IOException {
this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress();
}

public boolean isOpen() {
return getContext().isOpen();
}

public InetSocketAddress getLocalAddress() {
return localAddress;
}

/**
* Adds a close listener to the channel. Multiple close listeners can be added. There is no guarantee
* about the order in which close listeners will be executed. If the channel is already closed, the
Expand All @@ -64,6 +53,8 @@ public void close() {
getContext().closeChannel();
}

public abstract InetSocketAddress getLocalAddress();

public abstract NetworkChannel getRawChannel();

public abstract ChannelContext<?> getContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@

package org.elasticsearch.nio;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.util.concurrent.atomic.AtomicBoolean;

public class NioServerSocketChannel extends NioChannel {

private final ServerSocketChannel socketChannel;
private final ServerSocketChannel serverSocketChannel;
private final AtomicBoolean contextSet = new AtomicBoolean(false);
private volatile InetSocketAddress localAddress;
private ServerChannelContext context;

public NioServerSocketChannel(ServerSocketChannel socketChannel) throws IOException {
super(socketChannel);
this.socketChannel = socketChannel;
public NioServerSocketChannel(ServerSocketChannel serverSocketChannel) {
this.serverSocketChannel = serverSocketChannel;
attemptToSetLocalAddress();
}

/**
Expand All @@ -48,9 +49,15 @@ public void setContext(ServerChannelContext context) {
}
}

@Override
public InetSocketAddress getLocalAddress() {
attemptToSetLocalAddress();
return localAddress;
}

@Override
public ServerSocketChannel getRawChannel() {
return socketChannel;
return serverSocketChannel;
}

@Override
Expand All @@ -64,4 +71,10 @@ public String toString() {
"localAddress=" + getLocalAddress() +
'}';
}

private void attemptToSetLocalAddress() {
if (localAddress == null) {
localAddress = (InetSocketAddress) serverSocketChannel.socket().getLocalSocketAddress();
}
}
}
21 changes: 17 additions & 4 deletions libs/nio/src/main/java/org/elasticsearch/nio/NioSocketChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,27 @@
package org.elasticsearch.nio;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.nio.channels.SocketChannel;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;

public class NioSocketChannel extends NioChannel {

private final InetSocketAddress remoteAddress;
private final AtomicBoolean contextSet = new AtomicBoolean(false);
private final SocketChannel socketChannel;
private final InetSocketAddress remoteAddress;
private volatile InetSocketAddress localAddress;
private SocketChannelContext context;

public NioSocketChannel(SocketChannel socketChannel) throws IOException {
super(socketChannel);
public NioSocketChannel(SocketChannel socketChannel) {
this.socketChannel = socketChannel;
this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress();
try {
this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

public void setContext(SocketChannelContext context) {
Expand All @@ -46,6 +51,14 @@ public void setContext(SocketChannelContext context) {
}
}

@Override
public InetSocketAddress getLocalAddress() {
if (localAddress == null) {
localAddress = (InetSocketAddress) socketChannel.socket().getLocalSocketAddress();
}
return localAddress;
}

@Override
public SocketChannel getRawChannel() {
return socketChannel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.junit.Before;

import java.io.IOException;
import java.net.ServerSocket;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel;
Expand Down Expand Up @@ -69,7 +70,9 @@ public void setUpHandler() throws IOException {
channel.setContext(context);
handler.handleRegistration(context);

NioServerSocketChannel serverChannel = new NioServerSocketChannel(mock(ServerSocketChannel.class));
ServerSocketChannel serverSocketChannel = mock(ServerSocketChannel.class);
when(serverSocketChannel.socket()).thenReturn(mock(ServerSocket.class));
NioServerSocketChannel serverChannel = new NioServerSocketChannel(serverSocketChannel);
serverContext = new DoNotRegisterServerContext(serverChannel, mock(NioSelector.class), mock(Consumer.class));
serverChannel.setContext(serverContext);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.nio.NioSocketChannel;

import java.io.IOException;
import java.nio.channels.SocketChannel;

public class NioHttpChannel extends NioSocketChannel implements HttpChannel {

NioHttpChannel(SocketChannel socketChannel) throws IOException {
NioHttpChannel(SocketChannel socketChannel) {
super(socketChannel);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class NioTcpChannel extends NioSocketChannel implements TcpChannel {

private final String profile;

public NioTcpChannel(String profile, SocketChannel socketChannel) throws IOException {
public NioTcpChannel(String profile, SocketChannel socketChannel) {
super(socketChannel);
this.profile = profile;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.transport.TcpServerChannel;

import java.io.IOException;
import java.nio.channels.ServerSocketChannel;

/**
Expand All @@ -34,12 +33,11 @@ public class NioTcpServerChannel extends NioServerSocketChannel implements TcpSe

private final String profile;

public NioTcpServerChannel(String profile, ServerSocketChannel socketChannel) throws IOException {
public NioTcpServerChannel(String profile, ServerSocketChannel socketChannel) {
super(socketChannel);
this.profile = profile;
}

@Override
public void close() {
getContext().closeChannel();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ public MockSocketChannel createChannel(NioSelector selector, SocketChannel chann

@Override
public MockServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException {
MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel, this, selector);
MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel);
Consumer<Exception> exceptionHandler = (e) -> logger.error(() ->
new ParameterizedMessage("exception from server channel caught on transport layer [{}]", channel), e);
ServerChannelContext context = new ServerChannelContext(nioServerChannel, this, selector, MockNioTransport.this::acceptChannel,
Expand Down Expand Up @@ -196,8 +196,7 @@ private static class MockServerChannel extends NioServerSocketChannel implements

private final String profile;

MockServerChannel(String profile, ServerSocketChannel channel, ChannelFactory<?, ?> channelFactory, NioSelector selector)
throws IOException {
MockServerChannel(String profile, ServerSocketChannel channel) {
super(channel);
this.profile = profile;
}
Expand All @@ -222,8 +221,7 @@ private static class MockSocketChannel extends NioSocketChannel implements TcpCh

private final String profile;

private MockSocketChannel(String profile, java.nio.channels.SocketChannel socketChannel, NioSelector selector)
throws IOException {
private MockSocketChannel(String profile, java.nio.channels.SocketChannel socketChannel, NioSelector selector) {
super(socketChannel);
this.profile = profile;
}
Expand Down