Skip to content

Commit

Permalink
HttpDecoderSpec#validateHeaders() specifies whether request/response …
Browse files Browse the repository at this point in the history
…headers are validated (#3370)

This is related only to HTTP/1.1, for HTTP/2 and HTTP/3 headers are always validated.
By default the validation is enabled.
  • Loading branch information
violetagg committed Jul 26, 2024
1 parent 6d93917 commit a30aa84
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
readTimeout,
requestTimeout,
secured,
timestamp);
timestamp,
true);
}
catch (RuntimeException e) {
pendingResponse = false;
request.setDecoderResult(DecoderResult.failure(e.getCause() != null ? e.getCause() : e));
HttpServerOperations.sendDecodingFailures(ctx, listener, secured, e, msg, httpMessageLogFactory, true, timestamp, connectionInfo, remoteAddress);
HttpServerOperations.sendDecodingFailures(ctx, listener, secured, e, msg, httpMessageLogFactory, true, timestamp, connectionInfo, remoteAddress, true);
return;
}
ops.bind();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ final class Http3ServerOperations extends HttpServerOperations {
boolean secured,
ZonedDateTime timestamp) {
super(c, listener, nettyRequest, compressionPredicate, connectionInfo, decoder, encoder, formDecoderProvider,
httpMessageLogFactory, isHttp2, mapHandle, readTimeout, requestTimeout, secured, timestamp);
httpMessageLogFactory, isHttp2, mapHandle, readTimeout, requestTimeout, secured, timestamp, true);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
catch (RuntimeException e) {
pendingResponse = false;
request.setDecoderResult(DecoderResult.failure(e.getCause() != null ? e.getCause() : e));
HttpServerOperations.sendDecodingFailures(ctx, listener, true, e, msg, httpMessageLogFactory, true, timestamp, connectionInfo, remoteAddress);
HttpServerOperations.sendDecodingFailures(ctx, listener, true, e, msg, httpMessageLogFactory, true, timestamp, connectionInfo, remoteAddress, true);
return;
}
ops.bind();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,7 @@ static void configureHttp11OrH2CleartextPipeline(ChannelPipeline p,
NettyPipeline.HttpTrafficHandler,
new HttpTrafficHandler(compressPredicate, cookieDecoder, cookieEncoder, formDecoderProvider,
forwardedHeaderHandler, httpMessageLogFactory, idleTimeout, listener, mapHandle, maxKeepAliveRequests,
readTimeout, requestTimeout));
readTimeout, requestTimeout, decoder.validateHeaders()));

if (accessLogEnabled) {
p.addAfter(NettyPipeline.HttpTrafficHandler, NettyPipeline.AccessLogHandler, AccessLogHandlerFactory.H1.create(accessLog));
Expand Down Expand Up @@ -813,7 +813,7 @@ static void configureHttp11Pipeline(ChannelPipeline p,
NettyPipeline.HttpTrafficHandler,
new HttpTrafficHandler(compressPredicate, cookieDecoder, cookieEncoder, formDecoderProvider,
forwardedHeaderHandler, httpMessageLogFactory, idleTimeout, listener, mapHandle, maxKeepAliveRequests,
readTimeout, requestTimeout));
readTimeout, requestTimeout, decoder.validateHeaders()));

if (accessLogEnabled) {
p.addAfter(NettyPipeline.HttpTrafficHandler, NettyPipeline.AccessLogHandler, AccessLogHandlerFactory.H1.create(accessLog));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
import reactor.util.context.Context;

import static io.netty.buffer.Unpooled.EMPTY_BUFFER;
import static io.netty.handler.codec.http.DefaultHttpHeadersFactory.headersFactory;
import static io.netty.handler.codec.http.DefaultHttpHeadersFactory.trailersFactory;
import static io.netty.handler.codec.http.HttpUtil.isTransferEncodingChunked;
import static reactor.netty.ReactorNetty.format;
Expand Down Expand Up @@ -129,6 +130,7 @@ class HttpServerOperations extends HttpOperations<HttpServerRequest, HttpServerR
final HttpHeaders responseHeaders;
final String scheme;
final ZonedDateTime timestamp;
final boolean validateHeaders;

BiPredicate<HttpServerRequest, HttpServerResponse> compressionPredicate;
boolean isWebsocket;
Expand Down Expand Up @@ -165,6 +167,7 @@ class HttpServerOperations extends HttpOperations<HttpServerRequest, HttpServerR
this.scheme = replaced.scheme;
this.timestamp = replaced.timestamp;
this.trailerHeadersConsumer = replaced.trailerHeadersConsumer;
this.validateHeaders = replaced.validateHeaders;
}

