Skip to content

Commit

Permalink
Plumb target to load balancer
Browse files Browse the repository at this point in the history
gRFC A78 has WRR and pick-first include a `grpc.target` label, defined
in A66:

> `grpc.target` : Canonicalized target URI used when creating gRPC
> Channel, e.g. "dns:///pubsub.googleapis.com:443",
> "xds:///helloworld-gke:8000". Canonicalized target URI is the form
> with the scheme included if the user didn't mention the scheme
> (`scheme://[authority]/path`). For channels such as inprocess channels
> where a target URI is not available, implementations can synthesize a
> target URI.
  • Loading branch information
ejona86 committed May 1, 2024
1 parent 27d5758 commit 4561bb5
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 101 deletions.
7 changes: 7 additions & 0 deletions api/src/main/java/io/grpc/LoadBalancer.java
Original file line number Diff line number Diff line change
Expand Up @@ -1185,6 +1185,13 @@ public ScheduledExecutorService getScheduledExecutorService() {
*/
public abstract String getAuthority();

/**
* Returns the target string of the channel, guaranteed to include its scheme.
*/
public String getChannelTarget() {
throw new UnsupportedOperationException();
}

/**
* Returns the ChannelCredentials used to construct the channel, without bearer tokens.
*
Expand Down
53 changes: 34 additions & 19 deletions core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ public Result selectConfig(PickSubchannelArgs args) {
@Nullable
private final String authorityOverride;
private final NameResolverRegistry nameResolverRegistry;
private final URI targetUri;
private final NameResolverProvider nameResolverProvider;
private final NameResolver.Args nameResolverArgs;
private final AutoConfiguredLoadBalancerFactory loadBalancerFactory;
private final ClientTransportFactory originalTransportFactory;
Expand Down Expand Up @@ -383,8 +385,7 @@ private void shutdownNameResolverAndLoadBalancer(boolean channelIsActive) {
nameResolverStarted = false;
if (channelIsActive) {
nameResolver = getNameResolver(
target, authorityOverride, nameResolverRegistry, nameResolverArgs,
transportFactory.getSupportedSocketAddressTypes());
targetUri, authorityOverride, nameResolverProvider, nameResolverArgs);
} else {
nameResolver = null;
}
Expand Down Expand Up @@ -621,6 +622,10 @@ ClientStream newSubstream(
this.retryEnabled = builder.retryEnabled;
this.loadBalancerFactory = new AutoConfiguredLoadBalancerFactory(builder.defaultLbPolicy);
this.nameResolverRegistry = builder.nameResolverRegistry;
ResolvedNameResolver resolvedResolver = getNameResolverProvider(
target, nameResolverRegistry, transportFactory.getSupportedSocketAddressTypes());
this.targetUri = resolvedResolver.targetUri;
this.nameResolverProvider = resolvedResolver.provider;
ScParser serviceConfigParser =
new ScParser(
retryEnabled,
Expand All @@ -640,8 +645,7 @@ ClientStream newSubstream(
.setOverrideAuthority(this.authorityOverride)
.build();
this.nameResolver = getNameResolver(
target, authorityOverride, nameResolverRegistry, nameResolverArgs,
transportFactory.getSupportedSocketAddressTypes());
targetUri, authorityOverride, nameResolverProvider, nameResolverArgs);
this.balancerRpcExecutorPool = checkNotNull(balancerRpcExecutorPool, "balancerRpcExecutorPool");
this.balancerRpcExecutorHolder = new ExecutorHolder(balancerRpcExecutorPool);
this.delayedTransport = new DelayedClientTransport(this.executor, this.syncContext);
Expand Down Expand Up @@ -713,8 +717,20 @@ public CallTracer create() {
}
}

private static NameResolver getNameResolver(
String target, NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs,
@VisibleForTesting
static class ResolvedNameResolver {
public final URI targetUri;
public final NameResolverProvider provider;

public ResolvedNameResolver(URI targetUri, NameResolverProvider provider) {
this.targetUri = checkNotNull(targetUri, "targetUri");
this.provider = checkNotNull(provider, "provider");
}
}

@VisibleForTesting
static ResolvedNameResolver getNameResolverProvider(
String target, NameResolverRegistry nameResolverRegistry,
Collection<Class<? extends SocketAddress>> channelTransportSocketAddressTypes) {
// Finding a NameResolver. Try using the target string as the URI. If that fails, try prepending
// "dns:///".
Expand Down Expand Up @@ -761,23 +777,17 @@ private static NameResolver getNameResolver(
}
}

NameResolver resolver = provider.newNameResolver(targetUri, nameResolverArgs);
if (resolver != null) {
return resolver;
}

throw new IllegalArgumentException(String.format(
"cannot create a NameResolver for %s%s",
target, uriSyntaxErrors.length() > 0 ? " (" + uriSyntaxErrors + ")" : ""));
return new ResolvedNameResolver(targetUri, provider);
}

@VisibleForTesting
static NameResolver getNameResolver(
String target, @Nullable final String overrideAuthority,
NameResolverRegistry nameResolverRegistry, NameResolver.Args nameResolverArgs,
Collection<Class<? extends SocketAddress>> channelTransportSocketAddressTypes) {
NameResolver resolver = getNameResolver(target, nameResolverRegistry, nameResolverArgs,
channelTransportSocketAddressTypes);
URI targetUri, @Nullable final String overrideAuthority,
NameResolverProvider provider, NameResolver.Args nameResolverArgs) {
NameResolver resolver = provider.newNameResolver(targetUri, nameResolverArgs);
if (resolver == null) {
throw new IllegalArgumentException("cannot create a NameResolver for " + targetUri);
}

// We wrap the name resolver in a RetryingNameResolver to give it the ability to retry failures.
// TODO: After a transition period, all NameResolver implementations that need retry should use
Expand Down Expand Up @@ -1703,6 +1713,11 @@ public String getAuthority() {
return ManagedChannelImpl.this.authority();
}

@Override
public String getChannelTarget() {
return targetUri.toString();
}

@Override
public SynchronizationContext getSynchronizationContext() {
return syncContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,22 @@
package io.grpc.internal;

import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;

import io.grpc.ChannelLogger;
import io.grpc.NameResolver;
import io.grpc.NameResolver.Args;
import io.grpc.NameResolver.ServiceConfigParser;
import io.grpc.NameResolverProvider;
import io.grpc.NameResolverRegistry;
import io.grpc.ProxyDetector;
import io.grpc.SynchronizationContext;
import io.grpc.inprocess.InProcessSocketAddress;
import java.lang.Thread.UncaughtExceptionHandler;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.Collections;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Unit tests for ManagedChannelImpl#getNameResolver(). */
/** Unit tests for ManagedChannelImpl#getNameResolverProvider(). */
@RunWith(JUnit4.class)
public class ManagedChannelImplGetNameResolverTest {
private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder()
.setDefaultPort(447)
.setProxyDetector(mock(ProxyDetector.class))
.setSynchronizationContext(new SynchronizationContext(mock(UncaughtExceptionHandler.class)))
.setServiceConfigParser(mock(ServiceConfigParser.class))
.setChannelLogger(mock(ChannelLogger.class))
.setScheduledExecutorService(new FakeClock().getScheduledExecutorService())
.build();

@Test
public void invalidUriTarget() {
testInvalidTarget("defaultscheme:///[invalid]");
Expand All @@ -68,18 +50,6 @@ public void validAuthorityTarget() throws Exception {
new URI("defaultscheme", "", "/foo.googleapis.com:8080", null));
}

@Test
public void validAuthorityTarget_overrideAuthority() throws Exception {
String target = "foo.googleapis.com:8080";
String overrideAuthority = "override.authority";
URI expectedUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null);
NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme());
NameResolver nameResolver = ManagedChannelImpl.getNameResolver(
target, overrideAuthority, nameResolverRegistry, NAMERESOLVER_ARGS,
Collections.singleton(InetSocketAddress.class));
assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority);
}

