From fcf1e41e429b10e03c5cf9b8551636df7519b4c5 Mon Sep 17 00:00:00 2001 From: Tim Brooks Date: Thu, 14 Jun 2018 15:10:02 -0600 Subject: [PATCH] Extract common http logic to server (#31311) This is related to #28898. With the addition of the http nio transport, we now have two different modules that provide http transports. Currently most of the http logic lives at the module level. However, some of this logic can live in server. In particular, some of the setting of headers, cors, and pipelining. This commit begins this moving in that direction by introducing lower level abstraction (HttpChannel, HttpRequest, and HttpResonse) that is implemented by the modules. The higher level rest request and rest channel work can live entirely in server. --- .../http/netty4/Netty4HttpChannel.java | 258 +------- .../netty4/Netty4HttpPipeliningHandler.java | 2 +- .../http/netty4/Netty4HttpRequest.java | 154 ++--- .../http/netty4/Netty4HttpRequestHandler.java | 129 +--- .../http/netty4/Netty4HttpResponse.java | 100 ++- .../netty4/Netty4HttpServerTransport.java | 54 +- .../http/netty4/cors/Netty4CorsHandler.java | 10 + .../transport/netty4/Netty4Transport.java | 2 +- .../transport/netty4/NettyTcpChannel.java | 7 +- .../http/netty4/Netty4CorsTests.java | 148 +++++ .../http/netty4/Netty4HttpChannelTests.java | 616 ------------------ .../Netty4HttpPipeliningHandlerTests.java | 26 +- .../Netty4HttpServerPipeliningTests.java | 19 +- .../Netty4HttpServerTransportTests.java | 34 - .../http/nio/HttpReadWriteHandler.java | 121 +--- .../http/nio/NioHttpChannel.java | 243 +------ .../http/nio/NioHttpPipeliningHandler.java | 2 +- .../http/nio/NioHttpRequest.java | 105 ++- .../http/nio/NioHttpResponse.java | 97 ++- .../http/nio/NioHttpServerTransport.java | 27 +- .../http/nio/cors/NioCorsHandler.java | 10 + .../http/nio/HttpReadWriteHandlerTests.java | 225 +++++-- .../http/nio/NioHttpChannelTests.java | 349 ---------- .../nio/NioHttpPipeliningHandlerTests.java | 26 +- .../http/nio/NioHttpServerTransportTests.java | 34 - .../http/AbstractHttpServerTransport.java | 103 ++- .../http/DefaultRestChannel.java | 172 +++++ .../org/elasticsearch/http/HttpChannel.java | 58 ++ .../http/HttpPipelinedMessage.java | 21 +- .../http/HttpPipelinedRequest.java | 10 +- .../org/elasticsearch/http/HttpRequest.java | 65 ++ .../org/elasticsearch/http/HttpResponse.java | 32 + .../rest/AbstractRestChannel.java | 2 +- .../elasticsearch/rest/RestController.java | 5 +- .../org/elasticsearch/rest/RestRequest.java | 102 +-- .../org/elasticsearch/rest/RestResponse.java | 14 +- .../AbstractHttpServerTransportTests.java | 93 +++ .../http/DefaultRestChannelTests.java | 444 +++++++++++++ .../rest/BytesRestResponseTests.java | 24 +- .../rest/RestControllerTests.java | 96 +-- .../elasticsearch/rest/RestRequestTests.java | 113 ++-- .../test/rest/FakeRestRequest.java | 141 +++- .../core/security/rest/RestRequestFilter.java | 26 +- .../security/audit/index/IndexAuditTrail.java | 15 +- .../audit/logfile/LoggingAuditTrail.java | 10 +- .../xpack/security/rest/RemoteHostHeader.java | 2 +- .../security/rest/SecurityRestFilter.java | 11 +- .../SecurityNetty4HttpServerTransport.java | 2 +- .../audit/index/IndexAuditTrailTests.java | 5 +- .../security/rest/RestRequestFilterTests.java | 2 +- .../rest/SecurityRestFilterTests.java | 2 + 51 files changed, 2101 insertions(+), 2267 deletions(-) create mode 100644 modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java delete mode 100644 modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java delete mode 100644 plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java create mode 100644 server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java create mode 100644 server/src/main/java/org/elasticsearch/http/HttpChannel.java create mode 100644 server/src/main/java/org/elasticsearch/http/HttpRequest.java create mode 100644 server/src/main/java/org/elasticsearch/http/HttpResponse.java create mode 100644 server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java index cb31d44454452..473985d21091b 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpChannel.java @@ -19,252 +19,58 @@ package org.elasticsearch.http.netty4; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; import io.netty.channel.Channel; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelPromise; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.handler.codec.http.cookie.ServerCookieDecoder; -import io.netty.handler.codec.http.cookie.ServerCookieEncoder; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; -import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpResponse; import org.elasticsearch.transport.netty4.Netty4Utils; -import java.util.Collections; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.net.InetSocketAddress; -final class Netty4HttpChannel extends AbstractRestChannel { +public class Netty4HttpChannel implements HttpChannel { - private final Netty4HttpServerTransport transport; private final Channel channel; - private final FullHttpRequest nettyRequest; - private final int sequence; - private final ThreadContext threadContext; - private final HttpHandlingSettings handlingSettings; - /** - * @param transport The corresponding NettyHttpServerTransport where this channel belongs to. - * @param request The request that is handled by this channel. - * @param sequence The pipelining sequence number for this request - * @param handlingSettings true if error messages should include stack traces. - * @param threadContext the thread context for the channel - */ - Netty4HttpChannel(Netty4HttpServerTransport transport, Netty4HttpRequest request, int sequence, HttpHandlingSettings handlingSettings, - ThreadContext threadContext) { - super(request, handlingSettings.getDetailedErrorsEnabled()); - this.transport = transport; - this.channel = request.getChannel(); - this.nettyRequest = request.request(); - this.sequence = sequence; - this.threadContext = threadContext; - this.handlingSettings = handlingSettings; + Netty4HttpChannel(Channel channel) { + this.channel = channel; } @Override - protected BytesStreamOutput newBytesOutput() { - return new ReleasableBytesStreamOutput(transport.bigArrays); - } - - @Override - public void sendResponse(RestResponse response) { - // if the response object was created upstream, then use it; - // otherwise, create a new one - ByteBuf buffer = Netty4Utils.toByteBuf(response.content()); - final FullHttpResponse resp; - if (HttpMethod.HEAD.equals(nettyRequest.method())) { - resp = newResponse(Unpooled.EMPTY_BUFFER); - } else { - resp = newResponse(buffer); - } - resp.setStatus(getStatus(response.status())); - - Netty4CorsHandler.setCorsResponseHeaders(nettyRequest, resp, transport.getCorsConfig()); - - String opaque = nettyRequest.headers().get("X-Opaque-Id"); - if (opaque != null) { - setHeaderField(resp, "X-Opaque-Id", opaque); - } - - // Add all custom headers - addCustomHeaders(resp, response.getHeaders()); - addCustomHeaders(resp, threadContext.getResponseHeaders()); - - BytesReference content = response.content(); - boolean releaseContent = content instanceof Releasable; - boolean releaseBytesStreamOutput = bytesOutputOrNull() instanceof ReleasableBytesStreamOutput; - try { - // If our response doesn't specify a content-type header, set one - setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false); - // If our response has no content-length, calculate and set one - setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false); - - addCookies(resp); - - final ChannelPromise promise = channel.newPromise(); - - if (releaseContent) { - promise.addListener(f -> ((Releasable) content).close()); - } - - if (releaseBytesStreamOutput) { - promise.addListener(f -> bytesOutputOrNull().close()); - } - - if (isCloseConnection()) { - promise.addListener(ChannelFutureListener.CLOSE); - } - - Netty4HttpResponse newResponse = new Netty4HttpResponse(sequence, resp); - - channel.writeAndFlush(newResponse, promise); - releaseContent = false; - releaseBytesStreamOutput = false; - } finally { - if (releaseContent) { - ((Releasable) content).close(); - } - if (releaseBytesStreamOutput) { - bytesOutputOrNull().close(); - } - } - } - - private void setHeaderField(HttpResponse resp, String headerField, String value) { - setHeaderField(resp, headerField, value, true); - } - - private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) { - if (override || !resp.headers().contains(headerField)) { - resp.headers().add(headerField, value); - } - } - - private void addCookies(HttpResponse resp) { - if (handlingSettings.isResetCookies()) { - String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE); - if (cookieString != null) { - Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); - if (!cookies.isEmpty()) { - // Reset the cookies if necessary. - resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies)); - } - } - } - } - - private void addCustomHeaders(HttpResponse response, Map> customHeaders) { - if (customHeaders != null) { - for (Map.Entry> headerEntry : customHeaders.entrySet()) { - for (String headerValue : headerEntry.getValue()) { - setHeaderField(response, headerEntry.getKey(), headerValue); + public void sendResponse(HttpResponse response, ActionListener listener) { + ChannelPromise writePromise = channel.newPromise(); + writePromise.addListener(f -> { + if (f.isSuccess()) { + listener.onResponse(null); + } else { + final Throwable cause = f.cause(); + Netty4Utils.maybeDie(cause); + if (cause instanceof Error) { + listener.onFailure(new Exception(cause)); + } else { + listener.onFailure((Exception) cause); } } - } + }); + channel.writeAndFlush(response, writePromise); } - // Determine if the request protocol version is HTTP 1.0 - private boolean isHttp10() { - return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0); - } - - // Determine if the request connection should be closed on completion. - private boolean isCloseConnection() { - final boolean http10 = isHttp10(); - return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) || - (http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION))); + @Override + public InetSocketAddress getLocalAddress() { + return (InetSocketAddress) channel.localAddress(); } - // Create a new {@link HttpResponse} to transmit the response for the netty request. - private FullHttpResponse newResponse(ByteBuf buffer) { - final boolean http10 = isHttp10(); - final boolean close = isCloseConnection(); - // Build the response object. - final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize - final FullHttpResponse response; - if (http10) { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer); - if (!close) { - response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive"); - } - } else { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer); - } - return response; + @Override + public InetSocketAddress getRemoteAddress() { + return (InetSocketAddress) channel.remoteAddress(); } - private static Map MAP; - - static { - EnumMap map = new EnumMap<>(RestStatus.class); - map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); - map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); - map.put(RestStatus.OK, HttpResponseStatus.OK); - map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); - map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); - map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); - map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); - map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); - map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); - map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? - map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); - map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); - map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); - map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); - map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); - map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); - map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); - map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); - map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); - map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); - map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); - map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); - map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); - map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); - map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); - map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); - map.put(RestStatus.GONE, HttpResponseStatus.GONE); - map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); - map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); - map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); - map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); - map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); - map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); - map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); - map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); - map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); - map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); - map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); - map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); - map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); - map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); - MAP = Collections.unmodifiableMap(map); + @Override + public void close() { + channel.close(); } - private static HttpResponseStatus getStatus(RestStatus status) { - return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + public Channel getNettyChannel() { + return channel; } } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java index 12c2e9a685778..e6436ccea1a93 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandler.java @@ -66,7 +66,7 @@ public void write(final ChannelHandlerContext ctx, final Object msg, final Chann try { List> readyResponses = aggregator.write(response, promise); for (Tuple readyResponse : readyResponses) { - ctx.write(readyResponse.v1().getResponse(), readyResponse.v2()); + ctx.write(readyResponse.v1(), readyResponse.v2()); } success = true; } catch (IllegalStateException e) { diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java index 2ce6ffada67f0..ffabe5cbbe224 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequest.java @@ -19,17 +19,22 @@ package org.elasticsearch.http.netty4; -import io.netty.channel.Channel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.cookie.Cookie; +import io.netty.handler.codec.http.cookie.ServerCookieDecoder; +import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpRequest; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.transport.netty4.Netty4Utils; -import java.net.SocketAddress; import java.util.AbstractMap; import java.util.Collection; import java.util.Collections; @@ -38,53 +43,16 @@ import java.util.Set; import java.util.stream.Collectors; -public class Netty4HttpRequest extends RestRequest { - +public class Netty4HttpRequest implements HttpRequest { private final FullHttpRequest request; - private final Channel channel; private final BytesReference content; + private final HttpHeadersMap headers; + private final int sequence; - /** - * Construct a new request. - * - * @param xContentRegistry the content registry - * @param request the underlying request - * @param channel the channel for the request - * @throws BadParameterException if the parameters can not be decoded - * @throws ContentTypeHeaderException if the Content-Type header can not be parsed - */ - Netty4HttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request, Channel channel) { - super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers())); - this.request = request; - this.channel = channel; - if (request.content().isReadable()) { - this.content = Netty4Utils.toBytesReference(request.content()); - } else { - this.content = BytesArray.EMPTY; - } - } - - /** - * Construct a new request. In contrast to - * {@link Netty4HttpRequest#Netty4HttpRequest(NamedXContentRegistry, Map, String, FullHttpRequest, Channel)}, the URI is not decoded so - * this constructor will not throw a {@link BadParameterException}. - * - * @param xContentRegistry the content registry - * @param params the parameters for the request - * @param uri the path for the request - * @param request the underlying request - * @param channel the channel for the request - * @throws ContentTypeHeaderException if the Content-Type header can not be parsed - */ - Netty4HttpRequest( - final NamedXContentRegistry xContentRegistry, - final Map params, - final String uri, - final FullHttpRequest request, - final Channel channel) { - super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers())); + Netty4HttpRequest(FullHttpRequest request, int sequence) { this.request = request; - this.channel = channel; + headers = new HttpHeadersMap(request.headers()); + this.sequence = sequence; if (request.content().isReadable()) { this.content = Netty4Utils.toBytesReference(request.content()); } else { @@ -92,43 +60,39 @@ public class Netty4HttpRequest extends RestRequest { } } - public FullHttpRequest request() { - return this.request; - } - @Override - public Method method() { + public RestRequest.Method method() { HttpMethod httpMethod = request.method(); if (httpMethod == HttpMethod.GET) - return Method.GET; + return RestRequest.Method.GET; if (httpMethod == HttpMethod.POST) - return Method.POST; + return RestRequest.Method.POST; if (httpMethod == HttpMethod.PUT) - return Method.PUT; + return RestRequest.Method.PUT; if (httpMethod == HttpMethod.DELETE) - return Method.DELETE; + return RestRequest.Method.DELETE; if (httpMethod == HttpMethod.HEAD) { - return Method.HEAD; + return RestRequest.Method.HEAD; } if (httpMethod == HttpMethod.OPTIONS) { - return Method.OPTIONS; + return RestRequest.Method.OPTIONS; } if (httpMethod == HttpMethod.PATCH) { - return Method.PATCH; + return RestRequest.Method.PATCH; } if (httpMethod == HttpMethod.TRACE) { - return Method.TRACE; + return RestRequest.Method.TRACE; } if (httpMethod == HttpMethod.CONNECT) { - return Method.CONNECT; + return RestRequest.Method.CONNECT; } throw new IllegalArgumentException("Unexpected http method: " + httpMethod); @@ -140,39 +104,63 @@ public String uri() { } @Override - public boolean hasContent() { - return content.length() > 0; + public BytesReference content() { + return content; + } + + + @Override + public final Map> getHeaders() { + return headers; } @Override - public BytesReference content() { - return content; + public List strictCookies() { + String cookieString = request.headers().get(HttpHeaderNames.COOKIE); + if (cookieString != null) { + Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); + if (!cookies.isEmpty()) { + return ServerCookieEncoder.STRICT.encode(cookies); + } + } + return Collections.emptyList(); } - /** - * Returns the remote address where this rest request channel is "connected to". The - * returned {@link SocketAddress} is supposed to be down-cast into more - * concrete type such as {@link java.net.InetSocketAddress} to retrieve - * the detailed information. - */ @Override - public SocketAddress getRemoteAddress() { - return channel.remoteAddress(); + public HttpVersion protocolVersion() { + if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) { + return HttpRequest.HttpVersion.HTTP_1_0; + } else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) { + return HttpRequest.HttpVersion.HTTP_1_1; + } else { + throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion()); + } } - /** - * Returns the local address where this request channel is bound to. The returned - * {@link SocketAddress} is supposed to be down-cast into more concrete - * type such as {@link java.net.InetSocketAddress} to retrieve the detailed - * information. - */ @Override - public SocketAddress getLocalAddress() { - return channel.localAddress(); + public HttpRequest removeHeader(String header) { + HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); + headersWithoutContentTypeHeader.add(request.headers()); + headersWithoutContentTypeHeader.remove(header); + HttpHeaders trailingHeaders = new DefaultHttpHeaders(); + trailingHeaders.add(request.trailingHeaders()); + trailingHeaders.remove(header); + FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(), + request.content(), headersWithoutContentTypeHeader, trailingHeaders); + return new Netty4HttpRequest(requestWithoutHeader, sequence); + } + + @Override + public Netty4HttpResponse createResponse(RestStatus status, BytesReference content) { + return new Netty4HttpResponse(this, status, content); + } + + public FullHttpRequest nettyRequest() { + return request; } - public Channel getChannel() { - return channel; + int sequence() { + return sequence; } /** @@ -249,7 +237,7 @@ public Collection> values() { @Override public Set>> entrySet() { return httpHeaders.names().stream().map(k -> new AbstractMap.SimpleImmutableEntry<>(k, httpHeaders.getAll(k))) - .collect(Collectors.toSet()); + .collect(Collectors.toSet()); } } } diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java index c3a010226a408..4547a63a9a278 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpRequestHandler.java @@ -20,112 +20,51 @@ package org.elasticsearch.http.netty4; import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.HttpHeaders; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.http.HttpPipelinedRequest; -import org.elasticsearch.rest.RestRequest; import org.elasticsearch.transport.netty4.Netty4Utils; -import java.util.Collections; - @ChannelHandler.Sharable class Netty4HttpRequestHandler extends SimpleChannelInboundHandler> { private final Netty4HttpServerTransport serverTransport; - private final HttpHandlingSettings handlingSettings; - private final ThreadContext threadContext; - Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport, HttpHandlingSettings handlingSettings, - ThreadContext threadContext) { + Netty4HttpRequestHandler(Netty4HttpServerTransport serverTransport) { this.serverTransport = serverTransport; - this.handlingSettings = handlingSettings; - this.threadContext = threadContext; } @Override protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest msg) throws Exception { - final FullHttpRequest request = msg.getRequest(); + Netty4HttpChannel channel = ctx.channel().attr(Netty4HttpServerTransport.HTTP_CHANNEL_KEY).get(); + FullHttpRequest request = msg.getRequest(); try { + final FullHttpRequest copiedRequest = + new DefaultFullHttpRequest( + request.protocolVersion(), + request.method(), + request.uri(), + Unpooled.copiedBuffer(request.content()), + request.headers(), + request.trailingHeaders()); - final FullHttpRequest copy = - new DefaultFullHttpRequest( - request.protocolVersion(), - request.method(), - request.uri(), - Unpooled.copiedBuffer(request.content()), - request.headers(), - request.trailingHeaders()); - - Exception badRequestCause = null; - - /* - * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there - * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we - * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, - * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the - * underlying exception that caused us to treat the request as bad. - */ - final Netty4HttpRequest httpRequest; - { - Netty4HttpRequest innerHttpRequest; - try { - innerHttpRequest = new Netty4HttpRequest(serverTransport.xContentRegistry, copy, ctx.channel()); - } catch (final RestRequest.ContentTypeHeaderException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutContentTypeHeader(copy, ctx.channel(), badRequestCause); - } catch (final RestRequest.BadParameterException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutParameters(copy, ctx.channel()); - } - httpRequest = innerHttpRequest; - } - - /* - * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid - * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an - * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these - * parameter values. - */ - final Netty4HttpChannel channel; - { - Netty4HttpChannel innerChannel; - try { - innerChannel = - new Netty4HttpChannel(serverTransport, httpRequest, msg.getSequence(), handlingSettings, threadContext); - } catch (final IllegalArgumentException e) { - if (badRequestCause == null) { - badRequestCause = e; - } else { - badRequestCause.addSuppressed(e); - } - final Netty4HttpRequest innerRequest = - new Netty4HttpRequest( - serverTransport.xContentRegistry, - Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters - copy.uri(), - copy, - ctx.channel()); - innerChannel = - new Netty4HttpChannel(serverTransport, innerRequest, msg.getSequence(), handlingSettings, threadContext); - } - channel = innerChannel; - } + Netty4HttpRequest httpRequest = new Netty4HttpRequest(copiedRequest, msg.getSequence()); if (request.decoderResult().isFailure()) { - serverTransport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause()); - } else if (badRequestCause != null) { - serverTransport.dispatchBadRequest(httpRequest, channel, badRequestCause); + Throwable cause = request.decoderResult().cause(); + if (cause instanceof Error) { + ExceptionsHelper.dieOnError(cause); + serverTransport.incomingRequestError(httpRequest, channel, new Exception(cause)); + } else { + serverTransport.incomingRequestError(httpRequest, channel, (Exception) cause); + } } else { - serverTransport.dispatchRequest(httpRequest, channel); + serverTransport.incomingRequest(httpRequest, channel); } } finally { // As we have copied the buffer, we can release the request @@ -133,32 +72,6 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpPipelinedRequest MAP; + + static { + EnumMap map = new EnumMap<>(RestStatus.class); + map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); + map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); + map.put(RestStatus.OK, HttpResponseStatus.OK); + map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); + map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); + map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); + map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); + map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); + map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); + map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? + map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); + map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); + map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); + map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); + map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); + map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); + map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); + map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); + map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); + map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); + map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); + map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); + map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); + map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); + map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); + map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); + map.put(RestStatus.GONE, HttpResponseStatus.GONE); + map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); + map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); + map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); + map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); + map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); + map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); + map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); + map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); + map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); + map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); + map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); + map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); + map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); + map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); + MAP = Collections.unmodifiableMap(map); + } + + private static HttpResponseStatus getStatus(RestStatus status) { + return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + } + } + diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java index 0e18232e01cc7..6bfd8168dbe47 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/Netty4HttpServerTransport.java @@ -39,6 +39,7 @@ import io.netty.handler.codec.http.HttpResponseEncoder; import io.netty.handler.timeout.ReadTimeoutException; import io.netty.handler.timeout.ReadTimeoutHandler; +import io.netty.util.AttributeKey; import org.apache.logging.log4j.message.ParameterizedMessage; import org.apache.logging.log4j.util.Supplier; import org.elasticsearch.common.Strings; @@ -53,9 +54,7 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.http.AbstractHttpServerTransport; import org.elasticsearch.http.BindHttpException; import org.elasticsearch.http.HttpHandlingSettings; @@ -149,38 +148,29 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { public static final Setting SETTING_HTTP_NETTY_RECEIVE_PREDICTOR_SIZE = Setting.byteSizeSetting("http.netty.receive_predictor_size", new ByteSizeValue(64, ByteSizeUnit.KB), Property.NodeScope); - protected final BigArrays bigArrays; + private final ByteSizeValue maxInitialLineLength; + private final ByteSizeValue maxHeaderSize; + private final ByteSizeValue maxChunkSize; - protected final ByteSizeValue maxInitialLineLength; - protected final ByteSizeValue maxHeaderSize; - protected final ByteSizeValue maxChunkSize; + private final int workerCount; - protected final int workerCount; + private final int pipeliningMaxEvents; - protected final int pipeliningMaxEvents; + private final boolean tcpNoDelay; + private final boolean tcpKeepAlive; + private final boolean reuseAddress; - /** - * The registry used to construct parsers so they support {@link XContentParser#namedObject(Class, String, Object)}. - */ - protected final NamedXContentRegistry xContentRegistry; - - protected final boolean tcpNoDelay; - protected final boolean tcpKeepAlive; - protected final boolean reuseAddress; - - protected final ByteSizeValue tcpSendBufferSize; - protected final ByteSizeValue tcpReceiveBufferSize; - protected final RecvByteBufAllocator recvByteBufAllocator; + private final ByteSizeValue tcpSendBufferSize; + private final ByteSizeValue tcpReceiveBufferSize; + private final RecvByteBufAllocator recvByteBufAllocator; private final int readTimeoutMillis; - protected final int maxCompositeBufferComponents; + private final int maxCompositeBufferComponents; protected volatile ServerBootstrap serverBootstrap; protected final List serverChannels = new ArrayList<>(); - protected final HttpHandlingSettings httpHandlingSettings; - // package private for testing Netty4OpenChannelsHandler serverOpenChannels; @@ -189,16 +179,13 @@ public class Netty4HttpServerTransport extends AbstractHttpServerTransport { public Netty4HttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) { - super(settings, networkService, threadPool, dispatcher); + super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher); Netty4Utils.setAvailableProcessors(EsExecutors.PROCESSORS_SETTING.get(settings)); - this.bigArrays = bigArrays; - this.xContentRegistry = xContentRegistry; this.maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); this.maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); this.maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); this.pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); - this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); this.maxCompositeBufferComponents = SETTING_HTTP_NETTY_MAX_COMPOSITE_BUFFER_COMPONENTS.get(settings); this.workerCount = SETTING_HTTP_WORKER_COUNT.get(settings); @@ -398,26 +385,27 @@ protected void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throw } public ChannelHandler configureServerChannelHandler() { - return new HttpChannelHandler(this, httpHandlingSettings, threadPool.getThreadContext()); + return new HttpChannelHandler(this, handlingSettings); } + static final AttributeKey HTTP_CHANNEL_KEY = AttributeKey.newInstance("es-http-channel"); + protected static class HttpChannelHandler extends ChannelInitializer { private final Netty4HttpServerTransport transport; private final Netty4HttpRequestHandler requestHandler; private final HttpHandlingSettings handlingSettings; - protected HttpChannelHandler( - final Netty4HttpServerTransport transport, - final HttpHandlingSettings handlingSettings, - final ThreadContext threadContext) { + protected HttpChannelHandler(final Netty4HttpServerTransport transport, final HttpHandlingSettings handlingSettings) { this.transport = transport; this.handlingSettings = handlingSettings; - this.requestHandler = new Netty4HttpRequestHandler(transport, handlingSettings, threadContext); + this.requestHandler = new Netty4HttpRequestHandler(transport); } @Override protected void initChannel(Channel ch) throws Exception { + Netty4HttpChannel nettyTcpChannel = new Netty4HttpChannel(ch); + ch.attr(HTTP_CHANNEL_KEY).set(nettyTcpChannel); ch.pipeline().addLast("openChannels", transport.serverOpenChannels); ch.pipeline().addLast("read_timeout", new ReadTimeoutHandler(transport.readTimeoutMillis, TimeUnit.MILLISECONDS)); final HttpRequestDecoder decoder = new HttpRequestDecoder( diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java index 779eb4fe2e465..38d832d608051 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/http/netty4/cors/Netty4CorsHandler.java @@ -22,6 +22,7 @@ import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; @@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import org.elasticsearch.common.Strings; +import org.elasticsearch.http.netty4.Netty4HttpResponse; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -76,6 +78,14 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception ctx.fireChannelRead(msg); } + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assert msg instanceof Netty4HttpResponse : "Invalid message type: " + msg.getClass(); + Netty4HttpResponse response = (Netty4HttpResponse) msg; + setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config); + ctx.write(response, promise);; + } + public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, Netty4CorsConfig config) { if (!config.isCorsSupportEnabled()) { return; diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java index f4818a2e56752..466c4b68bfa4e 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/Netty4Transport.java @@ -333,10 +333,10 @@ protected void initChannel(Channel ch) throws Exception { addClosedExceptionLogger(ch); NettyTcpChannel nettyTcpChannel = new NettyTcpChannel(ch, name); ch.attr(CHANNEL_KEY).set(nettyTcpChannel); - serverAcceptedChannel(nettyTcpChannel); ch.pipeline().addLast("logging", new ESLoggingHandler()); ch.pipeline().addLast("size", new Netty4SizeHeaderFrameDecoder()); ch.pipeline().addLast("dispatcher", new Netty4MessageChannelHandler(Netty4Transport.this, name)); + serverAcceptedChannel(nettyTcpChannel); } @Override diff --git a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java index f650e757e7a62..89fabdcd763d1 100644 --- a/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java +++ b/modules/transport-netty4/src/main/java/org/elasticsearch/transport/netty4/NettyTcpChannel.java @@ -98,8 +98,11 @@ public void sendMessage(BytesReference reference, ActionListener listener) } else { final Throwable cause = f.cause(); Netty4Utils.maybeDie(cause); - assert cause instanceof Exception; - listener.onFailure((Exception) cause); + if (cause instanceof Error) { + listener.onFailure(new Exception(cause)); + } else { + listener.onFailure((Exception) cause); + } } }); channel.writeAndFlush(Netty4Utils.toByteBuf(reference), writePromise); diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java new file mode 100644 index 0000000000000..15a0850f64d38 --- /dev/null +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4CorsTests.java @@ -0,0 +1,148 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http.netty4; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; + +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +public class Netty4CorsTests extends ESTestCase { + + public void testCorsEnabledWithoutAllowOrigins() { + // Set up a HTTP transport with only the CORS enabled setting + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .build(); + HttpResponse response = executeRequest(settings, "remote-host", "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); + } + + public void testCorsEnabledWithAllowOrigins() { + final String originValue = "remote-host"; + // create a http transport with CORS enabled and allow origin configured + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + HttpResponse response = executeRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testCorsAllowOriginWithSameHost() { + String originValue = "remote-host"; + String host = "remote-host"; + // create a http transport with CORS enabled + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .build(); + HttpResponse response = executeRequest(settings, originValue, host); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = "http://" + originValue; + response = executeRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue + ":5555"; + host = host + ":5555"; + response = executeRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue.replace("http", "https"); + response = executeRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testThatStringLiteralWorksOnMatch() { + final String originValue = "remote-host"; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + HttpResponse response = executeRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + } + + public void testThatAnyOriginWorks() { + final String originValue = Netty4CorsHandler.ANY_ORIGIN; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + HttpResponse response = executeRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); + } + + private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { + // construct request and send it over the transport layer + final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + if (originValue != null) { + httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); + } + httpRequest.headers().add(HttpHeaderNames.HOST, host); + EmbeddedChannel embeddedChannel = new EmbeddedChannel(); + embeddedChannel.pipeline().addLast(new Netty4CorsHandler(Netty4HttpServerTransport.buildCorsConfig(settings))); + Netty4HttpRequest nettyRequest = new Netty4HttpRequest(httpRequest, 0); + embeddedChannel.writeOutbound(nettyRequest.createResponse(RestStatus.OK, new BytesArray("content"))); + return embeddedChannel.readOutbound(); + } +} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java deleted file mode 100644 index 7c5b35a322996..0000000000000 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpChannelTests.java +++ /dev/null @@ -1,616 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.http.netty4; - -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.channel.Channel; -import io.netty.channel.ChannelConfig; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelId; -import io.netty.channel.ChannelMetadata; -import io.netty.channel.ChannelPipeline; -import io.netty.channel.ChannelProgressivePromise; -import io.netty.channel.ChannelPromise; -import io.netty.channel.EventLoop; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.util.Attribute; -import io.netty.util.AttributeKey; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.bytes.ReleasablePagedBytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.lease.Releasables; -import org.elasticsearch.common.network.NetworkService; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.ByteArray; -import org.elasticsearch.common.util.MockBigArrays; -import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.HttpTransportSettings; -import org.elasticsearch.http.NullDispatcher; -import org.elasticsearch.http.netty4.cors.Netty4CorsHandler; -import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.rest.BytesRestResponse; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.netty4.Netty4Utils; -import org.junit.After; -import org.junit.Before; - -import java.io.IOException; -import java.io.UnsupportedEncodingException; -import java.net.SocketAddress; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; - -public class Netty4HttpChannelTests extends ESTestCase { - - private NetworkService networkService; - private ThreadPool threadPool; - private MockBigArrays bigArrays; - - @Before - public void setup() throws Exception { - networkService = new NetworkService(Collections.emptyList()); - threadPool = new TestThreadPool("test"); - bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - } - - @After - public void shutdown() throws Exception { - if (threadPool != null) { - threadPool.shutdownNow(); - } - } - - public void testResponse() { - final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host"); - assertThat(response.content(), equalTo(Netty4Utils.toByteBuf(new TestResponse().content()))); - } - - public void testCorsEnabledWithoutAllowOrigins() { - // Set up a HTTP transport with only the CORS enabled setting - Settings settings = Settings.builder() - .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, "remote-host", "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); - } - - public void testCorsEnabledWithAllowOrigins() { - final String originValue = "remote-host"; - // create a http transport with CORS enabled and allow origin configured - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testCorsAllowOriginWithSameHost() { - String originValue = "remote-host"; - String host = "remote-host"; - // create a http transport with CORS enabled - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, host); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = "http://" + originValue; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue + ":5555"; - host = host + ":5555"; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue.replace("http", "https"); - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testThatStringLiteralWorksOnMatch() { - final String originValue = "remote-host"; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") - .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); - } - - public void testThatAnyOriginWorks() { - final String originValue = Netty4CorsHandler.ANY_ORIGIN; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); - } - - public void testHeadersSet() { - Settings settings = Settings.builder().build(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), - new NullDispatcher())) { - httpServerTransport.start(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote"); - final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - - // send a response - Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - TestResponse resp = new TestResponse(); - final String customHeader = "custom-header"; - final String customHeaderValue = "xyz"; - resp.addHeader(customHeader, customHeaderValue); - channel.sendResponse(resp); - - // inspect what was written - List writtenObjects = writeCapturingChannel.getWrittenObjects(); - assertThat(writtenObjects.size(), is(1)); - HttpResponse response = ((Netty4HttpResponse) writtenObjects.get(0)).getResponse(); - assertThat(response.headers().get("non-existent-header"), nullValue()); - assertThat(response.headers().get(customHeader), equalTo(customHeaderValue)); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length()))); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType())); - } - } - - public void testReleaseOnSendToClosedChannel() { - final Settings settings = Settings.builder().build(); - final NamedXContentRegistry registry = xContentRegistry(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) { - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - final TestResponse response = new TestResponse(bigArrays); - assertThat(response.content(), instanceOf(Releasable.class)); - embeddedChannel.close(); - channel.sendResponse(response); - // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released - } - } - - public void testReleaseOnSendToChannelAfterException() throws IOException { - final Settings settings = Settings.builder().build(); - final NamedXContentRegistry registry = xContentRegistry(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, registry, new NullDispatcher())) { - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(registry, httpRequest, embeddedChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, - JsonXContent.contentBuilder().startObject().endObject()); - assertThat(response.content(), not(instanceOf(Releasable.class))); - - // ensure we have reserved bytes - if (randomBoolean()) { - BytesStreamOutput out = channel.bytesOutput(); - assertThat(out, instanceOf(ReleasableBytesStreamOutput.class)); - } else { - try (XContentBuilder builder = channel.newBuilder()) { - // do something builder - builder.startObject().endObject(); - } - } - - channel.sendResponse(response); - // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released - } - } - - public void testConnectionClose() throws Exception { - final Settings settings = Settings.builder().build(); - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), new NullDispatcher())) { - httpServerTransport.start(); - final FullHttpRequest httpRequest; - final boolean close = randomBoolean(); - if (randomBoolean()) { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); - } - } else { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/"); - if (!close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); - } - } - final EmbeddedChannel embeddedChannel = new EmbeddedChannel(); - final Netty4HttpRequest request = new Netty4HttpRequest(xContentRegistry(), httpRequest, embeddedChannel); - - // send a response, the channel close status should match - assertTrue(embeddedChannel.isOpen()); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - final Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - final TestResponse resp = new TestResponse(); - channel.sendResponse(resp); - assertThat(embeddedChannel.isOpen(), equalTo(!close)); - } - } - - private FullHttpResponse executeRequest(final Settings settings, final String host) { - return executeRequest(settings, null, host); - } - - private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { - // construct request and send it over the transport layer - try (Netty4HttpServerTransport httpServerTransport = - new Netty4HttpServerTransport(settings, networkService, bigArrays, threadPool, xContentRegistry(), - new NullDispatcher())) { - httpServerTransport.start(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (originValue != null) { - httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); - } - httpRequest.headers().add(HttpHeaderNames.HOST, host); - final WriteCapturingChannel writeCapturingChannel = new WriteCapturingChannel(); - final Netty4HttpRequest request = - new Netty4HttpRequest(xContentRegistry(), httpRequest, writeCapturingChannel); - HttpHandlingSettings handlingSettings = httpServerTransport.httpHandlingSettings; - - Netty4HttpChannel channel = - new Netty4HttpChannel(httpServerTransport, request, 1, handlingSettings, threadPool.getThreadContext()); - channel.sendResponse(new TestResponse()); - - // get the response - List writtenObjects = writeCapturingChannel.getWrittenObjects(); - assertThat(writtenObjects.size(), is(1)); - return ((Netty4HttpResponse) writtenObjects.get(0)).getResponse(); - } - } - - private static class WriteCapturingChannel implements Channel { - - private List writtenObjects = new ArrayList<>(); - - @Override - public ChannelId id() { - return null; - } - - @Override - public EventLoop eventLoop() { - return null; - } - - @Override - public Channel parent() { - return null; - } - - @Override - public ChannelConfig config() { - return null; - } - - @Override - public boolean isOpen() { - return false; - } - - @Override - public boolean isRegistered() { - return false; - } - - @Override - public boolean isActive() { - return false; - } - - @Override - public ChannelMetadata metadata() { - return null; - } - - @Override - public SocketAddress localAddress() { - return null; - } - - @Override - public SocketAddress remoteAddress() { - return null; - } - - @Override - public ChannelFuture closeFuture() { - return null; - } - - @Override - public boolean isWritable() { - return false; - } - - @Override - public long bytesBeforeUnwritable() { - return 0; - } - - @Override - public long bytesBeforeWritable() { - return 0; - } - - @Override - public Unsafe unsafe() { - return null; - } - - @Override - public ChannelPipeline pipeline() { - return null; - } - - @Override - public ByteBufAllocator alloc() { - return null; - } - - @Override - public Channel read() { - return null; - } - - @Override - public Channel flush() { - return null; - } - - @Override - public ChannelFuture bind(SocketAddress localAddress) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { - return null; - } - - @Override - public ChannelFuture disconnect() { - return null; - } - - @Override - public ChannelFuture close() { - return null; - } - - @Override - public ChannelFuture deregister() { - return null; - } - - @Override - public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture disconnect(ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture close(ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture deregister(ChannelPromise promise) { - return null; - } - - @Override - public ChannelFuture write(Object msg) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelFuture write(Object msg, ChannelPromise promise) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelFuture writeAndFlush(Object msg) { - writtenObjects.add(msg); - return null; - } - - @Override - public ChannelPromise newPromise() { - return null; - } - - @Override - public ChannelProgressivePromise newProgressivePromise() { - return null; - } - - @Override - public ChannelFuture newSucceededFuture() { - return null; - } - - @Override - public ChannelFuture newFailedFuture(Throwable cause) { - return null; - } - - @Override - public ChannelPromise voidPromise() { - return null; - } - - @Override - public Attribute attr(AttributeKey key) { - return null; - } - - @Override - public boolean hasAttr(AttributeKey key) { - return false; - } - - @Override - public int compareTo(Channel o) { - return 0; - } - - List getWrittenObjects() { - return writtenObjects; - } - - } - - private static class TestResponse extends RestResponse { - - private final BytesReference reference; - - TestResponse() { - reference = Netty4Utils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8)); - } - - TestResponse(final BigArrays bigArrays) { - final byte[] bytes; - try { - bytes = "content".getBytes("UTF-8"); - } catch (final UnsupportedEncodingException e) { - throw new AssertionError(e); - } - final ByteArray bigArray = bigArrays.newByteArray(bytes.length); - bigArray.set(0, bytes, 0, bytes.length); - reference = new ReleasablePagedBytesReference(bigArrays, bigArray, bytes.length, Releasables.releaseOnce(bigArray)); - } - - @Override - public String contentType() { - return "text"; - } - - @Override - public BytesReference content() { - return reference; - } - - @Override - public RestStatus status() { - return RestStatus.OK; - } - - } - -} diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java index f6c5dfd5a50b2..8b3ba19fe0144 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpPipeliningHandlerTests.java @@ -19,15 +19,12 @@ package org.elasticsearch.http.netty4; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpMethod; @@ -35,7 +32,10 @@ import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.QueryStringDecoder; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpPipelinedRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.After; @@ -55,7 +55,6 @@ import java.util.stream.IntStream; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; -import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static org.hamcrest.core.Is.is; @@ -191,11 +190,11 @@ public void testPipeliningRequestsAreReleased() throws InterruptedException { ArrayList promises = new ArrayList<>(); for (int i = 1; i < requests.size(); ++i) { - final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK); ChannelPromise promise = embeddedChannel.newPromise(); promises.add(promise); - int sequence = requests.get(i).getSequence(); - Netty4HttpResponse resp = new Netty4HttpResponse(sequence, httpResponse); + HttpPipelinedRequest pipelinedRequest = requests.get(i); + Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + Netty4HttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY); embeddedChannel.writeAndFlush(resp, promise); } @@ -233,10 +232,10 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpRequest request) thro } - private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { + private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { @Override - protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { + protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { LastHttpContent request = pipelinedRequest.getRequest(); final QueryStringDecoder decoder; if (request instanceof FullHttpRequest) { @@ -246,9 +245,10 @@ protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedReques } final String uri = decoder.path().replace("/", ""); - final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); - final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content); - httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes()); + final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8)); + Netty4HttpRequest nioHttpRequest = new Netty4HttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + Netty4HttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content); + httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length())); final CountDownLatch waitingLatch = new CountDownLatch(1); waitingRequests.put(uri, waitingLatch); @@ -260,7 +260,7 @@ protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedReques waitingLatch.await(1000, TimeUnit.SECONDS); final ChannelPromise promise = ctx.newPromise(); eventLoopService.submit(() -> { - ctx.write(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); + ctx.write(httpResponse, promise); finishingLatch.countDown(); }); } catch (InterruptedException e) { diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java index f2b28b909187b..3101f660d056e 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerPipeliningTests.java @@ -26,22 +26,20 @@ import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; -import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.MockBigArrays; import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.NullDispatcher; import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; @@ -120,7 +118,7 @@ class CustomNettyHttpServerTransport extends Netty4HttpServerTransport { @Override public ChannelHandler configureServerChannelHandler() { - return new CustomHttpChannelHandler(this, executorService, Netty4HttpServerPipeliningTests.this.threadPool.getThreadContext()); + return new CustomHttpChannelHandler(this, executorService); } @Override @@ -135,8 +133,8 @@ private class CustomHttpChannelHandler extends Netty4HttpServerTransport.HttpCha private final ExecutorService executorService; - CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService, ThreadContext threadContext) { - super(transport, transport.httpHandlingSettings, threadContext); + CustomHttpChannelHandler(Netty4HttpServerTransport transport, ExecutorService executorService) { + super(transport, transport.handlingSettings); this.executorService = executorService; } @@ -187,8 +185,9 @@ public void run() { final ByteBuf buffer = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); - final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, buffer); - httpResponse.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes()); + Netty4HttpRequest httpRequest = new Netty4HttpRequest(fullHttpRequest, pipelinedRequest.getSequence()); + Netty4HttpResponse response = httpRequest.createResponse(RestStatus.OK, new BytesArray(uri.getBytes(StandardCharsets.UTF_8))); + response.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes()); final boolean slow = uri.matches("/slow/\\d+"); if (slow) { @@ -202,7 +201,7 @@ public void run() { } final ChannelPromise promise = ctx.newPromise(); - ctx.writeAndFlush(new Netty4HttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); + ctx.writeAndFlush(response, promise); } } diff --git a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java index 5b22409b92da0..bcf28506143bf 100644 --- a/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java +++ b/modules/transport-netty4/src/test/java/org/elasticsearch/http/netty4/Netty4HttpServerTransportTests.java @@ -291,40 +291,6 @@ public void dispatchBadRequest(final RestRequest request, assertThat(causeReference.get(), instanceOf(TooLongFrameException.class)); } - public void testDispatchDoesNotModifyThreadContext() throws InterruptedException { - final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { - - @Override - public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { - threadContext.putHeader("foo", "bar"); - threadContext.putTransient("bar", "baz"); - } - - @Override - public void dispatchBadRequest(final RestRequest request, - final RestChannel channel, - final ThreadContext threadContext, - final Throwable cause) { - threadContext.putHeader("foo_bad", "bar"); - threadContext.putTransient("bar_bad", "baz"); - } - - }; - - try (Netty4HttpServerTransport transport = - new Netty4HttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) { - transport.start(); - - transport.dispatchRequest(null, null); - assertNull(threadPool.getThreadContext().getHeader("foo")); - assertNull(threadPool.getThreadContext().getTransient("bar")); - - transport.dispatchBadRequest(null, null, null); - assertNull(threadPool.getThreadContext().getHeader("foo_bad")); - assertNull(threadPool.getThreadContext().getTransient("bar_bad")); - } - } - public void testReadTimeout() throws Exception { final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java index 05f28e8254aa1..ea75c62dbbce2 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/HttpReadWriteHandler.java @@ -23,54 +23,38 @@ import io.netty.channel.ChannelHandler; import io.netty.handler.codec.ByteToMessageDecoder; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.HttpContentCompressor; import io.netty.handler.codec.http.HttpContentDecompressor; -import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestDecoder; import io.netty.handler.codec.http.HttpResponseEncoder; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpPipelinedRequest; import org.elasticsearch.http.nio.cors.NioCorsConfig; import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ReadWriteHandler; import org.elasticsearch.nio.SocketChannelContext; import org.elasticsearch.nio.WriteOperation; -import org.elasticsearch.rest.RestRequest; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.function.BiConsumer; - public class HttpReadWriteHandler implements ReadWriteHandler { private final NettyAdaptor adaptor; - private final NioSocketChannel nioChannel; + private final NioHttpChannel nioHttpChannel; private final NioHttpServerTransport transport; - private final HttpHandlingSettings settings; - private final NamedXContentRegistry xContentRegistry; - private final NioCorsConfig corsConfig; - private final ThreadContext threadContext; - - HttpReadWriteHandler(NioSocketChannel nioChannel, NioHttpServerTransport transport, HttpHandlingSettings settings, - NamedXContentRegistry xContentRegistry, NioCorsConfig corsConfig, ThreadContext threadContext) { - this.nioChannel = nioChannel; + + HttpReadWriteHandler(NioHttpChannel nioHttpChannel, NioHttpServerTransport transport, HttpHandlingSettings settings, + NioCorsConfig corsConfig) { + this.nioHttpChannel = nioHttpChannel; this.transport = transport; - this.settings = settings; - this.xContentRegistry = xContentRegistry; - this.corsConfig = corsConfig; - this.threadContext = threadContext; List handlers = new ArrayList<>(5); HttpRequestDecoder decoder = new HttpRequestDecoder(settings.getMaxInitialLineLength(), settings.getMaxHeaderSize(), @@ -89,7 +73,7 @@ public class HttpReadWriteHandler implements ReadWriteHandler { handlers.add(new NioHttpPipeliningHandler(transport.getLogger(), settings.getPipeliningMaxEvents())); adaptor = new NettyAdaptor(handlers.toArray(new ChannelHandler[0])); - adaptor.addCloseListener((v, e) -> nioChannel.close()); + adaptor.addCloseListener((v, e) -> nioHttpChannel.close()); } @Override @@ -150,95 +134,22 @@ private void handleRequest(Object msg) { request.headers(), request.trailingHeaders()); - Exception badRequestCause = null; - - /* - * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there - * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we - * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, - * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the - * underlying exception that caused us to treat the request as bad. - */ - final NioHttpRequest httpRequest; - { - NioHttpRequest innerHttpRequest; - try { - innerHttpRequest = new NioHttpRequest(xContentRegistry, copiedRequest); - } catch (final RestRequest.ContentTypeHeaderException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutContentTypeHeader(copiedRequest, badRequestCause); - } catch (final RestRequest.BadParameterException e) { - badRequestCause = e; - innerHttpRequest = requestWithoutParameters(copiedRequest); - } - httpRequest = innerHttpRequest; - } - - /* - * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid - * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an - * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of - * these parameter values. - */ - final NioHttpChannel channel; - { - NioHttpChannel innerChannel; - int sequence = pipelinedRequest.getSequence(); - BigArrays bigArrays = transport.getBigArrays(); - try { - innerChannel = new NioHttpChannel(nioChannel, bigArrays, httpRequest, sequence, settings, corsConfig, threadContext); - } catch (final IllegalArgumentException e) { - if (badRequestCause == null) { - badRequestCause = e; - } else { - badRequestCause.addSuppressed(e); - } - final NioHttpRequest innerRequest = - new NioHttpRequest( - xContentRegistry, - Collections.emptyMap(), // we are going to dispatch the request as a bad request, drop all parameters - copiedRequest.uri(), - copiedRequest); - innerChannel = new NioHttpChannel(nioChannel, bigArrays, innerRequest, sequence, settings, corsConfig, threadContext); - } - channel = innerChannel; - } + NioHttpRequest httpRequest = new NioHttpRequest(copiedRequest, pipelinedRequest.getSequence()); if (request.decoderResult().isFailure()) { - transport.dispatchBadRequest(httpRequest, channel, request.decoderResult().cause()); - } else if (badRequestCause != null) { - transport.dispatchBadRequest(httpRequest, channel, badRequestCause); + Throwable cause = request.decoderResult().cause(); + if (cause instanceof Error) { + ExceptionsHelper.dieOnError(cause); + transport.incomingRequestError(httpRequest, nioHttpChannel, new Exception(cause)); + } else { + transport.incomingRequestError(httpRequest, nioHttpChannel, (Exception) cause); + } } else { - transport.dispatchRequest(httpRequest, channel); + transport.incomingRequest(httpRequest, nioHttpChannel); } } finally { // As we have copied the buffer, we can release the request request.release(); } } - - private NioHttpRequest requestWithoutContentTypeHeader(final FullHttpRequest request, final Exception badRequestCause) { - final HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); - headersWithoutContentTypeHeader.add(request.headers()); - headersWithoutContentTypeHeader.remove("Content-Type"); - final FullHttpRequest requestWithoutContentTypeHeader = - new DefaultFullHttpRequest( - request.protocolVersion(), - request.method(), - request.uri(), - request.content(), - headersWithoutContentTypeHeader, // remove the Content-Type header so as to not parse it again - request.trailingHeaders()); // Content-Type can not be a trailing header - try { - return new NioHttpRequest(xContentRegistry, requestWithoutContentTypeHeader); - } catch (final RestRequest.BadParameterException e) { - badRequestCause.addSuppressed(e); - return requestWithoutParameters(requestWithoutContentTypeHeader); - } - } - - private NioHttpRequest requestWithoutParameters(final FullHttpRequest request) { - // remove all parameters as at least one is incorrectly encoded - return new NioHttpRequest(xContentRegistry, Collections.emptyMap(), request.uri(), request); - } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java index 634421b34ea48..088f0e85dde23 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpChannel.java @@ -19,244 +19,21 @@ package org.elasticsearch.http.nio; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpResponseStatus; -import io.netty.handler.codec.http.HttpVersion; -import io.netty.handler.codec.http.cookie.Cookie; -import io.netty.handler.codec.http.cookie.ServerCookieDecoder; -import io.netty.handler.codec.http.cookie.ServerCookieEncoder; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.lease.Releasables; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.nio.cors.NioCorsConfig; -import org.elasticsearch.http.nio.cors.NioCorsHandler; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpResponse; import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.rest.AbstractRestChannel; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import java.util.ArrayList; -import java.util.Collections; -import java.util.EnumMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.function.BiConsumer; +import java.io.IOException; +import java.nio.channels.SocketChannel; -public class NioHttpChannel extends AbstractRestChannel { +public class NioHttpChannel extends NioSocketChannel implements HttpChannel { - private final BigArrays bigArrays; - private final int sequence; - private final NioCorsConfig corsConfig; - private final ThreadContext threadContext; - private final FullHttpRequest nettyRequest; - private final NioSocketChannel nioChannel; - private final boolean resetCookies; - - NioHttpChannel(NioSocketChannel nioChannel, BigArrays bigArrays, NioHttpRequest request, int sequence, - HttpHandlingSettings settings, NioCorsConfig corsConfig, ThreadContext threadContext) { - super(request, settings.getDetailedErrorsEnabled()); - this.nioChannel = nioChannel; - this.bigArrays = bigArrays; - this.sequence = sequence; - this.corsConfig = corsConfig; - this.threadContext = threadContext; - this.nettyRequest = request.getRequest(); - this.resetCookies = settings.isResetCookies(); - } - - @Override - public void sendResponse(RestResponse response) { - // if the response object was created upstream, then use it; - // otherwise, create a new one - ByteBuf buffer = ByteBufUtils.toByteBuf(response.content()); - final FullHttpResponse resp; - if (HttpMethod.HEAD.equals(nettyRequest.method())) { - resp = newResponse(Unpooled.EMPTY_BUFFER); - } else { - resp = newResponse(buffer); - } - resp.setStatus(getStatus(response.status())); - - NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig); - - String opaque = nettyRequest.headers().get("X-Opaque-Id"); - if (opaque != null) { - setHeaderField(resp, "X-Opaque-Id", opaque); - } - - // Add all custom headers - addCustomHeaders(resp, response.getHeaders()); - addCustomHeaders(resp, threadContext.getResponseHeaders()); - - ArrayList toClose = new ArrayList<>(3); - - boolean success = false; - try { - // If our response doesn't specify a content-type header, set one - setHeaderField(resp, HttpHeaderNames.CONTENT_TYPE.toString(), response.contentType(), false); - // If our response has no content-length, calculate and set one - setHeaderField(resp, HttpHeaderNames.CONTENT_LENGTH.toString(), String.valueOf(buffer.readableBytes()), false); - - addCookies(resp); - - BytesReference content = response.content(); - if (content instanceof Releasable) { - toClose.add((Releasable) content); - } - BytesStreamOutput bytesStreamOutput = bytesOutputOrNull(); - if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) { - toClose.add((Releasable) bytesStreamOutput); - } - - if (isCloseConnection()) { - toClose.add(nioChannel::close); - } - - BiConsumer listener = (aVoid, ex) -> Releasables.close(toClose); - nioChannel.getContext().sendMessage(new NioHttpResponse(sequence, resp), listener); - success = true; - } finally { - if (success == false) { - Releasables.close(toClose); - } - } - } - - @Override - protected BytesStreamOutput newBytesOutput() { - return new ReleasableBytesStreamOutput(bigArrays); - } - - private void setHeaderField(HttpResponse resp, String headerField, String value) { - setHeaderField(resp, headerField, value, true); - } - - private void setHeaderField(HttpResponse resp, String headerField, String value, boolean override) { - if (override || !resp.headers().contains(headerField)) { - resp.headers().add(headerField, value); - } - } - - private void addCookies(HttpResponse resp) { - if (resetCookies) { - String cookieString = nettyRequest.headers().get(HttpHeaderNames.COOKIE); - if (cookieString != null) { - Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); - if (!cookies.isEmpty()) { - // Reset the cookies if necessary. - resp.headers().set(HttpHeaderNames.SET_COOKIE, ServerCookieEncoder.STRICT.encode(cookies)); - } - } - } - } - - private void addCustomHeaders(HttpResponse response, Map> customHeaders) { - if (customHeaders != null) { - for (Map.Entry> headerEntry : customHeaders.entrySet()) { - for (String headerValue : headerEntry.getValue()) { - setHeaderField(response, headerEntry.getKey(), headerValue); - } - } - } - } - - // Create a new {@link HttpResponse} to transmit the response for the netty request. - private FullHttpResponse newResponse(ByteBuf buffer) { - final boolean http10 = isHttp10(); - final boolean close = isCloseConnection(); - // Build the response object. - final HttpResponseStatus status = HttpResponseStatus.OK; // default to initialize - final FullHttpResponse response; - if (http10) { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, status, buffer); - if (!close) { - response.headers().add(HttpHeaderNames.CONNECTION, "Keep-Alive"); - } - } else { - response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, buffer); - } - return response; - } - - // Determine if the request protocol version is HTTP 1.0 - private boolean isHttp10() { - return nettyRequest.protocolVersion().equals(HttpVersion.HTTP_1_0); - } - - // Determine if the request connection should be closed on completion. - private boolean isCloseConnection() { - final boolean http10 = isHttp10(); - return HttpHeaderValues.CLOSE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION)) || - (http10 && !HttpHeaderValues.KEEP_ALIVE.contentEqualsIgnoreCase(nettyRequest.headers().get(HttpHeaderNames.CONNECTION))); - } - - private static Map MAP; - - static { - EnumMap map = new EnumMap<>(RestStatus.class); - map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); - map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); - map.put(RestStatus.OK, HttpResponseStatus.OK); - map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); - map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); - map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); - map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); - map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); - map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); - map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? - map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); - map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); - map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); - map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); - map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); - map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); - map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); - map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); - map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); - map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); - map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); - map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); - map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); - map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); - map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); - map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); - map.put(RestStatus.GONE, HttpResponseStatus.GONE); - map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); - map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); - map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); - map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); - map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); - map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); - map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); - map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); - map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); - map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); - map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); - map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); - map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); - map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); - map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); - MAP = Collections.unmodifiableMap(map); + NioHttpChannel(SocketChannel socketChannel) throws IOException { + super(socketChannel); } - private static HttpResponseStatus getStatus(RestStatus status) { - return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); + public void sendResponse(HttpResponse response, ActionListener listener) { + getContext().sendMessage(response, ActionListener.toBiConsumer(listener)); } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java index 1eb63364f995a..977092ddac0aa 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpPipeliningHandler.java @@ -68,7 +68,7 @@ public void write(final ChannelHandlerContext ctx, final Object msg, final Chann List> readyResponses = aggregator.write(response, listener); success = true; for (Tuple responseToWrite : readyResponses) { - ctx.write(responseToWrite.v1().getResponse(), responseToWrite.v2()); + ctx.write(responseToWrite.v1(), responseToWrite.v2()); } } catch (IllegalStateException e) { ctx.channel().close(); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java index 4dcd6ba19e06b..08937593f3ba6 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpRequest.java @@ -19,13 +19,20 @@ package org.elasticsearch.http.nio; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.cookie.Cookie; +import io.netty.handler.codec.http.cookie.ServerCookieDecoder; +import io.netty.handler.codec.http.cookie.ServerCookieEncoder; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpRequest; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import java.util.AbstractMap; import java.util.Collection; @@ -35,25 +42,17 @@ import java.util.Set; import java.util.stream.Collectors; -public class NioHttpRequest extends RestRequest { +public class NioHttpRequest implements HttpRequest { private final FullHttpRequest request; private final BytesReference content; + private final HttpHeadersMap headers; + private final int sequence; - NioHttpRequest(NamedXContentRegistry xContentRegistry, FullHttpRequest request) { - super(xContentRegistry, request.uri(), new HttpHeadersMap(request.headers())); - this.request = request; - if (request.content().isReadable()) { - this.content = ByteBufUtils.toBytesReference(request.content()); - } else { - this.content = BytesArray.EMPTY; - } - - } - - NioHttpRequest(NamedXContentRegistry xContentRegistry, Map params, String uri, FullHttpRequest request) { - super(xContentRegistry, params, uri, new HttpHeadersMap(request.headers())); + NioHttpRequest(FullHttpRequest request, int sequence) { this.request = request; + headers = new HttpHeadersMap(request.headers()); + this.sequence = sequence; if (request.content().isReadable()) { this.content = ByteBufUtils.toBytesReference(request.content()); } else { @@ -62,38 +61,38 @@ public class NioHttpRequest extends RestRequest { } @Override - public Method method() { + public RestRequest.Method method() { HttpMethod httpMethod = request.method(); if (httpMethod == HttpMethod.GET) - return Method.GET; + return RestRequest.Method.GET; if (httpMethod == HttpMethod.POST) - return Method.POST; + return RestRequest.Method.POST; if (httpMethod == HttpMethod.PUT) - return Method.PUT; + return RestRequest.Method.PUT; if (httpMethod == HttpMethod.DELETE) - return Method.DELETE; + return RestRequest.Method.DELETE; if (httpMethod == HttpMethod.HEAD) { - return Method.HEAD; + return RestRequest.Method.HEAD; } if (httpMethod == HttpMethod.OPTIONS) { - return Method.OPTIONS; + return RestRequest.Method.OPTIONS; } if (httpMethod == HttpMethod.PATCH) { - return Method.PATCH; + return RestRequest.Method.PATCH; } if (httpMethod == HttpMethod.TRACE) { - return Method.TRACE; + return RestRequest.Method.TRACE; } if (httpMethod == HttpMethod.CONNECT) { - return Method.CONNECT; + return RestRequest.Method.CONNECT; } throw new IllegalArgumentException("Unexpected http method: " + httpMethod); @@ -105,19 +104,65 @@ public String uri() { } @Override - public boolean hasContent() { - return content.length() > 0; + public BytesReference content() { + return content; } + @Override - public BytesReference content() { - return content; + public final Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + String cookieString = request.headers().get(HttpHeaderNames.COOKIE); + if (cookieString != null) { + Set cookies = ServerCookieDecoder.STRICT.decode(cookieString); + if (!cookies.isEmpty()) { + return ServerCookieEncoder.STRICT.encode(cookies); + } + } + return Collections.emptyList(); + } + + @Override + public HttpVersion protocolVersion() { + if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_0)) { + return HttpRequest.HttpVersion.HTTP_1_0; + } else if (request.protocolVersion().equals(io.netty.handler.codec.http.HttpVersion.HTTP_1_1)) { + return HttpRequest.HttpVersion.HTTP_1_1; + } else { + throw new IllegalArgumentException("Unexpected http protocol version: " + request.protocolVersion()); + } } - public FullHttpRequest getRequest() { + @Override + public HttpRequest removeHeader(String header) { + HttpHeaders headersWithoutContentTypeHeader = new DefaultHttpHeaders(); + headersWithoutContentTypeHeader.add(request.headers()); + headersWithoutContentTypeHeader.remove(header); + HttpHeaders trailingHeaders = new DefaultHttpHeaders(); + trailingHeaders.add(request.trailingHeaders()); + trailingHeaders.remove(header); + FullHttpRequest requestWithoutHeader = new DefaultFullHttpRequest(request.protocolVersion(), request.method(), request.uri(), + request.content(), headersWithoutContentTypeHeader, trailingHeaders); + return new NioHttpRequest(requestWithoutHeader, sequence); + } + + @Override + public NioHttpResponse createResponse(RestStatus status, BytesReference content) { + return new NioHttpResponse(this, status, content); + } + + public FullHttpRequest nettyRequest() { return request; } + int sequence() { + return sequence; + } + /** * A wrapper of {@link HttpHeaders} that implements a map to prevent copying unnecessarily. This class does not support modifications * and due to the underlying implementation, it performs case insensitive lookups of key to values. diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java index 4b634994b4557..24de843dcc82d 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpResponse.java @@ -19,19 +19,100 @@ package org.elasticsearch.http.nio; -import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpPipelinedMessage; +import org.elasticsearch.http.HttpResponse; +import org.elasticsearch.rest.RestStatus; -public class NioHttpResponse extends HttpPipelinedMessage { +import java.util.Collections; +import java.util.EnumMap; +import java.util.Map; - private final FullHttpResponse response; +public class NioHttpResponse extends DefaultFullHttpResponse implements HttpResponse, HttpPipelinedMessage { - public NioHttpResponse(int sequence, FullHttpResponse response) { - super(sequence); - this.response = response; + private final int sequence; + private final NioHttpRequest request; + + NioHttpResponse(NioHttpRequest request, RestStatus status, BytesReference content) { + super(request.nettyRequest().protocolVersion(), getStatus(status), ByteBufUtils.toByteBuf(content)); + this.sequence = request.sequence(); + this.request = request; + } + + @Override + public void addHeader(String name, String value) { + headers().add(name, value); + } + + @Override + public boolean containsHeader(String name) { + return headers().contains(name); + } + + @Override + public int getSequence() { + return sequence; + } + + private static Map MAP; + + public NioHttpRequest getRequest() { + return request; + } + + static { + EnumMap map = new EnumMap<>(RestStatus.class); + map.put(RestStatus.CONTINUE, HttpResponseStatus.CONTINUE); + map.put(RestStatus.SWITCHING_PROTOCOLS, HttpResponseStatus.SWITCHING_PROTOCOLS); + map.put(RestStatus.OK, HttpResponseStatus.OK); + map.put(RestStatus.CREATED, HttpResponseStatus.CREATED); + map.put(RestStatus.ACCEPTED, HttpResponseStatus.ACCEPTED); + map.put(RestStatus.NON_AUTHORITATIVE_INFORMATION, HttpResponseStatus.NON_AUTHORITATIVE_INFORMATION); + map.put(RestStatus.NO_CONTENT, HttpResponseStatus.NO_CONTENT); + map.put(RestStatus.RESET_CONTENT, HttpResponseStatus.RESET_CONTENT); + map.put(RestStatus.PARTIAL_CONTENT, HttpResponseStatus.PARTIAL_CONTENT); + map.put(RestStatus.MULTI_STATUS, HttpResponseStatus.INTERNAL_SERVER_ERROR); // no status for this?? + map.put(RestStatus.MULTIPLE_CHOICES, HttpResponseStatus.MULTIPLE_CHOICES); + map.put(RestStatus.MOVED_PERMANENTLY, HttpResponseStatus.MOVED_PERMANENTLY); + map.put(RestStatus.FOUND, HttpResponseStatus.FOUND); + map.put(RestStatus.SEE_OTHER, HttpResponseStatus.SEE_OTHER); + map.put(RestStatus.NOT_MODIFIED, HttpResponseStatus.NOT_MODIFIED); + map.put(RestStatus.USE_PROXY, HttpResponseStatus.USE_PROXY); + map.put(RestStatus.TEMPORARY_REDIRECT, HttpResponseStatus.TEMPORARY_REDIRECT); + map.put(RestStatus.BAD_REQUEST, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.UNAUTHORIZED, HttpResponseStatus.UNAUTHORIZED); + map.put(RestStatus.PAYMENT_REQUIRED, HttpResponseStatus.PAYMENT_REQUIRED); + map.put(RestStatus.FORBIDDEN, HttpResponseStatus.FORBIDDEN); + map.put(RestStatus.NOT_FOUND, HttpResponseStatus.NOT_FOUND); + map.put(RestStatus.METHOD_NOT_ALLOWED, HttpResponseStatus.METHOD_NOT_ALLOWED); + map.put(RestStatus.NOT_ACCEPTABLE, HttpResponseStatus.NOT_ACCEPTABLE); + map.put(RestStatus.PROXY_AUTHENTICATION, HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED); + map.put(RestStatus.REQUEST_TIMEOUT, HttpResponseStatus.REQUEST_TIMEOUT); + map.put(RestStatus.CONFLICT, HttpResponseStatus.CONFLICT); + map.put(RestStatus.GONE, HttpResponseStatus.GONE); + map.put(RestStatus.LENGTH_REQUIRED, HttpResponseStatus.LENGTH_REQUIRED); + map.put(RestStatus.PRECONDITION_FAILED, HttpResponseStatus.PRECONDITION_FAILED); + map.put(RestStatus.REQUEST_ENTITY_TOO_LARGE, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); + map.put(RestStatus.REQUEST_URI_TOO_LONG, HttpResponseStatus.REQUEST_URI_TOO_LONG); + map.put(RestStatus.UNSUPPORTED_MEDIA_TYPE, HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE); + map.put(RestStatus.REQUESTED_RANGE_NOT_SATISFIED, HttpResponseStatus.REQUESTED_RANGE_NOT_SATISFIABLE); + map.put(RestStatus.EXPECTATION_FAILED, HttpResponseStatus.EXPECTATION_FAILED); + map.put(RestStatus.UNPROCESSABLE_ENTITY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.LOCKED, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.FAILED_DEPENDENCY, HttpResponseStatus.BAD_REQUEST); + map.put(RestStatus.TOO_MANY_REQUESTS, HttpResponseStatus.TOO_MANY_REQUESTS); + map.put(RestStatus.INTERNAL_SERVER_ERROR, HttpResponseStatus.INTERNAL_SERVER_ERROR); + map.put(RestStatus.NOT_IMPLEMENTED, HttpResponseStatus.NOT_IMPLEMENTED); + map.put(RestStatus.BAD_GATEWAY, HttpResponseStatus.BAD_GATEWAY); + map.put(RestStatus.SERVICE_UNAVAILABLE, HttpResponseStatus.SERVICE_UNAVAILABLE); + map.put(RestStatus.GATEWAY_TIMEOUT, HttpResponseStatus.GATEWAY_TIMEOUT); + map.put(RestStatus.HTTP_VERSION_NOT_SUPPORTED, HttpResponseStatus.HTTP_VERSION_NOT_SUPPORTED); + MAP = Collections.unmodifiableMap(map); } - public FullHttpResponse getResponse() { - return response; + private static HttpResponseStatus getStatus(RestStatus status) { + return MAP.getOrDefault(status, HttpResponseStatus.INTERNAL_SERVER_ERROR); } } diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java index 57aaebb16a1a2..5aac491a6abd4 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/NioHttpServerTransport.java @@ -42,7 +42,6 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.http.AbstractHttpServerTransport; import org.elasticsearch.http.BindHttpException; -import org.elasticsearch.http.HttpHandlingSettings; import org.elasticsearch.http.HttpServerTransport; import org.elasticsearch.http.HttpStats; import org.elasticsearch.http.nio.cors.NioCorsConfig; @@ -53,11 +52,11 @@ import org.elasticsearch.nio.InboundChannelBuffer; import org.elasticsearch.nio.NioChannel; import org.elasticsearch.nio.NioGroup; +import org.elasticsearch.nio.NioSelector; import org.elasticsearch.nio.NioServerSocketChannel; import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.ServerChannelContext; import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.nio.NioSelector; import org.elasticsearch.rest.RestUtils; import org.elasticsearch.threadpool.ThreadPool; @@ -104,12 +103,6 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { (s) -> Integer.toString(EsExecutors.numberOfProcessors(s) * 2), (s) -> Setting.parseInt(s, 1, "http.nio.worker_count"), Setting.Property.NodeScope); - private final BigArrays bigArrays; - private final ThreadPool threadPool; - private final NamedXContentRegistry xContentRegistry; - - private final HttpHandlingSettings httpHandlingSettings; - private final boolean tcpNoDelay; private final boolean tcpKeepAlive; private final boolean reuseAddress; @@ -124,16 +117,12 @@ public class NioHttpServerTransport extends AbstractHttpServerTransport { public NioHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, NamedXContentRegistry xContentRegistry, HttpServerTransport.Dispatcher dispatcher) { - super(settings, networkService, threadPool, dispatcher); - this.bigArrays = bigArrays; - this.threadPool = threadPool; - this.xContentRegistry = xContentRegistry; + super(settings, networkService, bigArrays, threadPool, xContentRegistry, dispatcher); ByteSizeValue maxChunkSize = SETTING_HTTP_MAX_CHUNK_SIZE.get(settings); ByteSizeValue maxHeaderSize = SETTING_HTTP_MAX_HEADER_SIZE.get(settings); ByteSizeValue maxInitialLineLength = SETTING_HTTP_MAX_INITIAL_LINE_LENGTH.get(settings); int pipeliningMaxEvents = SETTING_PIPELINING_MAX_EVENTS.get(settings); - this.httpHandlingSettings = HttpHandlingSettings.fromSettings(settings);; this.corsConfig = buildCorsConfig(settings); this.tcpNoDelay = SETTING_HTTP_TCP_NO_DELAY.get(settings); @@ -148,10 +137,6 @@ public NioHttpServerTransport(Settings settings, NetworkService networkService, maxChunkSize, maxHeaderSize, maxInitialLineLength, maxContentLength, pipeliningMaxEvents); } - BigArrays getBigArrays() { - return bigArrays; - } - public Logger getLogger() { return logger; } @@ -335,17 +320,17 @@ private void acceptChannel(NioSocketChannel socketChannel) { socketChannels.add(socketChannel); } - private class HttpChannelFactory extends ChannelFactory { + private class HttpChannelFactory extends ChannelFactory { private HttpChannelFactory() { super(new RawChannelFactory(tcpNoDelay, tcpKeepAlive, reuseAddress, tcpSendBufferSize, tcpReceiveBufferSize)); } @Override - public NioSocketChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { - NioSocketChannel nioChannel = new NioSocketChannel(channel); + public NioHttpChannel createChannel(NioSelector selector, SocketChannel channel) throws IOException { + NioHttpChannel nioChannel = new NioHttpChannel(channel); HttpReadWriteHandler httpReadWritePipeline = new HttpReadWriteHandler(nioChannel,NioHttpServerTransport.this, - httpHandlingSettings, xContentRegistry, corsConfig, threadPool.getThreadContext()); + handlingSettings, corsConfig); Consumer exceptionHandler = (e) -> exceptionCaught(nioChannel, e); SocketChannelContext context = new BytesChannelContext(nioChannel, selector, exceptionHandler, httpReadWritePipeline, InboundChannelBuffer.allocatingInstance()); diff --git a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java index 6358510703779..98ae2d523ca81 100644 --- a/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java +++ b/plugins/transport-nio/src/main/java/org/elasticsearch/http/nio/cors/NioCorsHandler.java @@ -22,6 +22,7 @@ import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpHeaders; @@ -30,6 +31,7 @@ import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseStatus; import org.elasticsearch.common.Strings; +import org.elasticsearch.http.nio.NioHttpResponse; import java.util.regex.Pattern; import java.util.stream.Collectors; @@ -76,6 +78,14 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception ctx.fireChannelRead(msg); } + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assert msg instanceof NioHttpResponse : "Invalid message type: " + msg.getClass(); + NioHttpResponse response = (NioHttpResponse) msg; + setCorsResponseHeaders(response.getRequest().nettyRequest(), response, config); + ctx.write(response, promise); + } + public static void setCorsResponseHeaders(HttpRequest request, HttpResponse resp, NioCorsConfig config) { if (!config.isCorsSupportEnabled()) { return; diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java index 6ad53521ee12a..5bda7e1b83d81 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/HttpReadWriteHandlerTests.java @@ -23,29 +23,31 @@ import io.netty.buffer.Unpooled; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultFullHttpResponse; -import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpObjectAggregator; import io.netty.handler.codec.http.HttpRequestEncoder; -import io.netty.handler.codec.http.HttpResponse; import io.netty.handler.codec.http.HttpResponseDecoder; import io.netty.handler.codec.http.HttpResponseStatus; import io.netty.handler.codec.http.HttpUtil; import io.netty.handler.codec.http.HttpVersion; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; -import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.http.HttpChannel; import org.elasticsearch.http.HttpHandlingSettings; +import org.elasticsearch.http.HttpRequest; +import org.elasticsearch.http.HttpResponse; +import org.elasticsearch.http.HttpTransportSettings; +import org.elasticsearch.http.nio.cors.NioCorsConfig; import org.elasticsearch.http.nio.cors.NioCorsConfigBuilder; +import org.elasticsearch.http.nio.cors.NioCorsHandler; import org.elasticsearch.nio.FlushOperation; import org.elasticsearch.nio.InboundChannelBuffer; -import org.elasticsearch.nio.NioSocketChannel; import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -55,6 +57,9 @@ import java.util.List; import java.util.function.BiConsumer; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; +import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_COMPRESSION_LEVEL; @@ -64,7 +69,12 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_MAX_INITIAL_LINE_LENGTH; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_RESET_COOKIES; import static org.elasticsearch.http.HttpTransportSettings.SETTING_PIPELINING_MAX_EVENTS; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; import static org.mockito.Matchers.any; +import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -72,7 +82,7 @@ public class HttpReadWriteHandlerTests extends ESTestCase { private HttpReadWriteHandler handler; - private NioSocketChannel nioSocketChannel; + private NioHttpChannel nioHttpChannel; private NioHttpServerTransport transport; private final RequestEncoder requestEncoder = new RequestEncoder(); @@ -96,15 +106,13 @@ public void setMocks() { SETTING_HTTP_DETAILED_ERRORS_ENABLED.getDefault(settings), SETTING_PIPELINING_MAX_EVENTS.getDefault(settings), SETTING_CORS_ENABLED.getDefault(settings)); - ThreadContext threadContext = new ThreadContext(settings); - nioSocketChannel = mock(NioSocketChannel.class); - handler = new HttpReadWriteHandler(nioSocketChannel, transport, httpHandlingSettings, NamedXContentRegistry.EMPTY, - NioCorsConfigBuilder.forAnyOrigin().build(), threadContext); + nioHttpChannel = mock(NioHttpChannel.class); + handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, NioCorsConfigBuilder.forAnyOrigin().build()); } public void testSuccessfulDecodeHttpRequest() throws IOException { String uri = "localhost:9090/" + randomAlphaOfLength(8); - HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); ByteBuf buf = requestEncoder.encode(httpRequest); int slicePoint = randomInt(buf.writerIndex() - 1); @@ -113,22 +121,21 @@ public void testSuccessfulDecodeHttpRequest() throws IOException { ByteBuf slicedBuf2 = buf.retainedSlice(slicePoint, buf.writerIndex()); handler.consumeReads(toChannelBuffer(slicedBuf)); - verify(transport, times(0)).dispatchRequest(any(RestRequest.class), any(RestChannel.class)); + verify(transport, times(0)).incomingRequest(any(HttpRequest.class), any(NioHttpChannel.class)); handler.consumeReads(toChannelBuffer(slicedBuf2)); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(RestRequest.class); - verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(transport).incomingRequest(requestCaptor.capture(), any(NioHttpChannel.class)); - NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); - FullHttpRequest nettyHttpRequest = nioHttpRequest.getRequest(); - assertEquals(httpRequest.protocolVersion(), nettyHttpRequest.protocolVersion()); - assertEquals(httpRequest.method(), nettyHttpRequest.method()); + HttpRequest nioHttpRequest = requestCaptor.getValue(); + assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); + assertEquals(RestRequest.Method.GET, nioHttpRequest.method()); } public void testDecodeHttpRequestError() throws IOException { String uri = "localhost:9090/" + randomAlphaOfLength(8); - HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); + io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, uri); ByteBuf buf = requestEncoder.encode(httpRequest); buf.setByte(0, ' '); @@ -137,15 +144,15 @@ public void testDecodeHttpRequestError() throws IOException { handler.consumeReads(toChannelBuffer(buf)); - ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Throwable.class); - verify(transport).dispatchBadRequest(any(RestRequest.class), any(RestChannel.class), exceptionCaptor.capture()); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(transport).incomingRequestError(any(HttpRequest.class), any(NioHttpChannel.class), exceptionCaptor.capture()); assertTrue(exceptionCaptor.getValue() instanceof IllegalArgumentException); } public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() throws IOException { String uri = "localhost:9090/" + randomAlphaOfLength(8); - HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false); + io.netty.handler.codec.http.HttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, uri, false); HttpUtil.setContentLength(httpRequest, 1025); HttpUtil.setKeepAlive(httpRequest, false); @@ -153,60 +160,176 @@ public void testDecodeHttpRequestContentLengthToLongGeneratesOutboundMessage() t handler.consumeReads(toChannelBuffer(buf)); - verify(transport, times(0)).dispatchBadRequest(any(), any(), any()); - verify(transport, times(0)).dispatchRequest(any(), any()); + verify(transport, times(0)).incomingRequestError(any(), any(), any()); + verify(transport, times(0)).incomingRequest(any(), any()); List flushOperations = handler.pollFlushOperations(); assertFalse(flushOperations.isEmpty()); FlushOperation flushOperation = flushOperations.get(0); - HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); + FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); flushOperation.getListener().accept(null, null); // Since we have keep-alive set to false, we should close the channel after the response has been // flushed - verify(nioSocketChannel).close(); + verify(nioHttpChannel).close(); } @SuppressWarnings("unchecked") public void testEncodeHttpResponse() throws IOException { prepareHandlerForResponse(handler); - FullHttpResponse fullHttpResponse = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); - NioHttpResponse pipelinedResponse = new NioHttpResponse(0, fullHttpResponse); + DefaultFullHttpRequest nettyRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + NioHttpRequest nioHttpRequest = new NioHttpRequest(nettyRequest, 0); + NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY); + httpResponse.addHeader(HttpHeaderNames.CONTENT_LENGTH.toString(), "0"); SocketChannelContext context = mock(SocketChannelContext.class); - HttpWriteOperation writeOperation = new HttpWriteOperation(context, pipelinedResponse, mock(BiConsumer.class)); + HttpWriteOperation writeOperation = new HttpWriteOperation(context, httpResponse, mock(BiConsumer.class)); List flushOperations = handler.writeToBytes(writeOperation); - HttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite())); + FullHttpResponse response = responseDecoder.decode(Unpooled.wrappedBuffer(flushOperations.get(0).getBuffersToWrite())); assertEquals(HttpResponseStatus.OK, response.status()); assertEquals(HttpVersion.HTTP_1_1, response.protocolVersion()); } - private FullHttpRequest prepareHandlerForResponse(HttpReadWriteHandler adaptor) throws IOException { - HttpMethod method = HttpMethod.GET; - HttpVersion version = HttpVersion.HTTP_1_1; + public void testCorsEnabledWithoutAllowOrigins() throws IOException { + // Set up a HTTP transport with only the CORS enabled setting + Settings settings = Settings.builder() + .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, "remote-host", "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); + } + + public void testCorsEnabledWithAllowOrigins() throws IOException { + final String originValue = "remote-host"; + // create a http transport with CORS enabled and allow origin configured + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testCorsAllowOriginWithSameHost() throws IOException { + String originValue = "remote-host"; + String host = "remote-host"; + // create a http transport with CORS enabled + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .build(); + FullHttpResponse response = executeCorsRequest(settings, originValue, host); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = "http://" + originValue; + response = executeCorsRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue + ":5555"; + host = host + ":5555"; + response = executeCorsRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + + originValue = originValue.replace("http", "https"); + response = executeCorsRequest(settings, originValue, host); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + } + + public void testThatStringLiteralWorksOnMatch() throws IOException { + final String originValue = "remote-host"; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") + .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + } + + public void testThatAnyOriginWorks() throws IOException { + final String originValue = NioCorsHandler.ANY_ORIGIN; + Settings settings = Settings.builder() + .put(SETTING_CORS_ENABLED.getKey(), true) + .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) + .build(); + io.netty.handler.codec.http.HttpResponse response = executeCorsRequest(settings, originValue, "request-host"); + // inspect response and validate + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); + String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); + assertThat(allowedOrigins, is(originValue)); + assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); + } + + private FullHttpResponse executeCorsRequest(final Settings settings, final String originValue, final String host) throws IOException { + HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); + NioCorsConfig nioCorsConfig = NioHttpServerTransport.buildCorsConfig(settings); + HttpReadWriteHandler handler = new HttpReadWriteHandler(nioHttpChannel, transport, httpHandlingSettings, nioCorsConfig); + prepareHandlerForResponse(handler); + DefaultFullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + if (originValue != null) { + httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); + } + httpRequest.headers().add(HttpHeaderNames.HOST, host); + NioHttpRequest nioHttpRequest = new NioHttpRequest(httpRequest, 0); + BytesArray content = new BytesArray("content"); + HttpResponse response = nioHttpRequest.createResponse(RestStatus.OK, content); + response.addHeader("Content-Length", Integer.toString(content.length())); + + SocketChannelContext context = mock(SocketChannelContext.class); + List flushOperations = handler.writeToBytes(handler.createWriteOperation(context, response, (v, e) -> {})); + + FlushOperation flushOperation = flushOperations.get(0); + return responseDecoder.decode(Unpooled.wrappedBuffer(flushOperation.getBuffersToWrite())); + } + + + + private NioHttpRequest prepareHandlerForResponse(HttpReadWriteHandler handler) throws IOException { + HttpMethod method = randomBoolean() ? HttpMethod.GET : HttpMethod.HEAD; + HttpVersion version = randomBoolean() ? HttpVersion.HTTP_1_0 : HttpVersion.HTTP_1_1; String uri = "http://localhost:9090/" + randomAlphaOfLength(8); - HttpRequest request = new DefaultFullHttpRequest(version, method, uri); + io.netty.handler.codec.http.HttpRequest request = new DefaultFullHttpRequest(version, method, uri); ByteBuf buf = requestEncoder.encode(request); handler.consumeReads(toChannelBuffer(buf)); - ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(RestRequest.class); - verify(transport).dispatchRequest(requestCaptor.capture(), any(RestChannel.class)); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(NioHttpRequest.class); + verify(transport, atLeastOnce()).incomingRequest(requestCaptor.capture(), any(HttpChannel.class)); - NioHttpRequest nioHttpRequest = (NioHttpRequest) requestCaptor.getValue(); - FullHttpRequest requestParsed = nioHttpRequest.getRequest(); - assertNotNull(requestParsed); - assertEquals(requestParsed.method(), method); - assertEquals(requestParsed.protocolVersion(), version); - assertEquals(requestParsed.uri(), uri); - return requestParsed; + NioHttpRequest nioHttpRequest = requestCaptor.getValue(); + assertNotNull(nioHttpRequest); + assertEquals(method.name(), nioHttpRequest.method().name()); + if (version == HttpVersion.HTTP_1_1) { + assertEquals(HttpRequest.HttpVersion.HTTP_1_1, nioHttpRequest.protocolVersion()); + } else { + assertEquals(HttpRequest.HttpVersion.HTTP_1_0, nioHttpRequest.protocolVersion()); + } + assertEquals(nioHttpRequest.uri(), uri); + return nioHttpRequest; } private InboundChannelBuffer toChannelBuffer(ByteBuf buf) { @@ -226,11 +349,13 @@ private InboundChannelBuffer toChannelBuffer(ByteBuf buf) { return buffer; } + private static final int MAX = 16 * 1024 * 1024; + private static class RequestEncoder { - private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder()); + private final EmbeddedChannel requestEncoder = new EmbeddedChannel(new HttpRequestEncoder(), new HttpObjectAggregator(MAX)); - private ByteBuf encode(HttpRequest httpRequest) { + private ByteBuf encode(io.netty.handler.codec.http.HttpRequest httpRequest) { requestEncoder.writeOutbound(httpRequest); return requestEncoder.readOutbound(); } @@ -238,9 +363,9 @@ private ByteBuf encode(HttpRequest httpRequest) { private static class ResponseDecoder { - private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder()); + private final EmbeddedChannel responseDecoder = new EmbeddedChannel(new HttpResponseDecoder(), new HttpObjectAggregator(MAX)); - private HttpResponse decode(ByteBuf response) { + private FullHttpResponse decode(ByteBuf response) { responseDecoder.writeInbound(response); return responseDecoder.readInbound(); } diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java deleted file mode 100644 index 5fa0a7ae0a679..0000000000000 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpChannelTests.java +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to Elasticsearch under one or more contributor - * license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright - * ownership. Elasticsearch licenses this file to you under - * the Apache License, Version 2.0 (the "License"); you may - * not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.elasticsearch.http.nio; - -import io.netty.buffer.Unpooled; -import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.FullHttpRequest; -import io.netty.handler.codec.http.FullHttpResponse; -import io.netty.handler.codec.http.HttpHeaderNames; -import io.netty.handler.codec.http.HttpHeaderValues; -import io.netty.handler.codec.http.HttpMethod; -import io.netty.handler.codec.http.HttpResponse; -import io.netty.handler.codec.http.HttpVersion; -import org.elasticsearch.common.bytes.BytesReference; -import org.elasticsearch.common.io.stream.BytesStreamOutput; -import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; -import org.elasticsearch.common.lease.Releasable; -import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.common.util.MockBigArrays; -import org.elasticsearch.common.util.MockPageCacheRecycler; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; -import org.elasticsearch.common.xcontent.XContentBuilder; -import org.elasticsearch.common.xcontent.json.JsonXContent; -import org.elasticsearch.http.HttpHandlingSettings; -import org.elasticsearch.http.HttpTransportSettings; -import org.elasticsearch.http.nio.cors.NioCorsConfig; -import org.elasticsearch.http.nio.cors.NioCorsHandler; -import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; -import org.elasticsearch.nio.NioSocketChannel; -import org.elasticsearch.nio.SocketChannelContext; -import org.elasticsearch.rest.BytesRestResponse; -import org.elasticsearch.rest.RestResponse; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; -import org.junit.After; -import org.junit.Before; -import org.mockito.ArgumentCaptor; - -import java.io.IOException; -import java.nio.channels.ClosedChannelException; -import java.nio.charset.StandardCharsets; -import java.util.function.BiConsumer; - -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_CREDENTIALS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_METHODS; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ALLOW_ORIGIN; -import static org.elasticsearch.http.HttpTransportSettings.SETTING_CORS_ENABLED; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.not; -import static org.hamcrest.Matchers.notNullValue; -import static org.hamcrest.Matchers.nullValue; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.atLeastOnce; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -public class NioHttpChannelTests extends ESTestCase { - - private ThreadPool threadPool; - private MockBigArrays bigArrays; - private NioSocketChannel nioChannel; - private SocketChannelContext channelContext; - - @Before - public void setup() throws Exception { - nioChannel = mock(NioSocketChannel.class); - channelContext = mock(SocketChannelContext.class); - when(nioChannel.getContext()).thenReturn(channelContext); - threadPool = new TestThreadPool("test"); - bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); - } - - @After - public void shutdown() throws Exception { - if (threadPool != null) { - threadPool.shutdownNow(); - } - } - - public void testResponse() { - final FullHttpResponse response = executeRequest(Settings.EMPTY, "request-host"); - assertThat(response.content(), equalTo(ByteBufUtils.toByteBuf(new TestResponse().content()))); - } - - public void testCorsEnabledWithoutAllowOrigins() { - // Set up a HTTP transport with only the CORS enabled setting - Settings settings = Settings.builder() - .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, "remote-host", "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); - } - - public void testCorsEnabledWithAllowOrigins() { - final String originValue = "remote-host"; - // create a http transport with CORS enabled and allow origin configured - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testCorsAllowOriginWithSameHost() { - String originValue = "remote-host"; - String host = "remote-host"; - // create a http transport with CORS enabled - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, host); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = "http://" + originValue; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue + ":5555"; - host = host + ":5555"; - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - - originValue = originValue.replace("http", "https"); - response = executeRequest(settings, originValue, host); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - } - - public void testThatStringLiteralWorksOnMatch() { - final String originValue = "remote-host"; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") - .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); - } - - public void testThatAnyOriginWorks() { - final String originValue = NioCorsHandler.ANY_ORIGIN; - Settings settings = Settings.builder() - .put(SETTING_CORS_ENABLED.getKey(), true) - .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) - .build(); - HttpResponse response = executeRequest(settings, originValue, "request-host"); - // inspect response and validate - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); - String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); - assertThat(allowedOrigins, is(originValue)); - assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); - } - - public void testHeadersSet() { - Settings settings = Settings.builder().build(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - httpRequest.headers().add(HttpHeaderNames.ORIGIN, "remote"); - final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest); - HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - - // send a response - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, corsConfig, - threadPool.getThreadContext()); - TestResponse resp = new TestResponse(); - final String customHeader = "custom-header"; - final String customHeaderValue = "xyz"; - resp.addHeader(customHeader, customHeaderValue); - channel.sendResponse(resp); - - // inspect what was written - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); - verify(channelContext).sendMessage(responseCaptor.capture(), any()); - Object nioResponse = responseCaptor.getValue(); - HttpResponse response = ((NioHttpResponse) nioResponse).getResponse(); - assertThat(response.headers().get("non-existent-header"), nullValue()); - assertThat(response.headers().get(customHeader), equalTo(customHeaderValue)); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_LENGTH), equalTo(Integer.toString(resp.content().length()))); - assertThat(response.headers().get(HttpHeaderNames.CONTENT_TYPE), equalTo(resp.contentType())); - } - - @SuppressWarnings("unchecked") - public void testReleaseInListener() throws IOException { - final Settings settings = Settings.builder().build(); - final NamedXContentRegistry registry = xContentRegistry(); - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - final NioHttpRequest request = new NioHttpRequest(registry, httpRequest); - HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, - corsConfig, threadPool.getThreadContext()); - final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, - JsonXContent.contentBuilder().startObject().endObject()); - assertThat(response.content(), not(instanceOf(Releasable.class))); - - // ensure we have reserved bytes - if (randomBoolean()) { - BytesStreamOutput out = channel.bytesOutput(); - assertThat(out, instanceOf(ReleasableBytesStreamOutput.class)); - } else { - try (XContentBuilder builder = channel.newBuilder()) { - // do something builder - builder.startObject().endObject(); - } - } - - channel.sendResponse(response); - Class> listenerClass = (Class>) (Class) BiConsumer.class; - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); - verify(channelContext).sendMessage(any(), listenerCaptor.capture()); - BiConsumer listener = listenerCaptor.getValue(); - if (randomBoolean()) { - listener.accept(null, null); - } else { - listener.accept(null, new ClosedChannelException()); - } - // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released - } - - - @SuppressWarnings("unchecked") - public void testConnectionClose() throws Exception { - final Settings settings = Settings.builder().build(); - final FullHttpRequest httpRequest; - final boolean close = randomBoolean(); - if (randomBoolean()) { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); - } - } else { - httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/"); - if (!close) { - httpRequest.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); - } - } - final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest); - - HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, handlingSettings, - corsConfig, threadPool.getThreadContext()); - final TestResponse resp = new TestResponse(); - channel.sendResponse(resp); - Class> listenerClass = (Class>) (Class) BiConsumer.class; - ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); - verify(channelContext).sendMessage(any(), listenerCaptor.capture()); - BiConsumer listener = listenerCaptor.getValue(); - listener.accept(null, null); - if (close) { - verify(nioChannel, times(1)).close(); - } else { - verify(nioChannel, times(0)).close(); - } - } - - private FullHttpResponse executeRequest(final Settings settings, final String host) { - return executeRequest(settings, null, host); - } - - private FullHttpResponse executeRequest(final Settings settings, final String originValue, final String host) { - // construct request and send it over the transport layer - final FullHttpRequest httpRequest = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); - if (originValue != null) { - httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); - } - httpRequest.headers().add(HttpHeaderNames.HOST, host); - final NioHttpRequest request = new NioHttpRequest(xContentRegistry(), httpRequest); - - HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); - NioCorsConfig corsConfig = NioHttpServerTransport.buildCorsConfig(settings); - NioHttpChannel channel = new NioHttpChannel(nioChannel, bigArrays, request, 1, httpHandlingSettings, corsConfig, - threadPool.getThreadContext()); - channel.sendResponse(new TestResponse()); - - // get the response - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); - verify(channelContext, atLeastOnce()).sendMessage(responseCaptor.capture(), any()); - return ((NioHttpResponse) responseCaptor.getValue()).getResponse(); - } - - private static class TestResponse extends RestResponse { - - private final BytesReference reference; - - TestResponse() { - reference = ByteBufUtils.toBytesReference(Unpooled.copiedBuffer("content", StandardCharsets.UTF_8)); - } - - @Override - public String contentType() { - return "text"; - } - - @Override - public BytesReference content() { - return reference; - } - - @Override - public RestStatus status() { - return RestStatus.OK; - } - - } -} diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java index 94d7db171a563..5f2784a356714 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpPipeliningHandlerTests.java @@ -19,15 +19,12 @@ package org.elasticsearch.http.nio; -import io.netty.buffer.ByteBuf; import io.netty.buffer.ByteBufUtil; -import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.channel.embedded.EmbeddedChannel; import io.netty.handler.codec.http.DefaultFullHttpRequest; -import io.netty.handler.codec.http.DefaultFullHttpResponse; import io.netty.handler.codec.http.FullHttpRequest; import io.netty.handler.codec.http.FullHttpResponse; import io.netty.handler.codec.http.HttpMethod; @@ -35,7 +32,10 @@ import io.netty.handler.codec.http.LastHttpContent; import io.netty.handler.codec.http.QueryStringDecoder; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.http.HttpPipelinedRequest; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.junit.After; @@ -55,7 +55,6 @@ import java.util.stream.IntStream; import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; -import static io.netty.handler.codec.http.HttpResponseStatus.OK; import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; import static org.hamcrest.core.Is.is; @@ -190,11 +189,11 @@ public void testPipeliningRequestsAreReleased() throws InterruptedException { ArrayList promises = new ArrayList<>(); for (int i = 1; i < requests.size(); ++i) { - final FullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK); ChannelPromise promise = embeddedChannel.newPromise(); promises.add(promise); - int sequence = requests.get(i).getSequence(); - NioHttpResponse resp = new NioHttpResponse(sequence, httpResponse); + HttpPipelinedRequest pipelinedRequest = requests.get(i); + NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + NioHttpResponse resp = nioHttpRequest.createResponse(RestStatus.OK, BytesArray.EMPTY); embeddedChannel.writeAndFlush(resp, promise); } @@ -231,10 +230,10 @@ protected void channelRead0(ChannelHandlerContext ctx, HttpRequest request) thro } - private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { + private class WorkEmulatorHandler extends SimpleChannelInboundHandler> { @Override - protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { + protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedRequest pipelinedRequest) { LastHttpContent request = pipelinedRequest.getRequest(); final QueryStringDecoder decoder; if (request instanceof FullHttpRequest) { @@ -244,9 +243,10 @@ protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedReques } final String uri = decoder.path().replace("/", ""); - final ByteBuf content = Unpooled.copiedBuffer(uri, StandardCharsets.UTF_8); - final DefaultFullHttpResponse httpResponse = new DefaultFullHttpResponse(HTTP_1_1, OK, content); - httpResponse.headers().add(CONTENT_LENGTH, content.readableBytes()); + final BytesReference content = new BytesArray(uri.getBytes(StandardCharsets.UTF_8)); + NioHttpRequest nioHttpRequest = new NioHttpRequest(pipelinedRequest.getRequest(), pipelinedRequest.getSequence()); + NioHttpResponse httpResponse = nioHttpRequest.createResponse(RestStatus.OK, content); + httpResponse.addHeader(CONTENT_LENGTH.toString(), Integer.toString(content.length())); final CountDownLatch waitingLatch = new CountDownLatch(1); waitingRequests.put(uri, waitingLatch); @@ -258,7 +258,7 @@ protected void channelRead0(final ChannelHandlerContext ctx, HttpPipelinedReques waitingLatch.await(1000, TimeUnit.SECONDS); final ChannelPromise promise = ctx.newPromise(); eventLoopService.submit(() -> { - ctx.write(new NioHttpResponse(pipelinedRequest.getSequence(), httpResponse), promise); + ctx.write(httpResponse, promise); finishingLatch.countDown(); }); } catch (InterruptedException e) { diff --git a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java index c43fc7d072360..48a5bf617a436 100644 --- a/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java +++ b/plugins/transport-nio/src/test/java/org/elasticsearch/http/nio/NioHttpServerTransportTests.java @@ -280,40 +280,6 @@ public void dispatchBadRequest(final RestRequest request, assertThat(causeReference.get(), instanceOf(TooLongFrameException.class)); } - public void testDispatchDoesNotModifyThreadContext() throws InterruptedException { - final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { - - @Override - public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { - threadContext.putHeader("foo", "bar"); - threadContext.putTransient("bar", "baz"); - } - - @Override - public void dispatchBadRequest(final RestRequest request, - final RestChannel channel, - final ThreadContext threadContext, - final Throwable cause) { - threadContext.putHeader("foo_bad", "bar"); - threadContext.putTransient("bar_bad", "baz"); - } - - }; - - try (NioHttpServerTransport transport = - new NioHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher)) { - transport.start(); - - transport.dispatchRequest(null, null); - assertNull(threadPool.getThreadContext().getHeader("foo")); - assertNull(threadPool.getThreadContext().getTransient("bar")); - - transport.dispatchBadRequest(null, null, null); - assertNull(threadPool.getThreadContext().getHeader("foo_bad")); - assertNull(threadPool.getThreadContext().getTransient("bar_bad")); - } - } - // public void testReadTimeout() throws Exception { // final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { // diff --git a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java index c75754bde5855..4fad4159f55d8 100644 --- a/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java +++ b/server/src/main/java/org/elasticsearch/http/AbstractHttpServerTransport.java @@ -21,6 +21,7 @@ import com.carrotsearch.hppc.IntHashSet; import com.carrotsearch.hppc.IntSet; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.Strings; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.network.NetworkService; @@ -29,7 +30,9 @@ import org.elasticsearch.common.transport.PortsRange; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.threadpool.ThreadPool; @@ -48,11 +51,14 @@ import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_HOST; import static org.elasticsearch.http.HttpTransportSettings.SETTING_HTTP_PUBLISH_PORT; -public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements org.elasticsearch.http.HttpServerTransport { +public abstract class AbstractHttpServerTransport extends AbstractLifecycleComponent implements HttpServerTransport { + public final HttpHandlingSettings handlingSettings; protected final NetworkService networkService; + protected final BigArrays bigArrays; protected final ThreadPool threadPool; protected final Dispatcher dispatcher; + private final NamedXContentRegistry xContentRegistry; protected final String[] bindHosts; protected final String[] publishHosts; @@ -61,11 +67,15 @@ public abstract class AbstractHttpServerTransport extends AbstractLifecycleCompo protected volatile BoundTransportAddress boundAddress; - protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, ThreadPool threadPool, Dispatcher dispatcher) { + protected AbstractHttpServerTransport(Settings settings, NetworkService networkService, BigArrays bigArrays, ThreadPool threadPool, + NamedXContentRegistry xContentRegistry, Dispatcher dispatcher) { super(settings); this.networkService = networkService; + this.bigArrays = bigArrays; this.threadPool = threadPool; + this.xContentRegistry = xContentRegistry; this.dispatcher = dispatcher; + this.handlingSettings = HttpHandlingSettings.fromSettings(settings); // we can't make the network.bind_host a fallback since we already fall back to http.host hence the extra conditional here List httpBindHost = SETTING_HTTP_BIND_HOST.get(settings); @@ -156,17 +166,94 @@ static int resolvePublishPort(Settings settings, List boundAdd return publishPort; } - public void dispatchRequest(final RestRequest request, final RestChannel channel) { + /** + * This method handles an incoming http request. + * + * @param httpRequest that is incoming + * @param httpChannel that received the http request + */ + public void incomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel) { + handleIncomingRequest(httpRequest, httpChannel, null); + } + + /** + * This method handles an incoming http request that has encountered an error. + * + * @param httpRequest that is incoming + * @param httpChannel that received the http request + * @param exception that was encountered + */ + public void incomingRequestError(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { + handleIncomingRequest(httpRequest, httpChannel, exception); + } + + // Visible for testing + void dispatchRequest(final RestRequest restRequest, final RestChannel channel, final Throwable badRequestCause) { final ThreadContext threadContext = threadPool.getThreadContext(); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - dispatcher.dispatchRequest(request, channel, threadContext); + if (badRequestCause != null) { + dispatcher.dispatchBadRequest(restRequest, channel, threadContext, badRequestCause); + } else { + dispatcher.dispatchRequest(restRequest, channel, threadContext); + } } } - public void dispatchBadRequest(final RestRequest request, final RestChannel channel, final Throwable cause) { - final ThreadContext threadContext = threadPool.getThreadContext(); - try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - dispatcher.dispatchBadRequest(request, channel, threadContext, cause); + private void handleIncomingRequest(final HttpRequest httpRequest, final HttpChannel httpChannel, final Exception exception) { + Exception badRequestCause = exception; + + /* + * We want to create a REST request from the incoming request from Netty. However, creating this request could fail if there + * are incorrectly encoded parameters, or the Content-Type header is invalid. If one of these specific failures occurs, we + * attempt to create a REST request again without the input that caused the exception (e.g., we remove the Content-Type header, + * or skip decoding the parameters). Once we have a request in hand, we then dispatch the request as a bad request with the + * underlying exception that caused us to treat the request as bad. + */ + final RestRequest restRequest; + { + RestRequest innerRestRequest; + try { + innerRestRequest = RestRequest.request(xContentRegistry, httpRequest, httpChannel); + } catch (final RestRequest.ContentTypeHeaderException e) { + badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); + innerRestRequest = requestWithoutContentTypeHeader(httpRequest, httpChannel, badRequestCause); + } catch (final RestRequest.BadParameterException e) { + badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); + innerRestRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel); + } + restRequest = innerRestRequest; + } + + /* + * We now want to create a channel used to send the response on. However, creating this channel can fail if there are invalid + * parameter values for any of the filter_path, human, or pretty parameters. We detect these specific failures via an + * IllegalArgumentException from the channel constructor and then attempt to create a new channel that bypasses parsing of these + * parameter values. + */ + final RestChannel channel; + { + RestChannel innerChannel; + ThreadContext threadContext = threadPool.getThreadContext(); + try { + innerChannel = new DefaultRestChannel(httpChannel, httpRequest, restRequest, bigArrays, handlingSettings, threadContext); + } catch (final IllegalArgumentException e) { + badRequestCause = ExceptionsHelper.useOrSuppress(badRequestCause, e); + final RestRequest innerRequest = RestRequest.requestWithoutParameters(xContentRegistry, httpRequest, httpChannel); + innerChannel = new DefaultRestChannel(httpChannel, httpRequest, innerRequest, bigArrays, handlingSettings, threadContext); + } + channel = innerChannel; + } + + dispatchRequest(restRequest, channel, badRequestCause); + } + + private RestRequest requestWithoutContentTypeHeader(HttpRequest httpRequest, HttpChannel httpChannel, Exception badRequestCause) { + HttpRequest httpRequestWithoutContentType = httpRequest.removeHeader("Content-Type"); + try { + return RestRequest.request(xContentRegistry, httpRequestWithoutContentType, httpChannel); + } catch (final RestRequest.BadParameterException e) { + badRequestCause.addSuppressed(e); + return RestRequest.requestWithoutParameters(xContentRegistry, httpRequestWithoutContentType, httpChannel); } } } diff --git a/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java new file mode 100644 index 0000000000000..f5924bb239eae --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/DefaultRestChannel.java @@ -0,0 +1,172 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.lease.Releasables; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +/** + * The default rest channel for incoming requests. This class implements the basic logic for sending a rest + * response. It will set necessary headers nad ensure that bytes are released after the response is sent. + */ +public class DefaultRestChannel extends AbstractRestChannel implements RestChannel { + + static final String CLOSE = "close"; + static final String CONNECTION = "connection"; + static final String KEEP_ALIVE = "keep-alive"; + static final String CONTENT_TYPE = "content-type"; + static final String CONTENT_LENGTH = "content-length"; + static final String SET_COOKIE = "set-cookie"; + static final String X_OPAQUE_ID = "X-Opaque-Id"; + + private final HttpRequest httpRequest; + private final BigArrays bigArrays; + private final HttpHandlingSettings settings; + private final ThreadContext threadContext; + private final HttpChannel httpChannel; + + DefaultRestChannel(HttpChannel httpChannel, HttpRequest httpRequest, RestRequest request, BigArrays bigArrays, + HttpHandlingSettings settings, ThreadContext threadContext) { + super(request, settings.getDetailedErrorsEnabled()); + this.httpChannel = httpChannel; + this.httpRequest = httpRequest; + this.bigArrays = bigArrays; + this.settings = settings; + this.threadContext = threadContext; + } + + @Override + protected BytesStreamOutput newBytesOutput() { + return new ReleasableBytesStreamOutput(bigArrays); + } + + @Override + public void sendResponse(RestResponse restResponse) { + HttpResponse httpResponse; + if (RestRequest.Method.HEAD == request.method()) { + httpResponse = httpRequest.createResponse(restResponse.status(), BytesArray.EMPTY); + } else { + httpResponse = httpRequest.createResponse(restResponse.status(), restResponse.content()); + } + + // TODO: Ideally we should move the setting of Cors headers into :server + // NioCorsHandler.setCorsResponseHeaders(nettyRequest, resp, corsConfig); + + String opaque = request.header(X_OPAQUE_ID); + if (opaque != null) { + setHeaderField(httpResponse, X_OPAQUE_ID, opaque); + } + + // Add all custom headers + addCustomHeaders(httpResponse, restResponse.getHeaders()); + addCustomHeaders(httpResponse, threadContext.getResponseHeaders()); + + ArrayList toClose = new ArrayList<>(3); + + boolean success = false; + try { + // If our response doesn't specify a content-type header, set one + setHeaderField(httpResponse, CONTENT_TYPE, restResponse.contentType(), false); + // If our response has no content-length, calculate and set one + setHeaderField(httpResponse, CONTENT_LENGTH, String.valueOf(restResponse.content().length()), false); + + addCookies(httpResponse); + + BytesReference content = restResponse.content(); + if (content instanceof Releasable) { + toClose.add((Releasable) content); + } + BytesStreamOutput bytesStreamOutput = bytesOutputOrNull(); + if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) { + toClose.add((Releasable) bytesStreamOutput); + } + + if (isCloseConnection()) { + toClose.add(httpChannel); + } + + ActionListener listener = ActionListener.wrap(() -> Releasables.close(toClose)); + httpChannel.sendResponse(httpResponse, listener); + success = true; + } finally { + if (success == false) { + Releasables.close(toClose); + } + } + + } + + private void setHeaderField(HttpResponse response, String headerField, String value) { + setHeaderField(response, headerField, value, true); + } + + private void setHeaderField(HttpResponse response, String headerField, String value, boolean override) { + if (override || !response.containsHeader(headerField)) { + response.addHeader(headerField, value); + } + } + + private void addCustomHeaders(HttpResponse response, Map> customHeaders) { + if (customHeaders != null) { + for (Map.Entry> headerEntry : customHeaders.entrySet()) { + for (String headerValue : headerEntry.getValue()) { + setHeaderField(response, headerEntry.getKey(), headerValue); + } + } + } + } + + private void addCookies(HttpResponse response) { + if (settings.isResetCookies()) { + List cookies = request.getHttpRequest().strictCookies(); + if (cookies.isEmpty() == false) { + for (String cookie : cookies) { + response.addHeader(SET_COOKIE, cookie); + } + } + } + } + + // Determine if the request connection should be closed on completion. + private boolean isCloseConnection() { + final boolean http10 = isHttp10(); + return CLOSE.equalsIgnoreCase(request.header(CONNECTION)) || (http10 && !KEEP_ALIVE.equalsIgnoreCase(request.header(CONNECTION))); + } + + // Determine if the request protocol version is HTTP 1.0 + private boolean isHttp10() { + return request.getHttpRequest().protocolVersion() == HttpRequest.HttpVersion.HTTP_1_0; + } +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpChannel.java b/server/src/main/java/org/elasticsearch/http/HttpChannel.java new file mode 100644 index 0000000000000..baea3e0c3b3c3 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpChannel.java @@ -0,0 +1,58 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.lease.Releasable; + +import java.net.InetSocketAddress; + +public interface HttpChannel extends Releasable { + + /** + * Sends a http response to the channel. The listener will be executed once the send process has been + * completed. + * + * @param response to send to channel + * @param listener to execute upon send completion + */ + void sendResponse(HttpResponse response, ActionListener listener); + + /** + * Returns the local address for this channel. + * + * @return the local address of this channel. + */ + InetSocketAddress getLocalAddress(); + + /** + * Returns the remote address for this channel. Can be null if channel does not have a remote address. + * + * @return the remote address of this channel. + */ + InetSocketAddress getRemoteAddress(); + + /** + * Closes the channel. This might be an asynchronous process. There is no guarantee that the channel + * will be closed when this method returns. + */ + void close(); + +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java b/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java index 7db8666e73ae3..ae1520cba6002 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java +++ b/server/src/main/java/org/elasticsearch/http/HttpPipelinedMessage.java @@ -18,20 +18,17 @@ */ package org.elasticsearch.http; -public class HttpPipelinedMessage implements Comparable { +public interface HttpPipelinedMessage extends Comparable { - private final int sequence; - - public HttpPipelinedMessage(int sequence) { - this.sequence = sequence; - } - - public int getSequence() { - return sequence; - } + /** + * Get the sequence number for this message. + * + * @return the sequence number + */ + int getSequence(); @Override - public int compareTo(HttpPipelinedMessage o) { - return Integer.compare(sequence, o.sequence); + default int compareTo(HttpPipelinedMessage o) { + return Integer.compare(getSequence(), o.getSequence()); } } diff --git a/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java b/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java index df8bd7ee1eb80..db3a2bae16714 100644 --- a/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java +++ b/server/src/main/java/org/elasticsearch/http/HttpPipelinedRequest.java @@ -18,15 +18,21 @@ */ package org.elasticsearch.http; -public class HttpPipelinedRequest extends HttpPipelinedMessage { +public class HttpPipelinedRequest implements HttpPipelinedMessage { private final R request; + private final int sequence; HttpPipelinedRequest(int sequence, R request) { - super(sequence); + this.sequence = sequence; this.request = request; } + @Override + public int getSequence() { + return sequence; + } + public R getRequest() { return request; } diff --git a/server/src/main/java/org/elasticsearch/http/HttpRequest.java b/server/src/main/java/org/elasticsearch/http/HttpRequest.java new file mode 100644 index 0000000000000..496fec23312b0 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpRequest.java @@ -0,0 +1,65 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; + +import java.util.List; +import java.util.Map; + +/** + * A basic http request abstraction. Http modules needs to implement this interface to integrate with the + * server package's rest handling. + */ +public interface HttpRequest { + + enum HttpVersion { + HTTP_1_0, + HTTP_1_1 + } + + RestRequest.Method method(); + + /** + * The uri of the rest request, with the query string. + */ + String uri(); + + BytesReference content(); + + /** + * Get all of the headers and values associated with the headers. Modifications of this map are not supported. + */ + Map> getHeaders(); + + List strictCookies(); + + HttpVersion protocolVersion(); + + HttpRequest removeHeader(String header); + + /** + * Create an http response from this request and the supplied status and content. + */ + HttpResponse createResponse(RestStatus status, BytesReference content); + +} diff --git a/server/src/main/java/org/elasticsearch/http/HttpResponse.java b/server/src/main/java/org/elasticsearch/http/HttpResponse.java new file mode 100644 index 0000000000000..2d363f663c3ef --- /dev/null +++ b/server/src/main/java/org/elasticsearch/http/HttpResponse.java @@ -0,0 +1,32 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +/** + * A basic http response abstraction. Http modules must implement this interface as the server package rest + * handling needs to set http headers for a response. + */ +public interface HttpResponse { + + void addHeader(String name, String value); + + boolean containsHeader(String name); + +} diff --git a/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java b/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java index d376b65ef2d88..4e3d652ec5d7e 100644 --- a/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java +++ b/server/src/main/java/org/elasticsearch/rest/AbstractRestChannel.java @@ -40,7 +40,7 @@ public abstract class AbstractRestChannel implements RestChannel { private static final Predicate EXCLUDE_FILTER = INCLUDE_FILTER.negate(); protected final RestRequest request; - protected final boolean detailedErrorsEnabled; + private final boolean detailedErrorsEnabled; private final String format; private final String filterPath; private final boolean pretty; diff --git a/server/src/main/java/org/elasticsearch/rest/RestController.java b/server/src/main/java/org/elasticsearch/rest/RestController.java index aae63f041fad5..82fcf7178d1dd 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestController.java +++ b/server/src/main/java/org/elasticsearch/rest/RestController.java @@ -272,8 +272,9 @@ boolean dispatchRequest(final RestRequest request, final RestChannel channel, fi */ private static boolean hasContentType(final RestRequest restRequest, final RestHandler restHandler) { if (restRequest.getXContentType() == null) { - if (restHandler.supportsContentStream() && restRequest.header("Content-Type") != null) { - final String lowercaseMediaType = restRequest.header("Content-Type").toLowerCase(Locale.ROOT); + String contentTypeHeader = restRequest.header("Content-Type"); + if (restHandler.supportsContentStream() && contentTypeHeader != null) { + final String lowercaseMediaType = contentTypeHeader.toLowerCase(Locale.ROOT); // we also support newline delimited JSON: http://specs.okfnlabs.org/ndjson/ if (lowercaseMediaType.equals("application/x-ndjson")) { restRequest.setXContentType(XContentType.JSON); diff --git a/server/src/main/java/org/elasticsearch/rest/RestRequest.java b/server/src/main/java/org/elasticsearch/rest/RestRequest.java index 65b4f9d1d3614..813d6feb55167 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestRequest.java +++ b/server/src/main/java/org/elasticsearch/rest/RestRequest.java @@ -35,10 +35,11 @@ import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpRequest; import java.io.IOException; import java.io.InputStream; -import java.net.SocketAddress; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -51,7 +52,7 @@ import static org.elasticsearch.common.unit.ByteSizeValue.parseBytesSizeValue; import static org.elasticsearch.common.unit.TimeValue.parseTimeValue; -public abstract class RestRequest implements ToXContent.Params { +public class RestRequest implements ToXContent.Params { // tchar pattern as defined by RFC7230 section 3.2.6 private static final Pattern TCHAR_PATTERN = Pattern.compile("[a-zA-z0-9!#$%&'*+\\-.\\^_`|~]+"); @@ -62,18 +63,47 @@ public abstract class RestRequest implements ToXContent.Params { private final String rawPath; private final Set consumedParams = new HashSet<>(); private final SetOnce xContentType = new SetOnce<>(); + private final HttpRequest httpRequest; + private final HttpChannel httpChannel; + + protected RestRequest(NamedXContentRegistry xContentRegistry, Map params, String path, + Map> headers, HttpRequest httpRequest, HttpChannel httpChannel) { + final XContentType xContentType; + try { + xContentType = parseContentType(headers.get("Content-Type")); + } catch (final IllegalArgumentException e) { + throw new ContentTypeHeaderException(e); + } + if (xContentType != null) { + this.xContentType.set(xContentType); + } + this.xContentRegistry = xContentRegistry; + this.httpRequest = httpRequest; + this.httpChannel = httpChannel; + this.params = params; + this.rawPath = path; + this.headers = Collections.unmodifiableMap(headers); + } + + protected RestRequest(RestRequest restRequest) { + this(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(), + restRequest.getHttpRequest(), restRequest.getHttpChannel()); + } /** - * Creates a new REST request. + * Creates a new REST request. This method will throw {@link BadParameterException} if the path cannot be + * decoded * * @param xContentRegistry the content registry - * @param uri the raw URI that will be parsed into the path and the parameters - * @param headers a map of the header; this map should implement a case-insensitive lookup + * @param httpRequest the http request + * @param httpChannel the http channel * @throws BadParameterException if the parameters can not be decoded * @throws ContentTypeHeaderException if the Content-Type header can not be parsed */ - public RestRequest(final NamedXContentRegistry xContentRegistry, final String uri, final Map> headers) { - this(xContentRegistry, params(uri), path(uri), headers); + public static RestRequest request(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, HttpChannel httpChannel) { + Map params = params(httpRequest.uri()); + String path = path(httpRequest.uri()); + return new RestRequest(xContentRegistry, params, path, httpRequest.getHeaders(), httpRequest, httpChannel); } private static Map params(final String uri) { @@ -99,46 +129,34 @@ private static String path(final String uri) { } /** - * Creates a new REST request. In contrast to - * {@link RestRequest#RestRequest(NamedXContentRegistry, Map, String, Map)}, the path is not decoded so this constructor will not throw - * a {@link BadParameterException}. + * Creates a new REST request. The path is not decoded so this constructor will not throw a + * {@link BadParameterException}. * * @param xContentRegistry the content registry - * @param params the request parameters - * @param path the raw path (which is not parsed) - * @param headers a map of the header; this map should implement a case-insensitive lookup + * @param httpRequest the http request + * @param httpChannel the http channel * @throws ContentTypeHeaderException if the Content-Type header can not be parsed */ - public RestRequest( - final NamedXContentRegistry xContentRegistry, - final Map params, - final String path, - final Map> headers) { - final XContentType xContentType; - try { - xContentType = parseContentType(headers.get("Content-Type")); - } catch (final IllegalArgumentException e) { - throw new ContentTypeHeaderException(e); - } - if (xContentType != null) { - this.xContentType.set(xContentType); - } - this.xContentRegistry = xContentRegistry; - this.params = params; - this.rawPath = path; - this.headers = Collections.unmodifiableMap(headers); + public static RestRequest requestWithoutParameters(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, + HttpChannel httpChannel) { + Map params = Collections.emptyMap(); + return new RestRequest(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel); } public enum Method { GET, POST, PUT, DELETE, OPTIONS, HEAD, PATCH, TRACE, CONNECT } - public abstract Method method(); + public Method method() { + return httpRequest.method(); + } /** * The uri of the rest request, with the query string. */ - public abstract String uri(); + public String uri() { + return httpRequest.uri(); + } /** * The non decoded, raw path provided. @@ -154,9 +172,13 @@ public final String path() { return RestUtils.decodeComponent(rawPath()); } - public abstract boolean hasContent(); + public boolean hasContent() { + return content().length() > 0; + } - public abstract BytesReference content(); + public BytesReference content() { + return httpRequest.content(); + } /** * @return content of the request body or throw an exception if the body or content type is missing @@ -216,14 +238,12 @@ final void setXContentType(XContentType xContentType) { this.xContentType.set(xContentType); } - @Nullable - public SocketAddress getRemoteAddress() { - return null; + public HttpChannel getHttpChannel() { + return httpChannel; } - @Nullable - public SocketAddress getLocalAddress() { - return null; + public HttpRequest getHttpRequest() { + return httpRequest; } public final boolean hasParam(String key) { diff --git a/server/src/main/java/org/elasticsearch/rest/RestResponse.java b/server/src/main/java/org/elasticsearch/rest/RestResponse.java index 7e031f8d004e1..d0d6fa752d68e 100644 --- a/server/src/main/java/org/elasticsearch/rest/RestResponse.java +++ b/server/src/main/java/org/elasticsearch/rest/RestResponse.java @@ -20,10 +20,10 @@ package org.elasticsearch.rest; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.bytes.BytesReference; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -31,8 +31,7 @@ public abstract class RestResponse { - protected Map> customHeaders; - + private Map> customHeaders; /** * The response content type. @@ -81,10 +80,13 @@ public void addHeader(String name, String value) { } /** - * Returns custom headers that have been added, or null if none have been set. + * Returns custom headers that have been added. This method should not be used to mutate headers. */ - @Nullable public Map> getHeaders() { - return customHeaders; + if (customHeaders == null) { + return Collections.emptyMap(); + } else { + return customHeaders; + } } } diff --git a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java index ee74d98002faa..a7629e5f48b6c 100644 --- a/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java +++ b/server/src/test/java/org/elasticsearch/http/AbstractHttpServerTransportTests.java @@ -19,13 +19,27 @@ package org.elasticsearch.http; +import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.network.NetworkUtils; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; +import java.io.IOException; +import java.net.InetAddress; import java.net.UnknownHostException; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import static java.net.InetAddress.getByName; @@ -36,6 +50,27 @@ public class AbstractHttpServerTransportTests extends ESTestCase { + private NetworkService networkService; + private ThreadPool threadPool; + private MockBigArrays bigArrays; + + @Before + public void setup() throws Exception { + networkService = new NetworkService(Collections.emptyList()); + threadPool = new TestThreadPool("test"); + bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + @After + public void shutdown() throws Exception { + if (threadPool != null) { + threadPool.shutdownNow(); + } + threadPool = null; + networkService = null; + bigArrays = null; + } + public void testHttpPublishPort() throws Exception { int boundPort = randomIntBetween(9000, 9100); int otherBoundPort = randomIntBetween(9200, 9300); @@ -71,6 +106,64 @@ public void testHttpPublishPort() throws Exception { } } + public void testDispatchDoesNotModifyThreadContext() { + final HttpServerTransport.Dispatcher dispatcher = new HttpServerTransport.Dispatcher() { + + @Override + public void dispatchRequest(final RestRequest request, final RestChannel channel, final ThreadContext threadContext) { + threadContext.putHeader("foo", "bar"); + threadContext.putTransient("bar", "baz"); + } + + @Override + public void dispatchBadRequest(final RestRequest request, + final RestChannel channel, + final ThreadContext threadContext, + final Throwable cause) { + threadContext.putHeader("foo_bad", "bar"); + threadContext.putTransient("bar_bad", "baz"); + } + + }; + + try (AbstractHttpServerTransport transport = + new AbstractHttpServerTransport(Settings.EMPTY, networkService, bigArrays, threadPool, xContentRegistry(), dispatcher) { + @Override + protected TransportAddress bindAddress(InetAddress hostAddress) { + return null; + } + + @Override + protected void doStart() { + + } + + @Override + protected void doStop() { + + } + + @Override + protected void doClose() throws IOException { + + } + + @Override + public HttpStats stats() { + return null; + } + }) { + + transport.dispatchRequest(null, null, null); + assertNull(threadPool.getThreadContext().getHeader("foo")); + assertNull(threadPool.getThreadContext().getTransient("bar")); + + transport.dispatchRequest(null, null, new Exception()); + assertNull(threadPool.getThreadContext().getHeader("foo_bad")); + assertNull(threadPool.getThreadContext().getTransient("bar_bad")); + } + } + private TransportAddress address(String host, int port) throws UnknownHostException { return new TransportAddress(getByName(host), port); } diff --git a/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java new file mode 100644 index 0000000000000..bc499ed8a420a --- /dev/null +++ b/server/src/test/java/org/elasticsearch/http/DefaultRestChannelTests.java @@ -0,0 +1,444 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.elasticsearch.http; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput; +import org.elasticsearch.common.lease.Releasable; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.MockBigArrays; +import org.elasticsearch.common.util.MockPageCacheRecycler; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.indices.breaker.NoneCircuitBreakerService; +import org.elasticsearch.rest.BytesRestResponse; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.junit.After; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.io.IOException; +import java.nio.channels.ClosedChannelException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasItem; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class DefaultRestChannelTests extends ESTestCase { + + private ThreadPool threadPool; + private MockBigArrays bigArrays; + private HttpChannel httpChannel; + + @Before + public void setup() { + httpChannel = mock(HttpChannel.class); + threadPool = new TestThreadPool("test"); + bigArrays = new MockBigArrays(new MockPageCacheRecycler(Settings.EMPTY), new NoneCircuitBreakerService()); + } + + @After + public void shutdown() { + if (threadPool != null) { + threadPool.shutdownNow(); + } + } + + public void testResponse() { + final TestResponse response = executeRequest(Settings.EMPTY, "request-host"); + assertThat(response.content(), equalTo(new TestRestResponse().content())); + } + + // TODO: Enable these Cors tests when the Cors logic lives in :server + +// public void testCorsEnabledWithoutAllowOrigins() { +// // Set up a HTTP transport with only the CORS enabled setting +// Settings settings = Settings.builder() +// .put(HttpTransportSettings.SETTING_CORS_ENABLED.getKey(), true) +// .build(); +// HttpResponse response = executeRequest(settings, "remote-host", "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), nullValue()); +// } +// +// public void testCorsEnabledWithAllowOrigins() { +// final String originValue = "remote-host"; +// // create a http transport with CORS enabled and allow origin configured +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// } +// +// public void testCorsAllowOriginWithSameHost() { +// String originValue = "remote-host"; +// String host = "remote-host"; +// // create a http transport with CORS enabled +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, host); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// +// originValue = "http://" + originValue; +// response = executeRequest(settings, originValue, host); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// +// originValue = originValue + ":5555"; +// host = host + ":5555"; +// response = executeRequest(settings, originValue, host); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// +// originValue = originValue.replace("http", "https"); +// response = executeRequest(settings, originValue, host); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// } +// +// public void testThatStringLiteralWorksOnMatch() { +// final String originValue = "remote-host"; +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) +// .put(SETTING_CORS_ALLOW_METHODS.getKey(), "get, options, post") +// .put(SETTING_CORS_ALLOW_CREDENTIALS.getKey(), true) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); +// } +// +// public void testThatAnyOriginWorks() { +// final String originValue = NioCorsHandler.ANY_ORIGIN; +// Settings settings = Settings.builder() +// .put(SETTING_CORS_ENABLED.getKey(), true) +// .put(SETTING_CORS_ALLOW_ORIGIN.getKey(), originValue) +// .build(); +// HttpResponse response = executeRequest(settings, originValue, "request-host"); +// // inspect response and validate +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN), notNullValue()); +// String allowedOrigins = response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN); +// assertThat(allowedOrigins, is(originValue)); +// assertThat(response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS), nullValue()); +// } + + public void testHeadersSet() { + Settings settings = Settings.builder().build(); + final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc")); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + // send a response + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + TestRestResponse resp = new TestRestResponse(); + final String customHeader = "custom-header"; + final String customHeaderValue = "xyz"; + resp.addHeader(customHeader, customHeaderValue); + channel.sendResponse(resp); + + // inspect what was written + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + verify(httpChannel).sendResponse(responseCaptor.capture(), any()); + TestResponse httpResponse = responseCaptor.getValue(); + Map> headers = httpResponse.headers; + assertNull(headers.get("non-existent-header")); + assertEquals(customHeaderValue, headers.get(customHeader).get(0)); + assertEquals("abc", headers.get(DefaultRestChannel.X_OPAQUE_ID).get(0)); + assertEquals(Integer.toString(resp.content().length()), headers.get(DefaultRestChannel.CONTENT_LENGTH).get(0)); + assertEquals(resp.contentType(), headers.get(DefaultRestChannel.CONTENT_TYPE).get(0)); + } + + public void testCookiesSet() { + Settings settings = Settings.builder().put(HttpTransportSettings.SETTING_HTTP_RESET_COOKIES.getKey(), true).build(); + final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + httpRequest.getHeaders().put(DefaultRestChannel.X_OPAQUE_ID, Collections.singletonList("abc")); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + // send a response + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + channel.sendResponse(new TestRestResponse()); + + // inspect what was written + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + verify(httpChannel).sendResponse(responseCaptor.capture(), any()); + TestResponse nioResponse = responseCaptor.getValue(); + Map> headers = nioResponse.headers; + assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie")); + assertThat(headers.get(DefaultRestChannel.SET_COOKIE), hasItem("cookie2")); + } + + @SuppressWarnings("unchecked") + public void testReleaseInListener() throws IOException { + final Settings settings = Settings.builder().build(); + final TestRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + final BytesRestResponse response = new BytesRestResponse(RestStatus.INTERNAL_SERVER_ERROR, + JsonXContent.contentBuilder().startObject().endObject()); + assertThat(response.content(), not(instanceOf(Releasable.class))); + + // ensure we have reserved bytes + if (randomBoolean()) { + BytesStreamOutput out = channel.bytesOutput(); + assertThat(out, instanceOf(ReleasableBytesStreamOutput.class)); + } else { + try (XContentBuilder builder = channel.newBuilder()) { + // do something builder + builder.startObject().endObject(); + } + } + + channel.sendResponse(response); + Class> listenerClass = (Class>) (Class) ActionListener.class; + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); + verify(httpChannel).sendResponse(any(), listenerCaptor.capture()); + ActionListener listener = listenerCaptor.getValue(); + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new ClosedChannelException()); + } + // ESTestCase#after will invoke ensureAllArraysAreReleased which will fail if the response content was not released + } + + @SuppressWarnings("unchecked") + public void testConnectionClose() throws Exception { + final Settings settings = Settings.builder().build(); + final HttpRequest httpRequest; + final boolean close = randomBoolean(); + if (randomBoolean()) { + httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + if (close) { + httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.CLOSE)); + } + } else { + httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_0, RestRequest.Method.GET, "/"); + if (!close) { + httpRequest.getHeaders().put(DefaultRestChannel.CONNECTION, Collections.singletonList(DefaultRestChannel.KEEP_ALIVE)); + } + } + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + + HttpHandlingSettings handlingSettings = HttpHandlingSettings.fromSettings(settings); + + DefaultRestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, handlingSettings, + threadPool.getThreadContext()); + channel.sendResponse(new TestRestResponse()); + Class> listenerClass = (Class>) (Class) ActionListener.class; + ArgumentCaptor> listenerCaptor = ArgumentCaptor.forClass(listenerClass); + verify(httpChannel).sendResponse(any(), listenerCaptor.capture()); + ActionListener listener = listenerCaptor.getValue(); + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new ClosedChannelException()); + } + if (close) { + verify(httpChannel, times(1)).close(); + } else { + verify(httpChannel, times(0)).close(); + } + } + + private TestResponse executeRequest(final Settings settings, final String host) { + return executeRequest(settings, null, host); + } + + private TestResponse executeRequest(final Settings settings, final String originValue, final String host) { + HttpRequest httpRequest = new TestRequest(HttpRequest.HttpVersion.HTTP_1_1, RestRequest.Method.GET, "/"); + // TODO: These exist for the Cors tests +// if (originValue != null) { +// httpRequest.headers().add(HttpHeaderNames.ORIGIN, originValue); +// } +// httpRequest.headers().add(HttpHeaderNames.HOST, host); + final RestRequest request = RestRequest.request(xContentRegistry(), httpRequest, httpChannel); + + HttpHandlingSettings httpHandlingSettings = HttpHandlingSettings.fromSettings(settings); + RestChannel channel = new DefaultRestChannel(httpChannel, httpRequest, request, bigArrays, httpHandlingSettings, + threadPool.getThreadContext()); + channel.sendResponse(new TestRestResponse()); + + // get the response + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(TestResponse.class); + verify(httpChannel, atLeastOnce()).sendResponse(responseCaptor.capture(), any()); + return responseCaptor.getValue(); + } + + private static class TestRequest implements HttpRequest { + + private final HttpVersion version; + private final RestRequest.Method method; + private final String uri; + private HashMap> headers = new HashMap<>(); + + private TestRequest(HttpVersion version, RestRequest.Method method, String uri) { + + this.version = version; + this.method = method; + this.uri = uri; + } + + @Override + public RestRequest.Method method() { + return method; + } + + @Override + public String uri() { + return uri; + } + + @Override + public BytesReference content() { + return BytesArray.EMPTY; + } + + @Override + public Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + return Arrays.asList("cookie", "cookie2"); + } + + @Override + public HttpVersion protocolVersion() { + return version; + } + + @Override + public HttpRequest removeHeader(String header) { + throw new UnsupportedOperationException("Do not support removing header on test request."); + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + return new TestResponse(status, content); + } + } + + private static class TestResponse implements HttpResponse { + + private final RestStatus status; + private final BytesReference content; + private final Map> headers = new HashMap<>(); + + TestResponse(RestStatus status, BytesReference content) { + this.status = status; + this.content = content; + } + + public String contentType() { + return "text"; + } + + public BytesReference content() { + return content; + } + + public RestStatus status() { + return status; + } + + @Override + public void addHeader(String name, String value) { + if (headers.containsKey(name) == false) { + ArrayList values = new ArrayList<>(); + values.add(value); + headers.put(name, values); + } else { + headers.get(name).add(value); + } + } + + @Override + public boolean containsHeader(String name) { + return headers.containsKey(name); + } + } + + private static class TestRestResponse extends RestResponse { + + private final BytesReference content; + + TestRestResponse() { + content = new BytesArray("content".getBytes(StandardCharsets.UTF_8)); + } + + public String contentType() { + return "text"; + } + + public BytesReference content() { + return content; + } + + public RestStatus status() { + return RestStatus.OK; + } + } +} diff --git a/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java b/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java index a0e6f7020302d..a80c3b1bd4238 100644 --- a/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java +++ b/server/src/test/java/org/elasticsearch/rest/BytesRestResponseTests.java @@ -29,7 +29,6 @@ import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.transport.TransportAddress; -import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentParser; import org.elasticsearch.common.xcontent.XContentType; @@ -165,28 +164,7 @@ public void testConvert() throws IOException { public void testResponseWhenPathContainsEncodingError() throws IOException { final String path = "%a"; - final RestRequest request = - new RestRequest(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, Collections.emptyMap()) { - @Override - public Method method() { - return null; - } - - @Override - public String uri() { - return null; - } - - @Override - public boolean hasContent() { - return false; - } - - @Override - public BytesReference content() { - return null; - } - }; + final RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withPath(path).build(); final IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> RestUtils.decodeComponent(request.rawPath())); final RestChannel channel = new DetailedExceptionRestChannel(request); // if we try to decode the path, this will throw an IllegalArgumentException again diff --git a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java index f36638a43909f..a090cc40b6857 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestControllerTests.java @@ -110,21 +110,21 @@ public void testApplyRelevantHeaders() throws Exception { RestRequest fakeRequest = new FakeRestRequest.Builder(xContentRegistry()).withHeaders(restHeaders).build(); final RestController spyRestController = spy(restController); when(spyRestController.getAllHandlers(fakeRequest)) - .thenReturn(new Iterator() { - @Override - public boolean hasNext() { - return false; - } - - @Override - public MethodHandlers next() { - return new MethodHandlers("/", (RestRequest request, RestChannel channel, NodeClient client) -> { - assertEquals("true", threadContext.getHeader("header.1")); - assertEquals("true", threadContext.getHeader("header.2")); - assertNull(threadContext.getHeader("header.3")); - }, RestRequest.Method.GET); - } - }); + .thenReturn(new Iterator() { + @Override + public boolean hasNext() { + return false; + } + + @Override + public MethodHandlers next() { + return new MethodHandlers("/", (RestRequest request, RestChannel channel, NodeClient client) -> { + assertEquals("true", threadContext.getHeader("header.1")); + assertEquals("true", threadContext.getHeader("header.2")); + assertNull(threadContext.getHeader("header.3")); + }, RestRequest.Method.GET); + } + }); AssertingChannel channel = new AssertingChannel(fakeRequest, false, RestStatus.BAD_REQUEST); restController.dispatchRequest(fakeRequest, channel, threadContext); // the rest controller relies on the caller to stash the context, so we should expect these values here as we didn't stash the @@ -136,7 +136,7 @@ public MethodHandlers next() { public void testCanTripCircuitBreaker() throws Exception { RestController controller = new RestController(Settings.EMPTY, Collections.emptySet(), null, null, circuitBreakerService, - usageService); + usageService); // trip circuit breaker by default controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true)); controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false)); @@ -209,7 +209,7 @@ public void testRestHandlerWrapper() throws Exception { return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true); }; final RestController restController = new RestController(Settings.EMPTY, Collections.emptySet(), wrapper, null, - circuitBreakerService, usageService); + circuitBreakerService, usageService); final ThreadContext threadContext = new ThreadContext(Settings.EMPTY); restController.dispatchRequest(new FakeRestRequest.Builder(xContentRegistry()).build(), null, null, Optional.of(handler)); assertTrue(wrapperCalled.get()); @@ -240,7 +240,7 @@ public boolean canTripCircuitBreaker() { public void testDispatchRequestAddsAndFreesBytesOnSuccess() { int contentLength = BREAKER_LIMIT.bytesAsInt(); String content = randomAlphaOfLength(contentLength); - TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON); + RestRequest request = testRestRequest("/", content, XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.OK); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -252,7 +252,7 @@ public void testDispatchRequestAddsAndFreesBytesOnSuccess() { public void testDispatchRequestAddsAndFreesBytesOnError() { int contentLength = BREAKER_LIMIT.bytesAsInt(); String content = randomAlphaOfLength(contentLength); - TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON); + RestRequest request = testRestRequest("/error", content, XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.BAD_REQUEST); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -265,7 +265,7 @@ public void testDispatchRequestAddsAndFreesBytesOnlyOnceOnError() { int contentLength = BREAKER_LIMIT.bytesAsInt(); String content = randomAlphaOfLength(contentLength); // we will produce an error in the rest handler and one more when sending the error response - TestRestRequest request = new TestRestRequest("/error", content, XContentType.JSON); + RestRequest request = testRestRequest("/error", content, XContentType.JSON); ExceptionThrowingChannel channel = new ExceptionThrowingChannel(request, true); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -277,7 +277,7 @@ public void testDispatchRequestAddsAndFreesBytesOnlyOnceOnError() { public void testDispatchRequestLimitsBytes() { int contentLength = BREAKER_LIMIT.bytesAsInt() + 1; String content = randomAlphaOfLength(contentLength); - TestRestRequest request = new TestRestRequest("/", content, XContentType.JSON); + RestRequest request = testRestRequest("/", content, XContentType.JSON); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.SERVICE_UNAVAILABLE); restController.dispatchRequest(request, channel, new ThreadContext(Settings.EMPTY)); @@ -288,11 +288,11 @@ public void testDispatchRequestLimitsBytes() { public void testDispatchRequiresContentTypeForRequestsWithContent() { String content = randomAlphaOfLengthBetween(1, BREAKER_LIMIT.bytesAsInt()); - TestRestRequest request = new TestRestRequest("/", content, null); + RestRequest request = testRestRequest("/", content, null); AssertingChannel channel = new AssertingChannel(request, true, RestStatus.NOT_ACCEPTABLE); restController = new RestController( Settings.builder().put(HttpTransportSettings.SETTING_HTTP_CONTENT_TYPE_REQUIRED.getKey(), true).build(), - Collections.emptySet(), null, null, circuitBreakerService, usageService); + Collections.emptySet(), null, null, circuitBreakerService, usageService); restController.registerHandler(RestRequest.Method.GET, "/", (r, c, client) -> c.sendResponse( new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY))); @@ -412,8 +412,8 @@ public boolean supportsContentStream() { public void testNonStreamingXContentCausesErrorResponse() throws IOException { FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY) - .withContent(BytesReference.bytes(YamlXContent.contentBuilder().startObject().endObject()), - XContentType.YAML).withPath("/foo").build(); + .withContent(BytesReference.bytes(YamlXContent.contentBuilder().startObject().endObject()), + XContentType.YAML).withPath("/foo").build(); AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.NOT_ACCEPTABLE); restController.registerHandler(RestRequest.Method.GET, "/foo", new RestHandler() { @Override @@ -457,10 +457,10 @@ public void testDispatchBadRequest() { final FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).build(); final AssertingChannel channel = new AssertingChannel(fakeRestRequest, true, RestStatus.BAD_REQUEST); restController.dispatchBadRequest( - fakeRestRequest, - channel, - new ThreadContext(Settings.EMPTY), - randomBoolean() ? new IllegalStateException("bad request") : new Throwable("bad request")); + fakeRestRequest, + channel, + new ThreadContext(Settings.EMPTY), + randomBoolean() ? new IllegalStateException("bad request") : new Throwable("bad request")); assertTrue(channel.getSendResponseCalled()); assertThat(channel.getRestResponse().content().utf8ToString(), containsString("bad request")); } @@ -495,7 +495,7 @@ protected void doClose() { @Override public BoundTransportAddress boundAddress() { TransportAddress transportAddress = buildNewFakeTransportAddress(); - return new BoundTransportAddress(new TransportAddress[] {transportAddress} ,transportAddress); + return new BoundTransportAddress(new TransportAddress[]{transportAddress}, transportAddress); } @Override @@ -547,35 +547,11 @@ public void sendResponse(RestResponse response) { } } - private static final class TestRestRequest extends RestRequest { - - private final BytesReference content; - - private TestRestRequest(String path, String content, XContentType xContentType) { - super(NamedXContentRegistry.EMPTY, Collections.emptyMap(), path, xContentType == null ? - Collections.emptyMap() : Collections.singletonMap("Content-Type", Collections.singletonList(xContentType.mediaType()))); - this.content = new BytesArray(content); - } - - @Override - public Method method() { - return Method.GET; - } - - @Override - public String uri() { - return null; - } - - @Override - public boolean hasContent() { - return true; - } - - @Override - public BytesReference content() { - return content; - } - + private static RestRequest testRestRequest(String path, String content, XContentType xContentType) { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + builder.withPath(path); + builder.withContent(new BytesArray(content), xContentType); + return builder.build(); } } + diff --git a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java index 1b4bbff7322de..3ad9c61de3c8e 100644 --- a/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java +++ b/server/src/test/java/org/elasticsearch/rest/RestRequestTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.rest.FakeRestRequest; import java.io.IOException; import java.util.ArrayList; @@ -44,66 +45,66 @@ public class RestRequestTests extends ESTestCase { public void testContentParser() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).contentParser()); + contentRestRequest("", emptyMap()).contentParser()); assertEquals("request body is required", e.getMessage()); e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", singletonMap("source", "{}")).contentParser()); + contentRestRequest("", singletonMap("source", "{}")).contentParser()); assertEquals("request body is required", e.getMessage()); - assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentParser().map()); + assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentParser().map()); e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap(), emptyMap()).contentParser()); + contentRestRequest("", emptyMap(), emptyMap()).contentParser()); assertEquals("request body is required", e.getMessage()); } public void testApplyContentParser() throws IOException { - new ContentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called")); - new ContentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called")); + contentRestRequest("", emptyMap()).applyContentParser(p -> fail("Shouldn't have been called")); + contentRestRequest("", singletonMap("source", "{}")).applyContentParser(p -> fail("Shouldn't have been called")); AtomicReference source = new AtomicReference<>(); - new ContentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map())); + contentRestRequest("{}", emptyMap()).applyContentParser(p -> source.set(p.map())); assertEquals(emptyMap(), source.get()); } public void testContentOrSourceParam() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).contentOrSourceParam()); + contentRestRequest("", emptyMap()).contentOrSourceParam()); assertEquals("request body or source parameter is required", e.getMessage()); - assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2()); + assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).contentOrSourceParam().v2()); assertEquals(new BytesArray("stuff"), - new ContentRestRequest("stuff", MapBuilder.newMapBuilder() + contentRestRequest("stuff", MapBuilder.newMapBuilder() .put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).contentOrSourceParam().v2()); assertEquals(new BytesArray("{\"foo\": \"stuff\"}"), - new ContentRestRequest("", MapBuilder.newMapBuilder() + contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap()) .contentOrSourceParam().v2()); e = expectThrows(IllegalStateException.class, () -> - new ContentRestRequest("", MapBuilder.newMapBuilder() + contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "stuff2").immutableMap()).contentOrSourceParam()); assertEquals("source and source_content_type parameters are required", e.getMessage()); } public void testHasContentOrSourceParam() throws IOException { - assertEquals(false, new ContentRestRequest("", emptyMap()).hasContentOrSourceParam()); - assertEquals(true, new ContentRestRequest("stuff", emptyMap()).hasContentOrSourceParam()); - assertEquals(true, new ContentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam()); - assertEquals(true, new ContentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam()); + assertEquals(false, contentRestRequest("", emptyMap()).hasContentOrSourceParam()); + assertEquals(true, contentRestRequest("stuff", emptyMap()).hasContentOrSourceParam()); + assertEquals(true, contentRestRequest("stuff", singletonMap("source", "stuff2")).hasContentOrSourceParam()); + assertEquals(true, contentRestRequest("", singletonMap("source", "stuff")).hasContentOrSourceParam()); } public void testContentOrSourceParamParser() throws IOException { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).contentOrSourceParamParser()); + contentRestRequest("", emptyMap()).contentOrSourceParamParser()); assertEquals("request body or source parameter is required", e.getMessage()); - assertEquals(emptyMap(), new ContentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map()); - assertEquals(emptyMap(), new ContentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map()); - assertEquals(emptyMap(), new ContentRestRequest("", MapBuilder.newMapBuilder() + assertEquals(emptyMap(), contentRestRequest("{}", emptyMap()).contentOrSourceParamParser().map()); + assertEquals(emptyMap(), contentRestRequest("{}", singletonMap("source", "stuff2")).contentOrSourceParamParser().map()); + assertEquals(emptyMap(), contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "{}").put("source_content_type", "application/json").immutableMap()).contentOrSourceParamParser().map()); } public void testWithContentOrSourceParamParserOrNull() throws IOException { - new ContentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser)); - new ContentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); - new ContentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser -> + contentRestRequest("", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertNull(parser)); + contentRestRequest("{}", emptyMap()).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); + contentRestRequest("{}", singletonMap("source", "stuff2")).withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); - new ContentRestRequest("", MapBuilder.newMapBuilder().put("source_content_type", "application/json") + contentRestRequest("", MapBuilder.newMapBuilder().put("source_content_type", "application/json") .put("source", "{}").immutableMap()) .withContentOrSourceParamParserOrNull(parser -> assertEquals(emptyMap(), parser.map())); @@ -113,18 +114,18 @@ public void testContentTypeParsing() { for (XContentType xContentType : XContentType.values()) { Map> map = new HashMap<>(); map.put("Content-Type", Collections.singletonList(xContentType.mediaType())); - ContentRestRequest restRequest = new ContentRestRequest("", Collections.emptyMap(), map); + RestRequest restRequest = contentRestRequest("", Collections.emptyMap(), map); assertEquals(xContentType, restRequest.getXContentType()); map = new HashMap<>(); map.put("Content-Type", Collections.singletonList(xContentType.mediaTypeWithoutParameters())); - restRequest = new ContentRestRequest("", Collections.emptyMap(), map); + restRequest = contentRestRequest("", Collections.emptyMap(), map); assertEquals(xContentType, restRequest.getXContentType()); } } public void testPlainTextSupport() { - ContentRestRequest restRequest = new ContentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(), + RestRequest restRequest = contentRestRequest(randomAlphaOfLengthBetween(1, 30), Collections.emptyMap(), Collections.singletonMap("Content-Type", Collections.singletonList(randomFrom("text/plain", "text/plain; charset=utf-8", "text/plain;charset=utf-8")))); assertNull(restRequest.getXContentType()); @@ -136,7 +137,7 @@ public void testMalformedContentTypeHeader() { RestRequest.ContentTypeHeaderException.class, () -> { final Map> headers = Collections.singletonMap("Content-Type", Collections.singletonList(type)); - new ContentRestRequest("", Collections.emptyMap(), headers); + contentRestRequest("", Collections.emptyMap(), headers); }); assertNotNull(e.getCause()); assertThat(e.getCause(), instanceOf(IllegalArgumentException.class)); @@ -144,7 +145,7 @@ public void testMalformedContentTypeHeader() { } public void testNoContentTypeHeader() { - ContentRestRequest contentRestRequest = new ContentRestRequest("", Collections.emptyMap(), Collections.emptyMap()); + RestRequest contentRestRequest = contentRestRequest("", Collections.emptyMap(), Collections.emptyMap()); assertNull(contentRestRequest.getXContentType()); } @@ -152,7 +153,7 @@ public void testMultipleContentTypeHeaders() { List headers = new ArrayList<>(randomUnique(() -> randomAlphaOfLengthBetween(1, 16), randomIntBetween(2, 10))); final RestRequest.ContentTypeHeaderException e = expectThrows( RestRequest.ContentTypeHeaderException.class, - () -> new ContentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers))); + () -> contentRestRequest("", Collections.emptyMap(), Collections.singletonMap("Content-Type", headers))); assertNotNull(e.getCause()); assertThat(e.getCause(), instanceOf((IllegalArgumentException.class))); assertThat(e.getMessage(), equalTo("java.lang.IllegalArgumentException: only one Content-Type header should be provided")); @@ -160,52 +161,64 @@ public void testMultipleContentTypeHeaders() { public void testRequiredContent() { Exception e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", emptyMap()).requiredContent()); + contentRestRequest("", emptyMap()).requiredContent()); assertEquals("request body is required", e.getMessage()); - assertEquals(new BytesArray("stuff"), new ContentRestRequest("stuff", emptyMap()).requiredContent()); + assertEquals(new BytesArray("stuff"), contentRestRequest("stuff", emptyMap()).requiredContent()); assertEquals(new BytesArray("stuff"), - new ContentRestRequest("stuff", MapBuilder.newMapBuilder() + contentRestRequest("stuff", MapBuilder.newMapBuilder() .put("source", "stuff2").put("source_content_type", "application/json").immutableMap()).requiredContent()); e = expectThrows(ElasticsearchParseException.class, () -> - new ContentRestRequest("", MapBuilder.newMapBuilder() + contentRestRequest("", MapBuilder.newMapBuilder() .put("source", "{\"foo\": \"stuff\"}").put("source_content_type", "application/json").immutableMap()) .requiredContent()); assertEquals("request body is required", e.getMessage()); e = expectThrows(IllegalStateException.class, () -> - new ContentRestRequest("test", null, Collections.emptyMap()).requiredContent()); + contentRestRequest("test", null, Collections.emptyMap()).requiredContent()); assertEquals("unknown content type", e.getMessage()); } + private static RestRequest contentRestRequest(String content, Map params) { + Map> headers = new HashMap<>(); + headers.put("Content-Type", Collections.singletonList("application/json")); + return contentRestRequest(content, params, headers); + } + + private static RestRequest contentRestRequest(String content, Map params, Map> headers) { + FakeRestRequest.Builder builder = new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY); + builder.withHeaders(headers); + builder.withContent(new BytesArray(content), null); + builder.withParams(params); + return new ContentRestRequest(builder.build()); + } + private static final class ContentRestRequest extends RestRequest { - private final BytesArray content; - ContentRestRequest(String content, Map params) { - this(content, params, Collections.singletonMap("Content-Type", Collections.singletonList("application/json"))); - } + private final RestRequest restRequest; - ContentRestRequest(String content, Map params, Map> headers) { - super(NamedXContentRegistry.EMPTY, params, "not used by this test", headers); - this.content = new BytesArray(content); + private ContentRestRequest(RestRequest restRequest) { + super(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders(), + restRequest.getHttpRequest(), restRequest.getHttpChannel()); + this.restRequest = restRequest; } @Override - public boolean hasContent() { - return Strings.hasLength(content); + public Method method() { + return restRequest.method(); } @Override - public BytesReference content() { - return content; + public String uri() { + return restRequest.uri(); } @Override - public String uri() { - throw new UnsupportedOperationException("Not used by this test"); + public boolean hasContent() { + return Strings.hasLength(content()); } @Override - public Method method() { - throw new UnsupportedOperationException("Not used by this test"); + public BytesReference content() { + return restRequest.content(); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java index d0403736400cd..4d4743156c73d 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/FakeRestRequest.java @@ -19,12 +19,18 @@ package org.elasticsearch.test.rest; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.HttpRequest; +import org.elasticsearch.http.HttpResponse; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; -import java.net.SocketAddress; +import java.net.InetSocketAddress; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -32,45 +38,115 @@ public class FakeRestRequest extends RestRequest { - private final BytesReference content; - private final Method method; - private final SocketAddress remoteAddress; - public FakeRestRequest() { - this(NamedXContentRegistry.EMPTY, new HashMap<>(), new HashMap<>(), null, Method.GET, "/", null); + this(NamedXContentRegistry.EMPTY, new FakeHttpRequest(Method.GET, "", BytesArray.EMPTY, new HashMap<>()), new HashMap<>(), + new FakeHttpChannel(null)); } - private FakeRestRequest(NamedXContentRegistry xContentRegistry, Map> headers, - Map params, BytesReference content, Method method, String path, SocketAddress remoteAddress) { - super(xContentRegistry, params, path, headers); - this.content = content; - this.method = method; - this.remoteAddress = remoteAddress; + private FakeRestRequest(NamedXContentRegistry xContentRegistry, HttpRequest httpRequest, Map params, + HttpChannel httpChannel) { + super(xContentRegistry, params, httpRequest.uri(), httpRequest.getHeaders(), httpRequest, httpChannel); } @Override - public Method method() { - return method; + public boolean hasContent() { + return content() != null; } - @Override - public String uri() { - return rawPath(); - } + private static class FakeHttpRequest implements HttpRequest { - @Override - public boolean hasContent() { - return content != null; - } + private final Method method; + private final String uri; + private final BytesReference content; + private final Map> headers; - @Override - public BytesReference content() { - return content; + private FakeHttpRequest(Method method, String uri, BytesReference content, Map> headers) { + this.method = method; + this.uri = uri; + this.content = content; + this.headers = headers; + } + + @Override + public Method method() { + return method; + } + + @Override + public String uri() { + return uri; + } + + @Override + public BytesReference content() { + return content; + } + + @Override + public Map> getHeaders() { + return headers; + } + + @Override + public List strictCookies() { + return Collections.emptyList(); + } + + @Override + public HttpVersion protocolVersion() { + return HttpVersion.HTTP_1_1; + } + + @Override + public HttpRequest removeHeader(String header) { + headers.remove(header); + return this; + } + + @Override + public HttpResponse createResponse(RestStatus status, BytesReference content) { + Map headers = new HashMap<>(); + return new HttpResponse() { + @Override + public void addHeader(String name, String value) { + headers.put(name, value); + } + + @Override + public boolean containsHeader(String name) { + return headers.containsKey(name); + } + }; + } } - @Override - public SocketAddress getRemoteAddress() { - return remoteAddress; + private static class FakeHttpChannel implements HttpChannel { + + private final InetSocketAddress remoteAddress; + + private FakeHttpChannel(InetSocketAddress remoteAddress) { + this.remoteAddress = remoteAddress; + } + + @Override + public void sendResponse(HttpResponse response, ActionListener listener) { + + } + + @Override + public InetSocketAddress getLocalAddress() { + return null; + } + + @Override + public InetSocketAddress getRemoteAddress() { + return remoteAddress; + } + + @Override + public void close() { + + } } public static class Builder { @@ -86,7 +162,7 @@ public static class Builder { private Method method = Method.GET; - private SocketAddress address = null; + private InetSocketAddress address = null; public Builder(NamedXContentRegistry xContentRegistry) { this.xContentRegistry = xContentRegistry; @@ -120,15 +196,14 @@ public Builder withMethod(Method method) { return this; } - public Builder withRemoteAddress(SocketAddress address) { + public Builder withRemoteAddress(InetSocketAddress address) { this.address = address; return this; } public FakeRestRequest build() { - return new FakeRestRequest(xContentRegistry, headers, params, content, method, path, address); + FakeHttpRequest fakeHttpRequest = new FakeHttpRequest(method, path, content, headers); + return new FakeRestRequest(xContentRegistry, fakeHttpRequest, params, new FakeHttpChannel(address)); } - } - } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java index aec5b3a04d255..71424ec507f52 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/rest/RestRequestFilter.java @@ -6,7 +6,6 @@ package org.elasticsearch.xpack.core.security.rest; import org.elasticsearch.ElasticsearchException; -import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.collect.Tuple; @@ -17,7 +16,6 @@ import org.elasticsearch.rest.RestRequest; import java.io.IOException; -import java.net.SocketAddress; import java.util.Map; import java.util.Set; @@ -33,37 +31,15 @@ public interface RestRequestFilter { default RestRequest getFilteredRequest(RestRequest restRequest) throws IOException { Set fields = getFilteredFields(); if (restRequest.hasContent() && fields.isEmpty() == false) { - return new RestRequest(restRequest.getXContentRegistry(), restRequest.params(), restRequest.path(), restRequest.getHeaders()) { + return new RestRequest(restRequest) { private BytesReference filteredBytes = null; - @Override - public Method method() { - return restRequest.method(); - } - - @Override - public String uri() { - return restRequest.uri(); - } - @Override public boolean hasContent() { return true; } - @Nullable - @Override - public SocketAddress getRemoteAddress() { - return restRequest.getRemoteAddress(); - } - - @Nullable - @Override - public SocketAddress getLocalAddress() { - return restRequest.getLocalAddress(); - } - @Override public BytesReference content() { if (filteredBytes == null) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java index 1976722d65f36..1991c2685f24e 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrail.java @@ -69,7 +69,6 @@ import java.io.IOException; import java.net.InetAddress; import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.net.UnknownHostException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; @@ -829,10 +828,9 @@ private Message message(String type, @Nullable String action, @Nullable Authenti msg.builder.field(Field.REQUEST_BODY, restRequestContent(request)); } msg.builder.field(Field.ORIGIN_TYPE, "rest"); - SocketAddress address = request.getRemoteAddress(); - if (address instanceof InetSocketAddress) { - msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress()) - .getAddress())); + InetSocketAddress address = request.getHttpChannel().getRemoteAddress(); + if (address != null) { + msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress())); } else { msg.builder.field(Field.ORIGIN_ADDRESS, address); } @@ -854,10 +852,9 @@ private Message message(String type, @Nullable Tuple realms, @Nul msg.builder.field(Field.REQUEST_BODY, restRequestContent(request)); } msg.builder.field(Field.ORIGIN_TYPE, "rest"); - SocketAddress address = request.getRemoteAddress(); - if (address instanceof InetSocketAddress) { - msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(((InetSocketAddress) request.getRemoteAddress()) - .getAddress())); + InetSocketAddress address = request.getHttpChannel().getRemoteAddress(); + if (address != null) { + msg.builder.field(Field.ORIGIN_ADDRESS, NetworkAddress.format(address.getAddress())); } else { msg.builder.field(Field.ORIGIN_ADDRESS, address); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java index 3b9a42179a577..5706f79011ac5 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java @@ -38,7 +38,6 @@ import java.net.InetAddress; import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.util.Arrays; import java.util.Collections; import java.util.EnumSet; @@ -544,13 +543,8 @@ static String subject(Authentication authentication) { } private static String hostAttributes(RestRequest request) { - String formattedAddress; - final SocketAddress socketAddress = request.getRemoteAddress(); - if (socketAddress instanceof InetSocketAddress) { - formattedAddress = NetworkAddress.format(((InetSocketAddress) socketAddress).getAddress()); - } else { - formattedAddress = socketAddress.toString(); - } + final InetSocketAddress socketAddress = request.getHttpChannel().getRemoteAddress(); + String formattedAddress = NetworkAddress.format(socketAddress.getAddress()); return "origin_address=[" + formattedAddress + "]"; } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java index dcee6535cf337..ed50a5cfe84e7 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/RemoteHostHeader.java @@ -20,7 +20,7 @@ public class RemoteHostHeader { * then be copied to the subsequent action requests. */ public static void process(RestRequest request, ThreadContext threadContext) { - threadContext.putTransient(KEY, request.getRemoteAddress()); + threadContext.putTransient(KEY, request.getHttpChannel().getRemoteAddress()); } /** diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java index 0f4da8b847c58..9109bb37e8c41 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/SecurityRestFilter.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.security.rest; +import io.netty.channel.Channel; import io.netty.handler.ssl.SslHandler; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; @@ -13,7 +14,8 @@ import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.logging.ESLoggerFactory; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.http.netty4.Netty4HttpRequest; +import org.elasticsearch.http.HttpChannel; +import org.elasticsearch.http.netty4.Netty4HttpChannel; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestChannel; @@ -50,10 +52,11 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c if (licenseState.isSecurityEnabled() && licenseState.isAuthAllowed() && request.method() != Method.OPTIONS) { // CORS - allow for preflight unauthenticated OPTIONS request if (extractClientCertificate) { - Netty4HttpRequest nettyHttpRequest = (Netty4HttpRequest) request; - SslHandler handler = nettyHttpRequest.getChannel().pipeline().get(SslHandler.class); + HttpChannel httpChannel = request.getHttpChannel(); + Channel nettyChannel = ((Netty4HttpChannel) httpChannel).getNettyChannel(); + SslHandler handler = nettyChannel.pipeline().get(SslHandler.class); assert handler != null; - ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyHttpRequest.getChannel()); + ServerTransportFilter.extractClientCertificates(logger, threadContext, handler.engine(), nettyChannel); } service.authenticate(maybeWrapRestRequest(request), ActionListener.wrap( authentication -> { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java index 01916b9138031..ac586c4945794 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/netty4/SecurityNetty4HttpServerTransport.java @@ -104,7 +104,7 @@ public ChannelHandler configureServerChannelHandler() { private final class HttpSslChannelHandler extends HttpChannelHandler { HttpSslChannelHandler() { - super(SecurityNetty4HttpServerTransport.this, httpHandlingSettings, threadPool.getThreadContext()); + super(SecurityNetty4HttpServerTransport.this, handlingSettings); } @Override diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java index 7878fdb92336a..2e2a931f78f87 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/index/IndexAuditTrailTests.java @@ -33,6 +33,7 @@ import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.set.Sets; +import org.elasticsearch.http.HttpChannel; import org.elasticsearch.plugins.MetaDataUpgrader; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestRequest; @@ -914,7 +915,9 @@ public void clearCredentials() { private RestRequest mockRestRequest() { RestRequest request = mock(RestRequest.class); - when(request.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200)); + HttpChannel httpChannel = mock(HttpChannel.class); + when(request.getHttpChannel()).thenReturn(httpChannel); + when(httpChannel.getRemoteAddress()).thenReturn(new InetSocketAddress(InetAddress.getLoopbackAddress(), 9200)); when(request.uri()).thenReturn("_uri"); return request; } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java index 335673f1c0cbb..127784dcfc0db 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/RestRequestFilterTests.java @@ -88,6 +88,6 @@ public void testRemoteAddressWorks() throws IOException { new FakeRestRequest.Builder(NamedXContentRegistry.EMPTY).withContent(content, XContentType.JSON) .withRemoteAddress(address).build(); RestRequest filtered = filter.getFilteredRequest(restRequest); - assertEquals(address, filtered.getRemoteAddress()); + assertEquals(address, filtered.getHttpChannel().getRemoteAddress()); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java index 2857aee9b61ad..5db634c8d7be9 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/SecurityRestFilterTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.xcontent.NamedXContentRegistry; import org.elasticsearch.common.xcontent.XContentType; import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.http.HttpChannel; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.BytesRestResponse; import org.elasticsearch.rest.RestChannel; @@ -67,6 +68,7 @@ public void init() throws Exception { public void testProcess() throws Exception { RestRequest request = mock(RestRequest.class); + when(request.getHttpChannel()).thenReturn(mock(HttpChannel.class)); Authentication authentication = mock(Authentication.class); doAnswer((i) -> { ActionListener callback =