HttpServerOperations(Connection c, ConnectionObserver listener, HttpRequest nettyRequest,
Expand All @@ -179,7 +182,8 @@ class HttpServerOperations extends HttpOperations<HttpServerRequest, HttpServerR
@Nullable Duration readTimeout,
@Nullable Duration requestTimeout,
boolean secured,
ZonedDateTime timestamp) {
ZonedDateTime timestamp,
boolean validateHeaders) {
super(c, listener, httpMessageLogFactory);
this.compressionPredicate = compressionPredicate;
this.configuredCompressionPredicate = compressionPredicate;
Expand All @@ -193,13 +197,14 @@ class HttpServerOperations extends HttpOperations<HttpServerRequest, HttpServerR
this.isHttp2 = isHttp2;
this.mapHandle = mapHandle;
this.nettyRequest = nettyRequest;
this.nettyResponse = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
this.nettyResponse = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, headersFactory().withValidation(validateHeaders));
this.readTimeout = readTimeout;
this.requestTimeout = requestTimeout;
this.responseHeaders = nettyResponse.headers();
this.responseHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED);
this.scheme = secured ? "https" : "http";
this.timestamp = timestamp;
this.validateHeaders = validateHeaders;
}

@Override
Expand All @@ -221,7 +226,8 @@ public HttpServerOperations withConnection(Consumer<? super Connection> withConn
@Override
protected HttpMessage newFullBodyMessage(ByteBuf body) {
HttpResponse res =
new DefaultFullHttpResponse(version(), status(), body);
new DefaultFullHttpResponse(version(), status(), body,
headersFactory().withValidation(validateHeaders), trailersFactory().withValidation(validateHeaders));

if (!HttpMethod.HEAD.equals(method())) {
responseHeaders.remove(HttpHeaderNames.TRANSFER_ENCODING);
Expand Down Expand Up @@ -415,7 +421,8 @@ public Flux<?> receiveObject() {
return FutureMono.deferFuture(() -> {
if (!hasSentHeaders()) {
return channel().writeAndFlush(
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, EMPTY_BUFFER));
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, EMPTY_BUFFER,
headersFactory().withValidation(validateHeaders), trailersFactory().withValidation(validateHeaders)));
}
return channel().newSucceededFuture();
})
Expand Down Expand Up @@ -966,7 +973,7 @@ else if (contentLength != -1) {
responseHeaders.remove(HttpHeaderNames.TRANSFER_ENCODING);
}

return new DefaultFullHttpResponse(version(), status(), body, responseHeaders, trailersFactory().newHeaders());
return new DefaultFullHttpResponse(version(), status(), body, responseHeaders, trailersFactory().withValidation(validateHeaders).newHeaders());
}

