Skip to content

Commit

Permalink
Merge #3370 into 2.0.0-M4
Browse files Browse the repository at this point in the history
  • Loading branch information
violetagg committed Jul 26, 2024
2 parents 32bf360 + a30aa84 commit f2f410d
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2023 VMware, Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2018-2024 VMware, Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -134,12 +134,13 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
readTimeout,
requestTimeout,
secured,
timestamp);
timestamp,
true);
}
catch (RuntimeException e) {
pendingResponse = false;
request.setDecoderResult(DecoderResult.failure(e.getCause() != null ? e.getCause() : e));
HttpServerOperations.sendDecodingFailures(ctx, listener, secured, e, msg, httpMessageLogFactory, true, timestamp, connectionInfo, remoteAddress);
HttpServerOperations.sendDecodingFailures(ctx, listener, secured, e, msg, httpMessageLogFactory, true, timestamp, connectionInfo, remoteAddress, true);
return;
}
ops.bind();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023 VMware, Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2020-2024 VMware, Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -636,7 +636,7 @@ static void configureHttp11OrH2CleartextPipeline(ChannelPipeline p,
NettyPipeline.HttpTrafficHandler,
new HttpTrafficHandler(compressPredicate, formDecoderProvider,
forwardedHeaderHandler, httpMessageLogFactory, idleTimeout, listener, mapHandle, maxKeepAliveRequests,
readTimeout, requestTimeout));
readTimeout, requestTimeout, decoder.validateHeaders()));

if (accessLogEnabled) {
p.addAfter(NettyPipeline.HttpTrafficHandler, NettyPipeline.AccessLogHandler, AccessLogHandlerFactory.H1.create(accessLog));
Expand Down Expand Up @@ -702,7 +702,7 @@ static void configureHttp11Pipeline(ChannelPipeline p,
NettyPipeline.HttpTrafficHandler,
new HttpTrafficHandler(compressPredicate, formDecoderProvider,
forwardedHeaderHandler, httpMessageLogFactory, idleTimeout, listener, mapHandle, maxKeepAliveRequests,
readTimeout, requestTimeout));
readTimeout, requestTimeout, decoder.validateHeaders()));

if (accessLogEnabled) {
p.addAfter(NettyPipeline.HttpTrafficHandler, NettyPipeline.AccessLogHandler, AccessLogHandlerFactory.H1.create(accessLog));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@

import static io.netty5.buffer.DefaultBufferAllocators.preferredAllocator;
import static io.netty5.handler.codec.http.HttpUtil.isTransferEncodingChunked;
import static io.netty5.handler.codec.http.headers.DefaultHttpHeadersFactory.headersFactory;
import static io.netty5.handler.codec.http.headers.DefaultHttpHeadersFactory.trailersFactory;
import static reactor.netty5.ReactorNetty.format;
import static reactor.netty5.http.server.HttpServerFormDecoderProvider.DEFAULT_FORM_DECODER_SPEC;
Expand Down Expand Up @@ -126,6 +127,7 @@ class HttpServerOperations extends HttpOperations<HttpServerRequest, HttpServerR
final HttpHeaders responseHeaders;
final String scheme;
final ZonedDateTime timestamp;
final boolean validateHeaders;

BiPredicate<HttpServerRequest, HttpServerResponse> compressionPredicate;
boolean isWebsocket;
Expand Down Expand Up @@ -160,6 +162,7 @@ class HttpServerOperations extends HttpOperations<HttpServerRequest, HttpServerR
this.scheme = replaced.scheme;
this.timestamp = replaced.timestamp;
this.trailerHeadersConsumer = replaced.trailerHeadersConsumer;
this.validateHeaders = replaced.validateHeaders;
}

