Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set RequestContext.isTimedOut(true) on DNS, session, write timeout #5156

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.linecorp.armeria.common.RequestId;
import com.linecorp.armeria.common.Response;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.TimeoutException;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.annotation.UnstableApi;
import com.linecorp.armeria.common.logging.RequestLog;
Expand Down Expand Up @@ -514,6 +515,21 @@ default void timeoutNow() {
cancel(ResponseTimeoutException.get());
}

/**
* Returns whether this {@link ClientRequestContext} has been timed-out, that is the cancellation cause
* is an instance of {@link TimeoutException} or
* {@link UnprocessedRequestException} and wrapped cause is {@link TimeoutException}.
*/
@Override
default boolean isTimedOut() {
injae-kim marked this conversation as resolved.
Show resolved Hide resolved
if (RequestContext.super.isTimedOut()) {
return true;
}
final Throwable cause = cancellationCause();
return cause instanceof TimeoutException ||
cause instanceof UnprocessedRequestException && cause.getCause() instanceof TimeoutException;
}

/**
* Returns the maximum length of the received {@link Response}.
* This value is initially set from {@link ClientOptions#MAX_RESPONSE_LENGTH}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ private static void handleEarlyRequestException(ClientRequestContext ctx,
final RequestLogBuilder logBuilder = ctx.logBuilder();
logBuilder.endRequest(cause);
logBuilder.endResponse(cause);
ctx.cancel(cause);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public final class SessionProtocolNegotiationException extends RuntimeException
* Creates a new instance with the specified expected {@link SessionProtocol}.
*/
public SessionProtocolNegotiationException(SessionProtocol expected, @Nullable String reason) {
super("expected: " + requireNonNull(expected, "expected") + ", reason: " + reason);
super(appendReason("expected: " + requireNonNull(expected, "expected"), reason));
this.expected = expected;
actual = null;
}
Expand All @@ -48,8 +48,8 @@ public SessionProtocolNegotiationException(SessionProtocol expected, @Nullable S
public SessionProtocolNegotiationException(SessionProtocol expected,
@Nullable SessionProtocol actual, @Nullable String reason) {

super("expected: " + requireNonNull(expected, "expected") +
", actual: " + requireNonNull(actual, "actual") + ", reason: " + reason);
super(appendReason("expected: " + requireNonNull(expected, "expected") +
", actual: " + requireNonNull(actual, "actual"), reason));
injae-kim marked this conversation as resolved.
Show resolved Hide resolved
this.expected = expected;
this.actual = actual;
}
Expand Down Expand Up @@ -78,4 +78,11 @@ public Throwable fillInStackTrace() {
}
return this;
}

