Skip to content

Commit

Permalink
Ensure local addresses aren't null (#31440)
Browse files Browse the repository at this point in the history
Currently we set local addresses on the creation time of a NioChannel.
However, this may return null as the local address may not have been
set yet. An example is the local address has not been set on a client
channel as the connection process is not yet complete.

This PR modifies the getter to set the local field if it is currently null.
  • Loading branch information
Tim-Brooks authored Jun 21, 2018
1 parent 00283a6 commit 86423f9
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.Closeable;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
Expand Down Expand Up @@ -99,6 +100,11 @@ private Socket internalCreateChannel(NioSelector selector, SocketChannel rawChan
Socket channel = createChannel(selector, rawChannel);
assert channel.getContext() != null : "channel context should have been set on channel";
return channel;
} catch (UncheckedIOException e) {
// This can happen if getRemoteAddress throws IOException.
IOException cause = e.getCause();
closeRawChannel(rawChannel, cause);
throw cause;
} catch (Exception e) {
closeRawChannel(rawChannel, e);
throw e;
Expand Down
13 changes: 2 additions & 11 deletions libs/nio/src/main/java/org/elasticsearch/nio/NioChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

package org.elasticsearch.nio;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.NetworkChannel;
import java.util.function.BiConsumer;
Expand All @@ -32,20 +31,10 @@
*/
public abstract class NioChannel {

private final InetSocketAddress localAddress;

NioChannel(NetworkChannel socketChannel) throws IOException {
this.localAddress = (InetSocketAddress) socketChannel.getLocalAddress();
}

public boolean isOpen() {
return getContext().isOpen();
}

public InetSocketAddress getLocalAddress() {
return localAddress;
}

/**
* Adds a close listener to the channel. Multiple close listeners can be added. There is no guarantee
* about the order in which close listeners will be executed. If the channel is already closed, the
Expand All @@ -64,6 +53,8 @@ public void close() {
getContext().closeChannel();
}

public abstract InetSocketAddress getLocalAddress();

public abstract NetworkChannel getRawChannel();

public abstract ChannelContext<?> getContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@

package org.elasticsearch.nio;

import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.channels.ServerSocketChannel;
import java.util.concurrent.atomic.AtomicBoolean;

public class NioServerSocketChannel extends NioChannel {

private final ServerSocketChannel socketChannel;
private final ServerSocketChannel serverSocketChannel;
private final AtomicBoolean contextSet = new AtomicBoolean(false);
private volatile InetSocketAddress localAddress;
private ServerChannelContext context;

public NioServerSocketChannel(ServerSocketChannel socketChannel) throws IOException {
super(socketChannel);
this.socketChannel = socketChannel;
public NioServerSocketChannel(ServerSocketChannel serverSocketChannel) {
this.serverSocketChannel = serverSocketChannel;
attemptToSetLocalAddress();
}

/**
Expand All @@ -48,9 +49,15 @@ public void setContext(ServerChannelContext context) {
}
}

@Override
public InetSocketAddress getLocalAddress() {
attemptToSetLocalAddress();
return localAddress;
}

@Override
public ServerSocketChannel getRawChannel() {
return socketChannel;
return serverSocketChannel;
}

@Override
Expand All @@ -64,4 +71,10 @@ public String toString() {
"localAddress=" + getLocalAddress() +
'}';
}

private void attemptToSetLocalAddress() {
if (localAddress == null) {
localAddress = (InetSocketAddress) serverSocketChannel.socket().getLocalSocketAddress();
}
}
}
21 changes: 17 additions & 4 deletions libs/nio/src/main/java/org/elasticsearch/nio/NioSocketChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,27 @@
package org.elasticsearch.nio;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.nio.channels.SocketChannel;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.BiConsumer;

public class NioSocketChannel extends NioChannel {

private final InetSocketAddress remoteAddress;
private final AtomicBoolean contextSet = new AtomicBoolean(false);
private final SocketChannel socketChannel;
private final InetSocketAddress remoteAddress;
private volatile InetSocketAddress localAddress;
private SocketChannelContext context;

public NioSocketChannel(SocketChannel socketChannel) throws IOException {
super(socketChannel);
public NioSocketChannel(SocketChannel socketChannel) {
this.socketChannel = socketChannel;
this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress();
try {
this.remoteAddress = (InetSocketAddress) socketChannel.getRemoteAddress();
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

public void setContext(SocketChannelContext context) {
Expand All @@ -46,6 +51,14 @@ public void setContext(SocketChannelContext context) {
}
}

@Override
public InetSocketAddress getLocalAddress() {
if (localAddress == null) {
localAddress = (InetSocketAddress) socketChannel.socket().getLocalSocketAddress();
}
return localAddress;
}

@Override
public SocketChannel getRawChannel() {
return socketChannel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.junit.Before;

import java.io.IOException;
import java.net.ServerSocket;
import java.nio.channels.CancelledKeyException;
import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel;
Expand Down Expand Up @@ -69,7 +70,9 @@ public void setUpHandler() throws IOException {
channel.setContext(context);
handler.handleRegistration(context);

NioServerSocketChannel serverChannel = new NioServerSocketChannel(mock(ServerSocketChannel.class));
ServerSocketChannel serverSocketChannel = mock(ServerSocketChannel.class);
when(serverSocketChannel.socket()).thenReturn(mock(ServerSocket.class));
NioServerSocketChannel serverChannel = new NioServerSocketChannel(serverSocketChannel);
serverContext = new DoNotRegisterServerContext(serverChannel, mock(NioSelector.class), mock(Consumer.class));
serverChannel.setContext(serverContext);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,11 @@
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.nio.NioSocketChannel;

import java.io.IOException;
import java.nio.channels.SocketChannel;

public class NioHttpChannel extends NioSocketChannel implements HttpChannel {

NioHttpChannel(SocketChannel socketChannel) throws IOException {
NioHttpChannel(SocketChannel socketChannel) {
super(socketChannel);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class NioTcpChannel extends NioSocketChannel implements TcpChannel {

private final String profile;

public NioTcpChannel(String profile, SocketChannel socketChannel) throws IOException {
public NioTcpChannel(String profile, SocketChannel socketChannel) {
super(socketChannel);
this.profile = profile;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.elasticsearch.nio.NioServerSocketChannel;
import org.elasticsearch.transport.TcpServerChannel;

import java.io.IOException;
import java.nio.channels.ServerSocketChannel;

/**
Expand All @@ -34,12 +33,11 @@ public class NioTcpServerChannel extends NioServerSocketChannel implements TcpSe

private final String profile;

public NioTcpServerChannel(String profile, ServerSocketChannel socketChannel) throws IOException {
public NioTcpServerChannel(String profile, ServerSocketChannel socketChannel) {
super(socketChannel);
this.profile = profile;
}

@Override
public void close() {
getContext().closeChannel();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ public MockSocketChannel createChannel(NioSelector selector, SocketChannel chann

@Override
public MockServerChannel createServerChannel(NioSelector selector, ServerSocketChannel channel) throws IOException {
MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel, this, selector);
MockServerChannel nioServerChannel = new MockServerChannel(profileName, channel);
Consumer<Exception> exceptionHandler = (e) -> logger.error(() ->
new ParameterizedMessage("exception from server channel caught on transport layer [{}]", channel), e);
ServerChannelContext context = new ServerChannelContext(nioServerChannel, this, selector, MockNioTransport.this::acceptChannel,
Expand Down Expand Up @@ -196,8 +196,7 @@ private static class MockServerChannel extends NioServerSocketChannel implements

private final String profile;

MockServerChannel(String profile, ServerSocketChannel channel, ChannelFactory<?, ?> channelFactory, NioSelector selector)
throws IOException {
MockServerChannel(String profile, ServerSocketChannel channel) {
super(channel);
this.profile = profile;
}
Expand All @@ -222,8 +221,7 @@ private static class MockSocketChannel extends NioSocketChannel implements TcpCh

private final String profile;

private MockSocketChannel(String profile, java.nio.channels.SocketChannel socketChannel, NioSelector selector)
throws IOException {
private MockSocketChannel(String profile, java.nio.channels.SocketChannel socketChannel, NioSelector selector) {
super(socketChannel);
this.profile = profile;
}
Expand Down

0 comments on commit 86423f9

Please sign in to comment.