HttpServerOperations(Connection c, ConnectionObserver listener, HttpRequest nettyRequest,
Expand All @@ -172,7 +175,8 @@ class HttpServerOperations extends HttpOperations<HttpServerRequest, HttpServerR
@Nullable Duration readTimeout,
@Nullable Duration requestTimeout,
boolean secured,
ZonedDateTime timestamp) {
ZonedDateTime timestamp,
boolean validateHeaders) {
super(c, listener, httpMessageLogFactory);
this.compressionPredicate = compressionPredicate;
this.configuredCompressionPredicate = compressionPredicate;
Expand All @@ -184,13 +188,15 @@ class HttpServerOperations extends HttpOperations<HttpServerRequest, HttpServerR
this.isHttp2 = isHttp2;
this.mapHandle = mapHandle;
this.nettyRequest = nettyRequest;
this.nettyResponse = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK);
this.nettyResponse = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK,
headersFactory().withNameValidation(validateHeaders).withValueValidation(validateHeaders));
this.readTimeout = readTimeout;
this.requestTimeout = requestTimeout;
this.responseHeaders = nettyResponse.headers();
this.responseHeaders.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED);
this.scheme = secured ? "https" : "http";
this.timestamp = timestamp;
this.validateHeaders = validateHeaders;
}

@Override
Expand All @@ -212,7 +218,9 @@ public HttpServerOperations withConnection(Consumer<? super Connection> withConn
@Override
protected HttpMessage newFullBodyMessage(Buffer body) {
HttpResponse res =
new DefaultFullHttpResponse(version(), status(), body);
new DefaultFullHttpResponse(version(), status(), body,
headersFactory().withNameValidation(validateHeaders).withValueValidation(validateHeaders),
trailersFactory().withNameValidation(validateHeaders).withValueValidation(validateHeaders));

if (!HttpMethod.HEAD.equals(method())) {
responseHeaders.remove(HttpHeaderNames.TRANSFER_ENCODING);
Expand Down Expand Up @@ -406,7 +414,9 @@ public Flux<?> receiveObject() {
if (!hasSentHeaders()) {
return channel().writeAndFlush(
new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE,
preferredAllocator().allocate(0)));
preferredAllocator().allocate(0),
headersFactory().withNameValidation(validateHeaders).withValueValidation(validateHeaders),
trailersFactory().withNameValidation(validateHeaders).withValueValidation(validateHeaders)));
}
return channel().newSucceededFuture();
})
Expand Down Expand Up @@ -960,7 +970,8 @@ else if (HttpUtil.getContentLength(nettyResponse, -1) != -1) {
responseHeaders.remove(HttpHeaderNames.TRANSFER_ENCODING);
}

return new DefaultFullHttpResponse(version(), status(), body, responseHeaders, trailersFactory().newHeaders());
return new DefaultFullHttpResponse(version(), status(), body, responseHeaders,
trailersFactory().withNameValidation(validateHeaders).withValueValidation(validateHeaders).newHeaders());
}