private static String appendReason(String message, @Nullable String reason) {
if (reason == null) {
return message;
}
return message + ", reason: " + reason;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,21 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.util.function.Function;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;
import org.junit.jupiter.params.provider.ValueSource;

import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.TimeoutException;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.util.SafeCloseable;
import com.linecorp.armeria.server.ServiceRequestContext;
Expand Down Expand Up @@ -276,6 +282,24 @@ void updateRequestWithInvalidPath(String path) {
.hasMessageContaining("invalid path");
}

@ParameterizedTest
@ArgumentsSource(TimedOutExceptionProvider.class)
void isTimedOut_true(Throwable cause) {
final ClientRequestContext cctx = clientRequestContext();
cctx.cancel(cause);
cctx.whenResponseCancelled().join();
assertThat(cctx.isTimedOut()).isTrue();
}

@ParameterizedTest
@ArgumentsSource(NotTimedOutExceptionProvider.class)
void isTimedOut_false(Throwable cause) {
final ClientRequestContext cctx = clientRequestContext();
cctx.cancel(cause);
cctx.whenResponseCancelled().join();
assertThat(cctx.isTimedOut()).isFalse();
}

private static void assertUnwrapAllCurrentCtx(@Nullable RequestContext ctx) {
final RequestContext current = RequestContext.currentOrNull();
if (current == null) {
Expand All @@ -292,4 +316,25 @@ private static ServiceRequestContext serviceRequestContext() {
private static ClientRequestContext clientRequestContext() {
return ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/"));
}

private static class TimedOutExceptionProvider implements ArgumentsProvider {

@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(new TimeoutException(),
ResponseTimeoutException.get(),
UnprocessedRequestException.of(ResponseTimeoutException.get()))
.map(Arguments::of);
}
}

private static class NotTimedOutExceptionProvider implements ArgumentsProvider {

@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(new RuntimeException(),
UnprocessedRequestException.of(new RuntimeException()))
.map(Arguments::of);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,39 @@

package com.linecorp.armeria.client;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.awaitility.Awaitility.await;

import java.util.concurrent.CompletionException;
import java.util.stream.Stream;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import com.google.common.collect.ImmutableMap;

import com.linecorp.armeria.client.endpoint.dns.TestDnsServer;
import com.linecorp.armeria.common.CommonPools;
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.metric.PrometheusMeterRegistries;
import com.linecorp.armeria.server.AbstractHttpService;
import com.linecorp.armeria.server.Server;
import com.linecorp.armeria.server.ServerBuilder;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.dns.DatagramDnsQuery;
import io.netty.resolver.ResolvedAddressTypes;
import io.netty.resolver.dns.DnsServerAddressStreamProvider;
import io.netty.resolver.dns.DnsServerAddresses;
import io.netty.util.ReferenceCountUtil;

class HttpClientFactoryTest {
@RegisterExtension
public static final ServerExtension server = new ServerExtension() {
Expand Down Expand Up @@ -103,4 +121,61 @@ protected HttpResponse doGet(ServiceRequestContext ctx,
});
}
}

@Test
void execute_dnsTimeout_clientRequestContext_isTimedOut() {
try (TestDnsServer dnsServer = new TestDnsServer(ImmutableMap.of(), new AlwaysTimeoutHandler())) {
try (RefreshingAddressResolverGroup group = dnsTimeoutBuilder(dnsServer)
.build(CommonPools.workerGroup().next())) {
final ClientFactory clientFactory = ClientFactory
.builder()
.addressResolverGroupFactory(eventExecutors -> group)
.build();
final Endpoint endpoint = Endpoint
.of("test")
.withIpAddr(null); // to invoke dns resolve address
final WebClient client = WebClient
.builder(endpoint.toUri(SessionProtocol.H1C))
.factory(clientFactory)
.build();

try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) {
assertThatThrownBy(() -> client.get("/").aggregate().join())
.isInstanceOf(CompletionException.class)
.hasCauseInstanceOf(UnprocessedRequestException.class)
.hasRootCauseInstanceOf(DnsTimeoutException.class);
captor.get().whenResponseCancelled().join();
assertThat(captor.get().isTimedOut()).isTrue();
}

clientFactory.close();
endpoint.close();
}
}
}

private static class AlwaysTimeoutHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof DatagramDnsQuery) {
// Just release the msg and return so that the client request is timed out.
ReferenceCountUtil.safeRelease(msg);
return;
}
super.channelRead(ctx, msg);
}
}

private static DnsResolverGroupBuilder dnsTimeoutBuilder(TestDnsServer... servers) {
final DnsServerAddressStreamProvider dnsServerAddressStreamProvider =
hostname -> DnsServerAddresses.sequential(
Stream.of(servers).map(TestDnsServer::addr).collect(toImmutableList())).stream();
final DnsResolverGroupBuilder builder = new DnsResolverGroupBuilder()
.serverAddressStreamProvider(dnsServerAddressStreamProvider)
.meterRegistry(PrometheusMeterRegistries.newRegistry())
.resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY)
.traceEnabled(false)
.queryTimeoutMillis(1); // dns timeout
return builder;
}
}
Loading