diff --git a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java index 25ecab100a1..4aea2c21933 100644 --- a/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java +++ b/core/src/main/java/com/linecorp/armeria/client/ClientRequestContext.java @@ -42,6 +42,7 @@ import com.linecorp.armeria.common.RequestId; import com.linecorp.armeria.common.Response; import com.linecorp.armeria.common.RpcRequest; +import com.linecorp.armeria.common.TimeoutException; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.annotation.UnstableApi; import com.linecorp.armeria.common.logging.RequestLog; @@ -514,6 +515,21 @@ default void timeoutNow() { cancel(ResponseTimeoutException.get()); } + /** + * Returns whether this {@link ClientRequestContext} has been timed-out, that is the cancellation cause + * is an instance of {@link TimeoutException} or + * {@link UnprocessedRequestException} and wrapped cause is {@link TimeoutException}. + */ + @Override + default boolean isTimedOut() { + if (RequestContext.super.isTimedOut()) { + return true; + } + final Throwable cause = cancellationCause(); + return cause instanceof TimeoutException || + cause instanceof UnprocessedRequestException && cause.getCause() instanceof TimeoutException; + } + /** * Returns the maximum length of the received {@link Response}. * This value is initially set from {@link ClientOptions#MAX_RESPONSE_LENGTH}. diff --git a/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java b/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java index e54959ac233..db7b7f5dc57 100644 --- a/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java +++ b/core/src/main/java/com/linecorp/armeria/client/HttpClientDelegate.java @@ -244,6 +244,7 @@ private static void handleEarlyRequestException(ClientRequestContext ctx, final RequestLogBuilder logBuilder = ctx.logBuilder(); logBuilder.endRequest(cause); logBuilder.endResponse(cause); + ctx.cancel(cause); } } diff --git a/core/src/main/java/com/linecorp/armeria/client/SessionProtocolNegotiationException.java b/core/src/main/java/com/linecorp/armeria/client/SessionProtocolNegotiationException.java index 8314222a6e7..6eced0cba29 100644 --- a/core/src/main/java/com/linecorp/armeria/client/SessionProtocolNegotiationException.java +++ b/core/src/main/java/com/linecorp/armeria/client/SessionProtocolNegotiationException.java @@ -37,7 +37,7 @@ public final class SessionProtocolNegotiationException extends RuntimeException * Creates a new instance with the specified expected {@link SessionProtocol}. */ public SessionProtocolNegotiationException(SessionProtocol expected, @Nullable String reason) { - super("expected: " + requireNonNull(expected, "expected") + ", reason: " + reason); + super(appendReason("expected: " + requireNonNull(expected, "expected"), reason)); this.expected = expected; actual = null; } @@ -48,8 +48,8 @@ public SessionProtocolNegotiationException(SessionProtocol expected, @Nullable S public SessionProtocolNegotiationException(SessionProtocol expected, @Nullable SessionProtocol actual, @Nullable String reason) { - super("expected: " + requireNonNull(expected, "expected") + - ", actual: " + requireNonNull(actual, "actual") + ", reason: " + reason); + super(appendReason("expected: " + requireNonNull(expected, "expected") + + ", actual: " + actual, reason)); this.expected = expected; this.actual = actual; } @@ -78,4 +78,11 @@ public Throwable fillInStackTrace() { } return this; } + + private static String appendReason(String message, @Nullable String reason) { + if (reason == null) { + return message; + } + return message + ", reason: " + reason; + } } diff --git a/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java b/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java index be006e0fea2..d549701c386 100644 --- a/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/ClientRequestContextTest.java @@ -19,15 +19,21 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy; import java.util.function.Function; +import java.util.stream.Stream; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.ArgumentsProvider; +import org.junit.jupiter.params.provider.ArgumentsSource; import org.junit.jupiter.params.provider.ValueSource; import com.linecorp.armeria.common.HttpMethod; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.RequestContext; import com.linecorp.armeria.common.RequestHeaders; +import com.linecorp.armeria.common.TimeoutException; import com.linecorp.armeria.common.annotation.Nullable; import com.linecorp.armeria.common.util.SafeCloseable; import com.linecorp.armeria.server.ServiceRequestContext; @@ -276,6 +282,24 @@ void updateRequestWithInvalidPath(String path) { .hasMessageContaining("invalid path"); } + @ParameterizedTest + @ArgumentsSource(TimedOutExceptionProvider.class) + void isTimedOut_true(Throwable cause) { + final ClientRequestContext cctx = clientRequestContext(); + cctx.cancel(cause); + cctx.whenResponseCancelled().join(); + assertThat(cctx.isTimedOut()).isTrue(); + } + + @ParameterizedTest + @ArgumentsSource(NotTimedOutExceptionProvider.class) + void isTimedOut_false(Throwable cause) { + final ClientRequestContext cctx = clientRequestContext(); + cctx.cancel(cause); + cctx.whenResponseCancelled().join(); + assertThat(cctx.isTimedOut()).isFalse(); + } + private static void assertUnwrapAllCurrentCtx(@Nullable RequestContext ctx) { final RequestContext current = RequestContext.currentOrNull(); if (current == null) { @@ -292,4 +316,25 @@ private static ServiceRequestContext serviceRequestContext() { private static ClientRequestContext clientRequestContext() { return ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/")); } + + private static class TimedOutExceptionProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) throws Exception { + return Stream.of(new TimeoutException(), + ResponseTimeoutException.get(), + UnprocessedRequestException.of(ResponseTimeoutException.get())) + .map(Arguments::of); + } + } + + private static class NotTimedOutExceptionProvider implements ArgumentsProvider { + + @Override + public Stream provideArguments(ExtensionContext context) throws Exception { + return Stream.of(new RuntimeException(), + UnprocessedRequestException.of(new RuntimeException())) + .map(Arguments::of); + } + } } diff --git a/core/src/test/java/com/linecorp/armeria/client/HttpClientFactoryTest.java b/core/src/test/java/com/linecorp/armeria/client/HttpClientFactoryTest.java index 4163e8e52c4..aeb9d55b585 100644 --- a/core/src/test/java/com/linecorp/armeria/client/HttpClientFactoryTest.java +++ b/core/src/test/java/com/linecorp/armeria/client/HttpClientFactoryTest.java @@ -16,21 +16,39 @@ package com.linecorp.armeria.client; +import static com.google.common.collect.ImmutableList.toImmutableList; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.awaitility.Awaitility.await; +import java.util.concurrent.CompletionException; +import java.util.stream.Stream; + import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import com.google.common.collect.ImmutableMap; + +import com.linecorp.armeria.client.endpoint.dns.TestDnsServer; +import com.linecorp.armeria.common.CommonPools; import com.linecorp.armeria.common.HttpRequest; import com.linecorp.armeria.common.HttpResponse; import com.linecorp.armeria.common.SessionProtocol; +import com.linecorp.armeria.common.metric.PrometheusMeterRegistries; import com.linecorp.armeria.server.AbstractHttpService; import com.linecorp.armeria.server.Server; import com.linecorp.armeria.server.ServerBuilder; import com.linecorp.armeria.server.ServiceRequestContext; import com.linecorp.armeria.testing.junit5.server.ServerExtension; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.dns.DatagramDnsQuery; +import io.netty.resolver.ResolvedAddressTypes; +import io.netty.resolver.dns.DnsServerAddressStreamProvider; +import io.netty.resolver.dns.DnsServerAddresses; +import io.netty.util.ReferenceCountUtil; + class HttpClientFactoryTest { @RegisterExtension public static final ServerExtension server = new ServerExtension() { @@ -103,4 +121,61 @@ protected HttpResponse doGet(ServiceRequestContext ctx, }); } } + + @Test + void execute_dnsTimeout_clientRequestContext_isTimedOut() { + try (TestDnsServer dnsServer = new TestDnsServer(ImmutableMap.of(), new AlwaysTimeoutHandler())) { + try (RefreshingAddressResolverGroup group = dnsTimeoutBuilder(dnsServer) + .build(CommonPools.workerGroup().next())) { + final ClientFactory clientFactory = ClientFactory + .builder() + .addressResolverGroupFactory(eventExecutors -> group) + .build(); + final Endpoint endpoint = Endpoint + .of("test") + .withIpAddr(null); // to invoke dns resolve address + final WebClient client = WebClient + .builder(endpoint.toUri(SessionProtocol.H1C)) + .factory(clientFactory) + .build(); + + try (ClientRequestContextCaptor captor = Clients.newContextCaptor()) { + assertThatThrownBy(() -> client.get("/").aggregate().join()) + .isInstanceOf(CompletionException.class) + .hasCauseInstanceOf(UnprocessedRequestException.class) + .hasRootCauseInstanceOf(DnsTimeoutException.class); + captor.get().whenResponseCancelled().join(); + assertThat(captor.get().isTimedOut()).isTrue(); + } + + clientFactory.close(); + endpoint.close(); + } + } + } + + private static class AlwaysTimeoutHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof DatagramDnsQuery) { + // Just release the msg and return so that the client request is timed out. + ReferenceCountUtil.safeRelease(msg); + return; + } + super.channelRead(ctx, msg); + } + } + + private static DnsResolverGroupBuilder dnsTimeoutBuilder(TestDnsServer... servers) { + final DnsServerAddressStreamProvider dnsServerAddressStreamProvider = + hostname -> DnsServerAddresses.sequential( + Stream.of(servers).map(TestDnsServer::addr).collect(toImmutableList())).stream(); + final DnsResolverGroupBuilder builder = new DnsResolverGroupBuilder() + .serverAddressStreamProvider(dnsServerAddressStreamProvider) + .meterRegistry(PrometheusMeterRegistries.newRegistry()) + .resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY) + .traceEnabled(false) + .queryTimeoutMillis(1); // dns timeout + return builder; + } }