static long requestsCounter(Channel channel) {
Expand All @@ -982,8 +993,9 @@ static void sendDecodingFailures(
HttpMessageLogFactory httpMessageLogFactory,
@Nullable ZonedDateTime timestamp,
@Nullable ConnectionInfo connectionInfo,
SocketAddress remoteAddress) {
sendDecodingFailures(ctx, listener, secure, t, msg, httpMessageLogFactory, false, timestamp, connectionInfo, remoteAddress);
SocketAddress remoteAddress,
boolean validateHeaders) {
sendDecodingFailures(ctx, listener, secure, t, msg, httpMessageLogFactory, false, timestamp, connectionInfo, remoteAddress, validateHeaders);
}

@SuppressWarnings("FutureReturnValueIgnored")
Expand All @@ -997,7 +1009,8 @@ static void sendDecodingFailures(
boolean isHttp2,
@Nullable ZonedDateTime timestamp,
@Nullable ConnectionInfo connectionInfo,
SocketAddress remoteAddress) {
SocketAddress remoteAddress,
boolean validateHeaders) {

Throwable cause = t.getCause() != null ? t.getCause() : t;

Expand All @@ -1020,7 +1033,9 @@ else if (cause instanceof TooLongHttpHeaderException) {
status = HttpResponseStatus.BAD_REQUEST;
}
FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status,
ctx.bufferAllocator().allocate(0));
ctx.bufferAllocator().allocate(0),
headersFactory().withNameValidation(validateHeaders).withValueValidation(validateHeaders),
trailersFactory().withNameValidation(validateHeaders).withValueValidation(validateHeaders));
response.headers()
.set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO)
.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE);
Expand All @@ -1031,7 +1046,7 @@ else if (cause instanceof TooLongHttpHeaderException) {
if (msg instanceof HttpRequest request) {
ops = new FailedHttpServerRequest(conn, listener, request, response, httpMessageLogFactory, isHttp2,
secure, timestamp == null ? ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM) : timestamp,
connectionInfo == null ? new ConnectionInfo(ctx.channel().localAddress(), remoteAddress, secure) : connectionInfo);
connectionInfo == null ? new ConnectionInfo(ctx.channel().localAddress(), remoteAddress, secure) : connectionInfo, validateHeaders);
ops.bind();
}
else {
Expand Down Expand Up @@ -1207,9 +1222,10 @@ static final class FailedHttpServerRequest extends HttpServerOperations {
boolean isHttp2,
boolean secure,
ZonedDateTime timestamp,
ConnectionInfo connectionInfo) {
ConnectionInfo connectionInfo,
boolean validateHeaders) {
super(c, listener, nettyRequest, null, connectionInfo,
DEFAULT_FORM_DECODER_SPEC, httpMessageLogFactory, isHttp2, null, null, null, secure, timestamp);
DEFAULT_FORM_DECODER_SPEC, httpMessageLogFactory, isHttp2, null, null, null, secure, timestamp, validateHeaders);
this.customResponse = nettyResponse;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ final class HttpTrafficHandler extends ChannelHandlerAdapter implements Runnable
final int maxKeepAliveRequests;
final Duration readTimeout;
final Duration requestTimeout;
final boolean validateHeaders;

ChannelHandlerContext ctx;

Expand Down Expand Up @@ -111,7 +112,8 @@ final class HttpTrafficHandler extends ChannelHandlerAdapter implements Runnable
@Nullable BiFunction<? super Mono<Void>, ? super Connection, ? extends Mono<Void>> mapHandle,
int maxKeepAliveRequests,
@Nullable Duration readTimeout,
@Nullable Duration requestTimeout) {
@Nullable Duration requestTimeout,
boolean validateHeaders) {
this.listener = listener;
this.formDecoderProvider = formDecoderProvider;
this.forwardedHeaderHandler = forwardedHeaderHandler;
Expand All @@ -122,6 +124,7 @@ final class HttpTrafficHandler extends ChannelHandlerAdapter implements Runnable
this.maxKeepAliveRequests = maxKeepAliveRequests;
this.readTimeout = readTimeout;
this.requestTimeout = requestTimeout;
this.validateHeaders = validateHeaders;
}

@Override
Expand Down Expand Up @@ -177,7 +180,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
IllegalStateException e = new IllegalStateException(
"Unexpected request [" + request.method() + " " + request.uri() + " HTTP/2.0]");
request.setDecoderResult(DecoderResult.failure(e.getCause() != null ? e.getCause() : e));
sendDecodingFailures(e, msg);
sendDecodingFailures(e, msg, validateHeaders);
return;
}

