Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
injae-kim committed Sep 1, 2023
1 parent 9396937 commit 92426fb
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import com.linecorp.armeria.common.RequestId;
import com.linecorp.armeria.common.Response;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.TimeoutException;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.annotation.UnstableApi;
import com.linecorp.armeria.common.logging.RequestLog;
Expand Down Expand Up @@ -507,6 +508,18 @@ default void timeoutNow() {
cancel(ResponseTimeoutException.get());
}

/**
* Returns whether this {@link ClientRequestContext} has been timed-out, that is the cancellation cause
* is an instance of {@link TimeoutException} or
* {@link UnprocessedRequestException} and wrapped cause is {@link TimeoutException}.
*/
@Override
default boolean isTimedOut() {
final Throwable cause = cancellationCause();
return cause instanceof TimeoutException ||
cause instanceof UnprocessedRequestException && cause.getCause() instanceof TimeoutException;
}

/**
* Returns the maximum length of the received {@link Response}.
* This value is initially set from {@link ClientOptions#MAX_RESPONSE_LENGTH}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Ex
ctx.logBuilder().session(null, ctx.sessionProtocol(), timingsBuilder.build());
final UnprocessedRequestException wrappedCause = UnprocessedRequestException.of(cause);
handleEarlyRequestException(ctx, req, wrappedCause);
ctx.cancel(wrappedCause);
res.close(wrappedCause);
ctx.cancel(cause);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,15 @@
import static org.assertj.core.api.Assertions.assertThatThrownBy;

import java.util.function.Function;
import java.util.stream.Stream;

import com.linecorp.armeria.common.TimeoutException;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtensionContext;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.ArgumentsProvider;
import org.junit.jupiter.params.provider.ArgumentsSource;
import org.junit.jupiter.params.provider.ValueSource;

import com.linecorp.armeria.common.HttpMethod;
Expand Down Expand Up @@ -276,6 +282,24 @@ void updateRequestWithInvalidPath(String path) {
.hasMessageContaining("invalid path");
}

@ParameterizedTest
@ArgumentsSource(TimedOutExceptionProvider.class)
void isTimedOut_true(Throwable cause) {
final ClientRequestContext cctx = clientRequestContext();
cctx.cancel(cause);
cctx.whenResponseCancelled().join();
assertThat(cctx.isTimedOut()).isTrue();
}

@ParameterizedTest
@ArgumentsSource(NotTimedOutExceptionProvider.class)
void isTimedOut_false(Throwable cause) {
final ClientRequestContext cctx = clientRequestContext();
cctx.cancel(cause);
cctx.whenResponseCancelled().join();
assertThat(cctx.isTimedOut()).isFalse();
}

private static void assertUnwrapAllCurrentCtx(@Nullable RequestContext ctx) {
final RequestContext current = RequestContext.currentOrNull();
if (current == null) {
Expand All @@ -292,4 +316,25 @@ private static ServiceRequestContext serviceRequestContext() {
private static ClientRequestContext clientRequestContext() {
return ClientRequestContext.of(HttpRequest.of(HttpMethod.GET, "/"));
}

private static class TimedOutExceptionProvider implements ArgumentsProvider {

@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(new TimeoutException(),
ResponseTimeoutException.get(),
UnprocessedRequestException.of(ResponseTimeoutException.get()))
.map(Arguments::of);
}
}

private static class NotTimedOutExceptionProvider implements ArgumentsProvider {

@Override
public Stream<? extends Arguments> provideArguments(ExtensionContext context) throws Exception {
return Stream.of(new RuntimeException(),
UnprocessedRequestException.of(new RuntimeException()))
.map(Arguments::of);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,23 @@

package com.linecorp.armeria.client;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.awaitility.Awaitility.await;

import com.google.common.collect.ImmutableMap;
import com.linecorp.armeria.client.endpoint.dns.TestDnsServer;
import com.linecorp.armeria.common.CommonPools;
import com.linecorp.armeria.common.HttpMethod;
import com.linecorp.armeria.common.metric.PrometheusMeterRegistries;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.dns.DatagramDnsQuery;
import io.netty.resolver.ResolvedAddressTypes;
import io.netty.resolver.dns.DnsServerAddressStreamProvider;
import io.netty.resolver.dns.DnsServerAddresses;
import io.netty.util.ReferenceCountUtil;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

Expand All @@ -31,6 +45,9 @@
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.testing.junit5.server.ServerExtension;

import java.util.concurrent.CompletionException;
import java.util.stream.Stream;

class HttpClientFactoryTest {
@RegisterExtension
public static final ServerExtension server = new ServerExtension() {
Expand Down Expand Up @@ -103,4 +120,65 @@ protected HttpResponse doGet(ServiceRequestContext ctx,
});
}
}

@Test
void execute_dnsTimeout_clientRequestContext_isTimedOut() {
try (TestDnsServer dnsServer = new TestDnsServer(ImmutableMap.of(), new AlwaysTimeoutHandler())) {
try (RefreshingAddressResolverGroup group = dnsTimeoutBuilder(dnsServer)
.build(CommonPools.workerGroup().next())) {
final ClientFactory clientFactory = ClientFactory
.builder()
.addressResolverGroupFactory(eventExecutors -> group)
.build();
final Endpoint endpoint = Endpoint
.of("test")
.withIpAddr(null); // to invoke dns resolve address
final WebClient client = WebClient
.builder(endpoint.toUri(SessionProtocol.H1C))
.factory(clientFactory)
.build();
final HttpRequest req = HttpRequest.of(HttpMethod.GET, "/");
final ClientRequestContext reqCtx = ClientRequestContext
.builder(req)
.endpoint(endpoint)
.build();

assertThat(reqCtx.isTimedOut()).isFalse();
assertThatThrownBy(() -> client.unwrap().execute(reqCtx, req).aggregate().join())
.isInstanceOf(CompletionException.class)
.hasCauseInstanceOf(UnprocessedRequestException.class)
.hasRootCauseInstanceOf(DnsTimeoutException.class);
reqCtx.whenResponseCancelled().join();
assertThat(reqCtx.isTimedOut()).isTrue();

clientFactory.close();
endpoint.close();
}
}
}

private static class AlwaysTimeoutHandler extends ChannelInboundHandlerAdapter {
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (msg instanceof DatagramDnsQuery) {
// Just release the msg and return so that the client request is timed out.
ReferenceCountUtil.safeRelease(msg);
return;
}
super.channelRead(ctx, msg);
}
}

private static DnsResolverGroupBuilder dnsTimeoutBuilder(TestDnsServer... servers) {
final DnsServerAddressStreamProvider dnsServerAddressStreamProvider =
hostname -> DnsServerAddresses.sequential(
Stream.of(servers).map(TestDnsServer::addr).collect(toImmutableList())).stream();
final DnsResolverGroupBuilder builder = new DnsResolverGroupBuilder()
.serverAddressStreamProvider(dnsServerAddressStreamProvider)
.meterRegistry(PrometheusMeterRegistries.newRegistry())
.resolvedAddressTypes(ResolvedAddressTypes.IPV4_ONLY)
.traceEnabled(false)
.queryTimeoutMillis(1); // dns timeout
return builder;
}
}

0 comments on commit 92426fb

Please sign in to comment.