static long requestsCounter(Channel channel) {
Expand All @@ -988,8 +995,9 @@ static void sendDecodingFailures(
HttpMessageLogFactory httpMessageLogFactory,
@Nullable ZonedDateTime timestamp,
@Nullable ConnectionInfo connectionInfo,
SocketAddress remoteAddress) {
sendDecodingFailures(ctx, listener, secure, t, msg, httpMessageLogFactory, false, timestamp, connectionInfo, remoteAddress);
SocketAddress remoteAddress,
boolean validateHeaders) {
sendDecodingFailures(ctx, listener, secure, t, msg, httpMessageLogFactory, false, timestamp, connectionInfo, remoteAddress, validateHeaders);
}

@SuppressWarnings("FutureReturnValueIgnored")
Expand All @@ -1003,7 +1011,8 @@ static void sendDecodingFailures(
boolean isHttp2,
@Nullable ZonedDateTime timestamp,
@Nullable ConnectionInfo connectionInfo,
SocketAddress remoteAddress) {
SocketAddress remoteAddress,
boolean validateHeaders) {

Throwable cause = t.getCause() != null ? t.getCause() : t;

Expand All @@ -1025,7 +1034,8 @@ else if (cause instanceof TooLongHttpHeaderException) {
else {
status = HttpResponseStatus.BAD_REQUEST;
}
FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status);
FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, Unpooled.buffer(0),
headersFactory().withValidation(validateHeaders), trailersFactory().withValidation(validateHeaders));
response.headers()
.setInt(HttpHeaderNames.CONTENT_LENGTH, 0)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
Expand All @@ -1036,7 +1046,7 @@ else if (cause instanceof TooLongHttpHeaderException) {
if (msg instanceof HttpRequest) {
ops = new FailedHttpServerRequest(conn, listener, (HttpRequest) msg, response, httpMessageLogFactory, isHttp2,
secure, timestamp == null ? ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM) : timestamp,
connectionInfo == null ? new ConnectionInfo(ctx.channel().localAddress(), remoteAddress, secure) : connectionInfo);
connectionInfo == null ? new ConnectionInfo(ctx.channel().localAddress(), remoteAddress, secure) : connectionInfo, validateHeaders);
ops.bind();
}
else {
Expand Down Expand Up @@ -1213,10 +1223,11 @@ static final class FailedHttpServerRequest extends HttpServerOperations {
boolean isHttp2,
boolean secure,
ZonedDateTime timestamp,
ConnectionInfo connectionInfo) {
ConnectionInfo connectionInfo,
boolean validateHeaders) {
super(c, listener, nettyRequest, null, connectionInfo,
ServerCookieDecoder.STRICT, ServerCookieEncoder.STRICT, DEFAULT_FORM_DECODER_SPEC, httpMessageLogFactory, isHttp2,
null, null, null, secure, timestamp);
null, null, null, secure, timestamp, validateHeaders);
this.customResponse = nettyResponse;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ final class HttpTrafficHandler extends ChannelDuplexHandler implements Runnable
final int maxKeepAliveRequests;
final Duration readTimeout;
final Duration requestTimeout;
final boolean validateHeaders;

ChannelHandlerContext ctx;

Expand Down Expand Up @@ -116,7 +117,8 @@ final class HttpTrafficHandler extends ChannelDuplexHandler implements Runnable
@Nullable BiFunction<? super Mono<Void>, ? super Connection, ? extends Mono<Void>> mapHandle,
int maxKeepAliveRequests,
@Nullable Duration readTimeout,
@Nullable Duration requestTimeout) {
@Nullable Duration requestTimeout,
boolean validateHeaders) {
this.listener = listener;
this.formDecoderProvider = formDecoderProvider;
this.forwardedHeaderHandler = forwardedHeaderHandler;
Expand All @@ -129,6 +131,7 @@ final class HttpTrafficHandler extends ChannelDuplexHandler implements Runnable
this.maxKeepAliveRequests = maxKeepAliveRequests;
this.readTimeout = readTimeout;
this.requestTimeout = requestTimeout;
this.validateHeaders = validateHeaders;
}

@Override
Expand Down Expand Up @@ -173,7 +176,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
IllegalStateException e = new IllegalStateException(
"Unexpected request [" + request.method() + " " + request.uri() + " HTTP/2.0]");
request.setDecoderResult(DecoderResult.failure(e.getCause() != null ? e.getCause() : e));
sendDecodingFailures(e, msg);
sendDecodingFailures(e, msg, validateHeaders);
return;
}