@Test
public void validUriTarget() throws Exception {
testValidTarget("scheme:///foo.googleapis.com:8080", "scheme:///foo.googleapis.com:8080",
Expand Down Expand Up @@ -121,47 +91,12 @@ public void validTargetStartingWithSlash() throws Exception {
new URI("defaultscheme", "", "//target", null));
}

@Test
public void validTargetNoResolver() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
NameResolverProvider nameResolverProvider = new NameResolverProvider() {
@Override
protected boolean isAvailable() {
return true;
}

@Override
protected int priority() {
return 5;
}

@Override
public NameResolver newNameResolver(URI targetUri, Args args) {
return null;
}

@Override
public String getDefaultScheme() {
return "defaultscheme";
}
};
nameResolverRegistry.register(nameResolverProvider);
try {
ManagedChannelImpl.getNameResolver(
"foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS,
Collections.singleton(InetSocketAddress.class));
fail("Should fail");
} catch (IllegalArgumentException e) {
// expected
}
}

@Test
public void validTargetNoProvider() {
NameResolverRegistry nameResolverRegistry = new NameResolverRegistry();
try {
ManagedChannelImpl.getNameResolver(
"foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS,
ManagedChannelImpl.getNameResolverProvider(
"foo.googleapis.com:8080", nameResolverRegistry,
Collections.singleton(InetSocketAddress.class));
fail("Should fail");
} catch (IllegalArgumentException e) {
Expand All @@ -173,8 +108,8 @@ public void validTargetNoProvider() {
public void validTargetProviderAddrTypesNotSupported() {
NameResolverRegistry nameResolverRegistry = getTestRegistry("testscheme");
try {
ManagedChannelImpl.getNameResolver(
"testscheme:///foo.googleapis.com:8080", null, nameResolverRegistry, NAMERESOLVER_ARGS,
ManagedChannelImpl.getNameResolverProvider(
"testscheme:///foo.googleapis.com:8080", nameResolverRegistry,
Collections.singleton(InProcessSocketAddress.class));
fail("Should fail");
} catch (IllegalArgumentException e) {
Expand All @@ -184,26 +119,23 @@ public void validTargetProviderAddrTypesNotSupported() {
}
}


private void testValidTarget(String target, String expectedUriString, URI expectedUri) {
NameResolverRegistry nameResolverRegistry = getTestRegistry(expectedUri.getScheme());
FakeNameResolver nameResolver
= (FakeNameResolver) ((RetryingNameResolver) ManagedChannelImpl.getNameResolver(
target, null, nameResolverRegistry, NAMERESOLVER_ARGS,
Collections.singleton(InetSocketAddress.class))).getRetriedNameResolver();
assertNotNull(nameResolver);
assertEquals(expectedUri, nameResolver.uri);
assertEquals(expectedUriString, nameResolver.uri.toString());
ManagedChannelImpl.ResolvedNameResolver resolved = ManagedChannelImpl.getNameResolverProvider(
target, nameResolverRegistry, Collections.singleton(InetSocketAddress.class));
assertThat(resolved.provider).isInstanceOf(FakeNameResolverProvider.class);
assertThat(resolved.targetUri).isEqualTo(expectedUri);
assertThat(resolved.targetUri.toString()).isEqualTo(expectedUriString);
}

private void testInvalidTarget(String target) {
NameResolverRegistry nameResolverRegistry = getTestRegistry("dns");

try {
FakeNameResolver nameResolver = (FakeNameResolver) ManagedChannelImpl.getNameResolver(
target, null, nameResolverRegistry, NAMERESOLVER_ARGS,
Collections.singleton(InetSocketAddress.class));
fail("Should have failed, but got resolver with " + nameResolver.uri);
ManagedChannelImpl.ResolvedNameResolver resolved = ManagedChannelImpl.getNameResolverProvider(
target, nameResolverRegistry, Collections.singleton(InetSocketAddress.class));
FakeNameResolverProvider nameResolverProvider = (FakeNameResolverProvider) resolved.provider;
fail("Should have failed, but got resolver provider " + nameResolverProvider);
} catch (IllegalArgumentException e) {
// expected
}
Expand Down
85 changes: 85 additions & 0 deletions core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
import io.grpc.NameResolver;
import io.grpc.NameResolver.ConfigOrError;
import io.grpc.NameResolver.ResolutionResult;
import io.grpc.NameResolverProvider;
import io.grpc.NameResolverRegistry;
import io.grpc.ProxiedSocketAddress;
import io.grpc.ProxyDetector;
Expand All @@ -112,6 +113,7 @@
import io.grpc.Status;
import io.grpc.Status.Code;
import io.grpc.StringMarshaller;
import io.grpc.SynchronizationContext;
import io.grpc.internal.ClientTransportFactory.ClientTransportOptions;
import io.grpc.internal.ClientTransportFactory.SwapChannelCredentialsResult;
import io.grpc.internal.InternalSubchannel.TransportLogger;
Expand Down Expand Up @@ -188,6 +190,15 @@ public class ManagedChannelImplTest {
.setUserAgent(USER_AGENT);
private static final String TARGET = "fake://" + SERVICE_NAME;
private static final String MOCK_POLICY_NAME = "mock_lb";
private static final NameResolver.Args NAMERESOLVER_ARGS = NameResolver.Args.newBuilder()
.setDefaultPort(447)
.setProxyDetector(mock(ProxyDetector.class))
.setSynchronizationContext(
new SynchronizationContext(mock(Thread.UncaughtExceptionHandler.class)))
.setServiceConfigParser(mock(NameResolver.ServiceConfigParser.class))
.setScheduledExecutorService(new FakeClock().getScheduledExecutorService())
.build();

private URI expectedUri;
private final SocketAddress socketAddress =
new SocketAddress() {
Expand Down Expand Up @@ -4306,6 +4317,80 @@ public void transportTerminated(Attributes transportAttrs) {
assertEquals(1, terminationCallbackCalled.get());
}

@Test
public void validAuthorityTarget_overrideAuthority() throws Exception {
String overrideAuthority = "override.authority";
String serviceAuthority = "fakeauthority";
NameResolverProvider nameResolverProvider = new NameResolverProvider() {
@Override protected boolean isAvailable() {
return true;
}

@Override protected int priority() {
return 5;
}

@Override public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
return new NameResolver() {
@Override public String getServiceAuthority() {
return serviceAuthority;
}

@Override public void start(final Listener2 listener) {}

@Override public void shutdown() {}
};
}

@Override public String getDefaultScheme() {
return "defaultscheme";
}
};

URI targetUri = new URI("defaultscheme", "", "/foo.googleapis.com:8080", null);
NameResolver nameResolver = ManagedChannelImpl.getNameResolver(
targetUri, null, nameResolverProvider, NAMERESOLVER_ARGS);
assertThat(nameResolver.getServiceAuthority()).isEqualTo(serviceAuthority);

nameResolver = ManagedChannelImpl.getNameResolver(
targetUri, overrideAuthority, nameResolverProvider, NAMERESOLVER_ARGS);
assertThat(nameResolver.getServiceAuthority()).isEqualTo(overrideAuthority);
}

@Test
public void validTargetNoResolver_throws() {
NameResolverProvider nameResolverProvider = new NameResolverProvider() {
@Override
protected boolean isAvailable() {
return true;
}

@Override
protected int priority() {
return 5;
}

@Override
public NameResolver newNameResolver(URI targetUri, NameResolver.Args args) {
return null;
}

@Override
public String getDefaultScheme() {
return "defaultscheme";
}
};
try {
ManagedChannelImpl.getNameResolver(
URI.create("defaultscheme:///foo.gogoleapis.com:8080"),
null, nameResolverProvider, NAMERESOLVER_ARGS);
fail("Should fail");
} catch (IllegalArgumentException e) {
// expected
}
}


private static final class FakeBackoffPolicyProvider implements BackoffPolicy.Provider {
@Override
public BackoffPolicy get() {
Expand Down
Loading

0 comments on commit 4561bb5

Please sign in to comment.