Expand Down Expand Up @@ -207,7 +210,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {

DecoderResult decoderResult = request.decoderResult();
if (decoderResult.isFailure()) {
sendDecodingFailures(decoderResult.cause(), msg);
sendDecodingFailures(decoderResult.cause(), msg, validateHeaders);
return;
}

Expand All @@ -232,11 +235,12 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
readTimeout,
requestTimeout,
secure,
timestamp);
timestamp,
validateHeaders);
}
catch (RuntimeException e) {
request.setDecoderResult(DecoderResult.failure(e.getCause() != null ? e.getCause() : e));
sendDecodingFailures(e, msg, timestamp, connectionInfo);
sendDecodingFailures(e, msg, timestamp, connectionInfo, validateHeaders);
return;
}
ops.bind();
Expand All @@ -251,7 +255,7 @@ else if (persistentConnection && pendingResponses == 0) {
if (msg instanceof LastHttpContent<?> lastHttpContent) {
DecoderResult decoderResult = lastHttpContent.decoderResult();
if (decoderResult.isFailure()) {
sendDecodingFailures(decoderResult.cause(), msg);
sendDecodingFailures(decoderResult.cause(), msg, validateHeaders);
return;
}

Expand Down Expand Up @@ -283,7 +287,7 @@ else if (overflow) {
if (msg instanceof DecoderResultProvider decoderResultProvider) {
DecoderResult decoderResult = decoderResultProvider.decoderResult();
if (decoderResult.isFailure()) {
sendDecodingFailures(decoderResult.cause(), msg);
sendDecodingFailures(decoderResult.cause(), msg, validateHeaders);
return;
}
}
Expand Down Expand Up @@ -333,14 +337,14 @@ public void flush(ChannelHandlerContext ctx) {
}
}

void sendDecodingFailures(Throwable t, Object msg) {
sendDecodingFailures(t, msg, null, null);
void sendDecodingFailures(Throwable t, Object msg, boolean validateHeaders) {
sendDecodingFailures(t, msg, null, null, validateHeaders);
}

void sendDecodingFailures(Throwable t, Object msg, @Nullable ZonedDateTime timestamp, @Nullable ConnectionInfo connectionInfo) {
void sendDecodingFailures(Throwable t, Object msg, @Nullable ZonedDateTime timestamp, @Nullable ConnectionInfo connectionInfo, boolean validateHeaders) {
persistentConnection = false;
HttpServerOperations.sendDecodingFailures(ctx, listener, secure, t, msg, httpMessageLogFactory, timestamp, connectionInfo,
remoteAddress);
remoteAddress, validateHeaders);
}

void doPipeline(ChannelHandlerContext ctx, Object msg) {
Expand Down Expand Up @@ -477,7 +481,7 @@ public void run() {

DecoderResult decoderResult = nextRequest.decoderResult();
if (decoderResult.isFailure()) {
sendDecodingFailures(decoderResult.cause(), nextRequest, holder.timestamp, null);
sendDecodingFailures(decoderResult.cause(), nextRequest, holder.timestamp, null, validateHeaders);
discard();
return;
}
Expand All @@ -502,11 +506,12 @@ public void run() {
readTimeout,
requestTimeout,
secure,
holder.timestamp);
holder.timestamp,
validateHeaders);
}
catch (RuntimeException e) {
holder.request.setDecoderResult(DecoderResult.failure(e.getCause() != null ? e.getCause() : e));
sendDecodingFailures(e, holder.request, holder.timestamp, connectionInfo);
sendDecodingFailures(e, holder.request, holder.timestamp, connectionInfo, validateHeaders);
return;
}
ops.bind();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1995,7 +1995,8 @@ private void doTestStatus(HttpResponseStatus status) {
null,
null,
false,
ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM));
ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM),
true);
ops.status(status);
try (Buffer buffer = channel.bufferAllocator().allocate(0)) {
HttpMessage response = ops.newFullBodyMessage(buffer);
Expand Down Expand Up @@ -2993,7 +2994,8 @@ private void doTestIsFormUrlencoded(String headerValue, boolean expectation) {
null,
null,
false,
ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM));
ZonedDateTime.now(ReactorNetty.ZONE_ID_SYSTEM),
true);
assertThat(ops.isFormUrlencoded()).isEqualTo(expectation);
channel.close();
}
Expand Down

0 comments on commit f2f410d

Please sign in to comment.