Expand Down Expand Up @@ -219,7 +222,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {

DecoderResult decoderResult = request.decoderResult();
if (decoderResult.isFailure()) {
sendDecodingFailures(decoderResult.cause(), msg);
sendDecodingFailures(decoderResult.cause(), msg, validateHeaders);
return;
}

Expand Down Expand Up @@ -247,11 +250,12 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
readTimeout,
requestTimeout,
secure,
timestamp);
timestamp,
validateHeaders);
}
catch (RuntimeException e) {
request.setDecoderResult(DecoderResult.failure(e.getCause() != null ? e.getCause() : e));
sendDecodingFailures(e, msg, timestamp, connectionInfo);
sendDecodingFailures(e, msg, timestamp, connectionInfo, validateHeaders);
return;
}
ops.bind();
Expand All @@ -266,7 +270,7 @@ else if (persistentConnection && pendingResponses == 0) {
if (msg instanceof LastHttpContent) {
DecoderResult decoderResult = ((LastHttpContent) msg).decoderResult();
if (decoderResult.isFailure()) {
sendDecodingFailures(decoderResult.cause(), msg);
sendDecodingFailures(decoderResult.cause(), msg, validateHeaders);
return;
}

Expand Down Expand Up @@ -298,7 +302,7 @@ else if (overflow) {
if (msg instanceof DecoderResultProvider) {
DecoderResult decoderResult = ((DecoderResultProvider) msg).decoderResult();
if (decoderResult.isFailure()) {
sendDecodingFailures(decoderResult.cause(), msg);
sendDecodingFailures(decoderResult.cause(), msg, validateHeaders);
return;
}
}
Expand Down Expand Up @@ -338,14 +342,14 @@ public void flush(ChannelHandlerContext ctx) {
}
}

void sendDecodingFailures(Throwable t, Object msg) {
sendDecodingFailures(t, msg, null, null);
void sendDecodingFailures(Throwable t, Object msg, boolean validateHeaders) {
sendDecodingFailures(t, msg, null, null, validateHeaders);
}

void sendDecodingFailures(Throwable t, Object msg, @Nullable ZonedDateTime timestamp, @Nullable ConnectionInfo connectionInfo) {
void sendDecodingFailures(Throwable t, Object msg, @Nullable ZonedDateTime timestamp, @Nullable ConnectionInfo connectionInfo, boolean validateHeaders) {
persistentConnection = false;
HttpServerOperations.sendDecodingFailures(ctx, listener, secure, t, msg, httpMessageLogFactory, timestamp, connectionInfo,
remoteAddress);
remoteAddress, validateHeaders);
}

void doPipeline(ChannelHandlerContext ctx, Object msg) {
Expand Down Expand Up @@ -478,7 +482,7 @@ public void run() {

DecoderResult decoderResult = nextRequest.decoderResult();
if (decoderResult.isFailure()) {
sendDecodingFailures(decoderResult.cause(), nextRequest, holder.timestamp, null);
sendDecodingFailures(decoderResult.cause(), nextRequest, holder.timestamp, null, validateHeaders);
discard();
return;
}
Expand Down Expand Up @@ -506,11 +510,12 @@ public void run() {
readTimeout,
requestTimeout,
secure,
holder.timestamp);
holder.timestamp,
validateHeaders);
}
catch (RuntimeException e) {
holder.request.setDecoderResult(DecoderResult.failure(e.getCause() != null ? e.getCause() : e));
sendDecodingFailures(e, holder.request, holder.timestamp, connectionInfo);
sendDecodingFailures(e, holder.request, holder.timestamp, connectionInfo, validateHeaders);
return;
}
ops.bind();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2133,7 +2133,8 @@ private void doTestStatus(HttpResponseStatus status) {
null,
null,
false,
ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM));
ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM),
true);
ops.status(status);
HttpMessage response = ops.newFullBodyMessage(Unpooled.EMPTY_BUFFER);
assertThat(((FullHttpResponse) response).status().reasonPhrase()).isEqualTo(status.reasonPhrase());
Expand Down Expand Up @@ -3136,7 +3137,8 @@ private void doTestIsFormUrlencoded(String headerValue, boolean expectation) {
null,
null,
false,
ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM));
ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM),
true);
assertThat(ops.isFormUrlencoded()).isEqualTo(expectation);
// "FutureReturnValueIgnored" is suppressed deliberately
channel.close();
Expand Down

0 comments on commit a30aa84

Please sign in to comment.