Skip to content

Commit

Permalink
Abstracting outbound side of transport (#13293)
Browse files Browse the repository at this point in the history
* Abstracting outbound side of transport

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Making outbound handler protocol dependent via inbound handler

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Fixing precommit

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Addressing comments

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Fixing precommit

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Cleaning up code

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Addressing comments

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Cleaning up

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Addressing comments

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Abstracting InboundHandlerTests

Signed-off-by: Vacha Shah <vachshah@amazon.com>

* Abstracting TransportLoggerTests

Signed-off-by: Vacha Shah <vachshah@amazon.com>

---------

Signed-off-by: Vacha Shah <vachshah@amazon.com>
  • Loading branch information
VachaShah authored May 13, 2024
1 parent 079cef5 commit 14f1c43
Show file tree
Hide file tree
Showing 28 changed files with 1,061 additions and 645 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x]
### Added
- Add useCompoundFile index setting ([#13478](https://github.com/opensearch-project/OpenSearch/pull/13478))
- Make outbound side of transport protocol dependent ([#13293](https://github.com/opensearch-project/OpenSearch/pull/13293))

### Dependencies
- Bump `com.github.spullara.mustache.java:compiler` from 0.9.10 to 0.9.13 ([#13329](https://github.com/opensearch-project/OpenSearch/pull/13329), [#13559](https://github.com/opensearch-project/OpenSearch/pull/13559))
Expand Down
13 changes: 13 additions & 0 deletions server/src/main/java/org/opensearch/transport/InboundHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@

package org.opensearch.transport;

import org.opensearch.Version;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.BigArrays;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.telemetry.tracing.Tracer;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -57,7 +59,12 @@ public class InboundHandler {
private final Map<String, ProtocolMessageHandler> protocolMessageHandlers;

InboundHandler(
String nodeName,
Version version,
String[] features,
StatsTracker statsTracker,
ThreadPool threadPool,
BigArrays bigArrays,
OutboundHandler outboundHandler,
NamedWriteableRegistry namedWriteableRegistry,
TransportHandshaker handshaker,
Expand All @@ -70,7 +77,12 @@ public class InboundHandler {
this.protocolMessageHandlers = Map.of(
NativeInboundMessage.NATIVE_PROTOCOL,
new NativeMessageHandler(
nodeName,
version,
features,
statsTracker,
threadPool,
bigArrays,
outboundHandler,
namedWriteableRegistry,
handshaker,
Expand All @@ -83,6 +95,7 @@ public class InboundHandler {
}

void setMessageListener(TransportMessageListener listener) {
protocolMessageHandlers.values().forEach(handler -> handler.setMessageListener(listener));
if (messageListener == TransportMessageListener.NOOP_LISTENER) {
messageListener = listener;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.util.BytesRef;
import org.opensearch.Version;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.common.io.stream.ByteBufferStreamInput;
Expand All @@ -52,6 +53,7 @@
import org.opensearch.telemetry.tracing.channels.TraceableTcpTransportChannel;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.nativeprotocol.NativeInboundMessage;
import org.opensearch.transport.nativeprotocol.NativeOutboundHandler;

import java.io.EOFException;
import java.io.IOException;
Expand All @@ -72,7 +74,7 @@ public class NativeMessageHandler implements ProtocolMessageHandler {
private static final Logger logger = LogManager.getLogger(NativeMessageHandler.class);

private final ThreadPool threadPool;
private final OutboundHandler outboundHandler;
private final NativeOutboundHandler outboundHandler;
private final NamedWriteableRegistry namedWriteableRegistry;
private final TransportHandshaker handshaker;
private final TransportKeepAlive keepAlive;
Expand All @@ -82,7 +84,12 @@ public class NativeMessageHandler implements ProtocolMessageHandler {
private final Tracer tracer;

NativeMessageHandler(
String nodeName,
Version version,
String[] features,
StatsTracker statsTracker,
ThreadPool threadPool,
BigArrays bigArrays,
OutboundHandler outboundHandler,
NamedWriteableRegistry namedWriteableRegistry,
TransportHandshaker handshaker,
Expand All @@ -92,7 +99,7 @@ public class NativeMessageHandler implements ProtocolMessageHandler {
TransportKeepAlive keepAlive
) {
this.threadPool = threadPool;
this.outboundHandler = outboundHandler;
this.outboundHandler = new NativeOutboundHandler(nodeName, version, features, statsTracker, threadPool, bigArrays, outboundHandler);
this.namedWriteableRegistry = namedWriteableRegistry;
this.handshaker = handshaker;
this.requestHandlers = requestHandlers;
Expand Down Expand Up @@ -492,4 +499,9 @@ public void onFailure(Exception e) {
}
}

@Override
public void setMessageListener(TransportMessageListener listener) {
outboundHandler.setMessageListener(listener);
}

}
171 changes: 13 additions & 158 deletions server/src/main/java/org/opensearch/transport/OutboundHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,164 +35,47 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.CheckedSupplier;
import org.opensearch.common.io.stream.ReleasableBytesStreamOutput;
import org.opensearch.common.lease.Releasable;
import org.opensearch.common.lease.Releasables;
import org.opensearch.common.network.CloseableChannel;
import org.opensearch.common.transport.NetworkExceptionHelper;
import org.opensearch.common.util.BigArrays;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.util.io.IOUtils;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.action.NotifyOnceListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.transport.TransportAddress;
import org.opensearch.core.transport.TransportResponse;
import org.opensearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.Set;

/**
* Outbound data handler
*
* @opensearch.internal
*/
final class OutboundHandler {
public final class OutboundHandler {

private static final Logger logger = LogManager.getLogger(OutboundHandler.class);

private final String nodeName;
private final Version version;
private final String[] features;
private final StatsTracker statsTracker;
private final ThreadPool threadPool;
private final BigArrays bigArrays;
private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;

OutboundHandler(
String nodeName,
Version version,
String[] features,
StatsTracker statsTracker,
ThreadPool threadPool,
BigArrays bigArrays
) {
this.nodeName = nodeName;
this.version = version;
this.features = features;
public OutboundHandler(StatsTracker statsTracker, ThreadPool threadPool) {
this.statsTracker = statsTracker;
this.threadPool = threadPool;
this.bigArrays = bigArrays;
}

void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener<Void> listener) {
SendContext sendContext = new SendContext(channel, () -> bytes, listener);
SendContext sendContext = new SendContext(statsTracker, channel, () -> bytes, listener);
try {
internalSend(channel, sendContext);
sendBytes(channel, sendContext);
} catch (IOException e) {
// This should not happen as the bytes are already serialized
throw new AssertionError(e);
}
}

/**
* Sends the request to the given channel. This method should be used to send {@link TransportRequest}
* objects back to the caller.
*/
void sendRequest(
final DiscoveryNode node,
final TcpChannel channel,
final long requestId,
final String action,
final TransportRequest request,
final TransportRequestOptions options,
final Version channelVersion,
final boolean compressRequest,
final boolean isHandshake
) throws IOException, TransportException {
Version version = Version.min(this.version, channelVersion);
OutboundMessage.Request message = new OutboundMessage.Request(
threadPool.getThreadContext(),
features,
request,
version,
action,
requestId,
isHandshake,
compressRequest
);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onRequestSent(node, requestId, action, request, options));
sendMessage(channel, message, listener);
}

/**
* Sends the response to the given channel. This method should be used to send {@link TransportResponse}
* objects back to the caller.
*
* @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses
*/
void sendResponse(
final Version nodeVersion,
final Set<String> features,
final TcpChannel channel,
final long requestId,
final String action,
final TransportResponse response,
final boolean compress,
final boolean isHandshake
) throws IOException {
Version version = Version.min(this.version, nodeVersion);
OutboundMessage.Response message = new OutboundMessage.Response(
threadPool.getThreadContext(),
features,
response,
version,
requestId,
isHandshake,
compress
);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
sendMessage(channel, message, listener);
}

/**
* Sends back an error response to the caller via the given channel
*/
void sendErrorResponse(
final Version nodeVersion,
final Set<String> features,
final TcpChannel channel,
final long requestId,
final String action,
final Exception error
) throws IOException {
Version version = Version.min(this.version, nodeVersion);
TransportAddress address = new TransportAddress(channel.getLocalAddress());
RemoteTransportException tx = new RemoteTransportException(nodeName, address, action, error);
OutboundMessage.Response message = new OutboundMessage.Response(
threadPool.getThreadContext(),
features,
tx,
version,
requestId,
false,
false
);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, error));
sendMessage(channel, message, listener);
}

private void sendMessage(TcpChannel channel, OutboundMessage networkMessage, ActionListener<Void> listener) throws IOException {
MessageSerializer serializer = new MessageSerializer(networkMessage, bigArrays);
SendContext sendContext = new SendContext(channel, serializer, listener, serializer);
internalSend(channel, sendContext);
}

private void internalSend(TcpChannel channel, SendContext sendContext) throws IOException {
public void sendBytes(TcpChannel channel, SendContext sendContext) throws IOException {
channel.getChannelStats().markAccessed(threadPool.relativeTimeInMillis());
BytesReference reference = sendContext.get();
// stash thread context so that channel event loop is not polluted by thread context
Expand All @@ -205,59 +88,30 @@ private void internalSend(TcpChannel channel, SendContext sendContext) throws IO
}
}

void setMessageListener(TransportMessageListener listener) {
if (messageListener == TransportMessageListener.NOOP_LISTENER) {
messageListener = listener;
} else {
throw new IllegalStateException("Cannot set message listener twice");
}
}

/**
* Internal message serializer
*
* @opensearch.internal
*/
private static class MessageSerializer implements CheckedSupplier<BytesReference, IOException>, Releasable {

private final OutboundMessage message;
private final BigArrays bigArrays;
private volatile ReleasableBytesStreamOutput bytesStreamOutput;

private MessageSerializer(OutboundMessage message, BigArrays bigArrays) {
this.message = message;
this.bigArrays = bigArrays;
}

@Override
public BytesReference get() throws IOException {
bytesStreamOutput = new ReleasableBytesStreamOutput(bigArrays);
return message.serialize(bytesStreamOutput);
}

@Override
public void close() {
IOUtils.closeWhileHandlingException(bytesStreamOutput);
}
}

private class SendContext extends NotifyOnceListener<Void> implements CheckedSupplier<BytesReference, IOException> {

public static class SendContext extends NotifyOnceListener<Void> implements CheckedSupplier<BytesReference, IOException> {
private final StatsTracker statsTracker;
private final TcpChannel channel;
private final CheckedSupplier<BytesReference, IOException> messageSupplier;
private final ActionListener<Void> listener;
private final Releasable optionalReleasable;
private long messageSize = -1;

private SendContext(
SendContext(
StatsTracker statsTracker,
TcpChannel channel,
CheckedSupplier<BytesReference, IOException> messageSupplier,
ActionListener<Void> listener
) {
this(channel, messageSupplier, listener, null);
this(statsTracker, channel, messageSupplier, listener, null);
}

private SendContext(
public SendContext(
StatsTracker statsTracker,
TcpChannel channel,
CheckedSupplier<BytesReference, IOException> messageSupplier,
ActionListener<Void> listener,
Expand All @@ -267,6 +121,7 @@ private SendContext(
this.messageSupplier = messageSupplier;
this.listener = listener;
this.optionalReleasable = optionalReleasable;
this.statsTracker = statsTracker;
}

public BytesReference get() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,25 @@
*/
public interface ProtocolMessageHandler {

/**
* Handles the message received on the channel.
* @param channel the channel on which the message was received
* @param message the message received
* @param startTime the start time
* @param slowLogThresholdMs the threshold for slow logs
* @param messageListener the message listener
*/
public void messageReceived(
TcpChannel channel,
ProtocolInboundMessage message,
long startTime,
long slowLogThresholdMs,
TransportMessageListener messageListener
) throws IOException;

/**
* Sets the message listener to be used by the handler.
* @param listener the message listener
*/
public void setMessageListener(TransportMessageListener listener);
}
Loading

0 comments on commit 14f1c43

Please sign in to comment.