Skip to content

Commit

Permalink
feat: Update logic for determining the endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
lqiu96 committed Nov 7, 2023
1 parent 0726f00 commit 6e6439c
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import com.google.api.gax.core.ExecutorAsBackgroundResource;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.rpc.internal.QuotaProjectIdHidingCredentials;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.api.gax.tracing.ApiTracerFactory;
import com.google.api.gax.tracing.BaseApiTracerFactory;
import com.google.auth.Credentials;
Expand Down Expand Up @@ -146,29 +145,6 @@ public static ClientContext create(ClientSettings settings) throws IOException {
return create(settings.getStubSettings());
}

/** Returns the endpoint that should be used. See https://google.aip.dev/auth/4114. */
static String getEndpoint(
String endpoint,
String mtlsEndpoint,
boolean switchToMtlsEndpointAllowed,
MtlsProvider mtlsProvider)
throws IOException {
if (switchToMtlsEndpointAllowed) {
switch (mtlsProvider.getMtlsEndpointUsagePolicy()) {
case ALWAYS:
return mtlsEndpoint;
case NEVER:
return endpoint;
default:
if (mtlsProvider.useMtlsClientCertificate() && mtlsProvider.getKeyStore() != null) {
return mtlsEndpoint;
}
return endpoint;
}
}
return endpoint;
}

/**
* Instantiates the executor, credentials, and transport context based on the given client
* settings.
Expand All @@ -187,8 +163,8 @@ public static ClientContext create(StubSettings settings) throws IOException {
String audienceString;
if (!Strings.isNullOrEmpty(settingsGdchApiAudience)) {
audienceString = settingsGdchApiAudience;
} else if (!Strings.isNullOrEmpty(settings.getEndpoint())) {
audienceString = settings.getEndpoint();
} else if (!Strings.isNullOrEmpty(settings.getUnresolvedEndpoint())) {
audienceString = settings.getUnresolvedEndpoint();
} else {
throw new IllegalArgumentException("Could not infer GDCH api audience from settings");
}
Expand Down Expand Up @@ -230,12 +206,6 @@ public static ClientContext create(StubSettings settings) throws IOException {
EndpointContext endpointContext = settings.getEndpointContext();
String endpoint = endpointContext.resolveEndpoint(credentials);
String universeDomain = endpointContext.resolveUniverseDomain(credentials);
// String endpoint =
// getEndpoint(
// settings.getEndpoint(),
// settings.getMtlsEndpoint(),
// settings.getSwitchToMtlsEndpointAllowed(),
// new MtlsProvider());
if (transportChannelProvider.needsEndpoint()) {
transportChannelProvider = transportChannelProvider.withEndpoint(endpoint);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@
package com.google.api.gax.rpc;

import com.google.api.core.InternalApi;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import java.io.IOException;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.annotation.Nullable;
Expand All @@ -41,8 +44,10 @@
public abstract class EndpointContext {
private static final String DEFAULT_UNIVERSE_DOMAIN = "googleapis.com";
private static final String DEFAULT_PORT = "443";
private static final String UNIVERSE_DOMAIN_TEMPLATE = "SERVICE_NAME.UNIVERSE_DOMAIN:PORT";
private static final Pattern ENDPOINT_REGEX = Pattern.compile("^[a-zA-Z]+\\.[\\S]+:\\d+$");
private static final String UNIVERSE_DOMAIN_TEMPLATE =
"https://SERVICE_NAME.UNIVERSE_DOMAIN:PORT";
private static final Pattern ENDPOINT_REGEX =
Pattern.compile("^(https\\:\\/\\/)?(www.)?[a-zA-Z]+\\.[\\S]+(\\:\\d)?$");

@Nullable
public abstract String clientSettingsEndpoint();
Expand All @@ -58,6 +63,10 @@ public abstract class EndpointContext {
@Nullable
public abstract String universeDomain();

@VisibleForTesting
@Nullable
public abstract MtlsProvider mtlsProvider();

public abstract Builder toBuilder();

private String resolvedEndpoint;
Expand All @@ -70,7 +79,7 @@ public static Builder newBuilder() {
// By default, the clientSettingsEndpoint value is the default_host endpoint
// value configured in the service. Users can override this value by the Setter
// exposed in the Client/Stub Settings or in the TransportChannelProvider.
private void determineEndpoint() {
private void determineEndpoint() throws IOException {
if (resolvedEndpoint != null && resolvedUniverseDomain != null) {
return;
}
Expand All @@ -88,11 +97,50 @@ private void determineEndpoint() {
// throw new Exception("Invalid endpoint: " + customEndpoint);
return;
}

MtlsProvider mtlsProvider = mtlsProvider() == null ? new MtlsProvider() : mtlsProvider();
boolean isUsingMtlsEndpoint = false;
if (switchToMtlsEndpointAllowed() && mtlsProvider != null) {
switch (mtlsProvider.getMtlsEndpointUsagePolicy()) {
case ALWAYS:
customEndpoint = mtlsEndpoint();
isUsingMtlsEndpoint = true;
break;
case NEVER:
// CustomEndpoint is already set
break;
default:
if (mtlsProvider.useMtlsClientCertificate() && mtlsProvider.getKeyStore() != null) {
customEndpoint = mtlsEndpoint();
isUsingMtlsEndpoint = true;
break;
}
}
}
// mTLS is not supported yet. If mTLS is enabled, use that endpoint.
if (isUsingMtlsEndpoint) {
resolvedEndpoint = mtlsEndpoint();
resolvedUniverseDomain = DEFAULT_UNIVERSE_DOMAIN;
return;
}

if (customEndpoint.contains("https://")) {
customEndpoint = customEndpoint.substring(8);
}

int periodIndex = customEndpoint.indexOf('.');
int colonIndex = customEndpoint.indexOf(':');
String serviceName = customEndpoint.substring(0, periodIndex);
String universeDomain = customEndpoint.substring(periodIndex + 1, colonIndex);
String port = customEndpoint.substring(colonIndex + 1);
String serviceName;
String universeDomain;
String port = "443";
if (colonIndex != -1) {
universeDomain = customEndpoint.substring(periodIndex + 1, colonIndex);
port = customEndpoint.substring(colonIndex + 1);
} else {
universeDomain = customEndpoint.substring(periodIndex + 1);
}
serviceName = customEndpoint.substring(0, periodIndex);

// TODO: Build out logic for resolving endpoint
resolvedEndpoint = buildEndpoint(serviceName, universeDomain, port);
resolvedUniverseDomain = universeDomain;
Expand All @@ -109,14 +157,14 @@ private String buildEndpoint(String serviceName, String universeDomain, String p
.replace("PORT", port);
}

public String resolveEndpoint(Credentials credentials) {
public String resolveEndpoint(Credentials credentials) throws IOException {
if (resolvedEndpoint == null) {
determineEndpoint();
}
return resolvedEndpoint;
}

public String resolveUniverseDomain(Credentials credentials) {
public String resolveUniverseDomain(Credentials credentials) throws IOException {
if (resolvedUniverseDomain == null) {
determineEndpoint();
}
Expand All @@ -135,6 +183,9 @@ public abstract static class Builder {

public abstract Builder setUniverseDomain(String universeDomain);

@VisibleForTesting
public abstract Builder setMtlsProvider(MtlsProvider mtlsProvider);

public abstract EndpointContext build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import com.google.api.core.ApiClock;
import com.google.api.core.ApiFunction;
import com.google.api.core.BetaApi;
import com.google.api.core.InternalApi;
import com.google.api.core.NanoClock;
import com.google.api.gax.core.CredentialsProvider;
import com.google.api.gax.core.ExecutorProvider;
Expand Down Expand Up @@ -141,12 +142,36 @@ public final ApiClock getClock() {
return clock;
}

/**
* Resolves the endpoint with the correct Universe Domain
*
* @return Resolved Endpoint or null if there is any issue resolving it
*/
public final String getEndpoint() {
try {
return endpointContext.resolveEndpoint(null);
} catch (IOException e) {
return null;
}
}

// This is to return the custom user set endpoint for GDC-H
@InternalApi
final String getUnresolvedEndpoint() {
return endpoint;
}

/**
* Resolves the Universe Domain
*
* @return Resolved Universe Domain or null if there is any issue resolving it
*/
public final String getUniverseDomain() {
return universeDomain;
try {
return endpointContext.resolveUniverseDomain(null);
} catch (IOException e) {
return null;
}
}

public final String getMtlsEndpoint() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,8 +633,8 @@ public void testUserAgentConcat() throws Exception {
.containsEntry("user-agent", "user-supplied-agent internal-agent");
}

private static String endpoint = "https://foo.googleapis.com";
private static String mtlsEndpoint = "https://foo.mtls.googleapis.com";
private static String endpoint = "https://foo.googleapis.com:443";
private static String mtlsEndpoint = "https://foo.mtls.googleapis.com:443";

@Test
public void testAutoUseMtlsEndpoint() throws IOException {
Expand All @@ -647,8 +647,14 @@ public void testAutoUseMtlsEndpoint() throws IOException {
FakeMtlsProvider.createTestMtlsKeyStore(),
"",
false);
String endpointSelected =
ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider);
EndpointContext endpointContext =
EndpointContext.newBuilder()
.setClientSettingsEndpoint(endpoint)
.setMtlsEndpoint(mtlsEndpoint)
.setSwitchToMtlsEndpointAllowed(switchToMtlsEndpointAllowed)
.setMtlsProvider(provider)
.build();
String endpointSelected = endpointContext.resolveEndpoint(null);
assertEquals(mtlsEndpoint, endpointSelected);
}

Expand All @@ -664,8 +670,14 @@ public void testEndpointNotOverridable() throws IOException {
FakeMtlsProvider.createTestMtlsKeyStore(),
"",
false);
String endpointSelected =
ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider);
EndpointContext endpointContext =
EndpointContext.newBuilder()
.setClientSettingsEndpoint(endpoint)
.setMtlsEndpoint(mtlsEndpoint)
.setSwitchToMtlsEndpointAllowed(switchToMtlsEndpointAllowed)
.setMtlsProvider(provider)
.build();
String endpointSelected = endpointContext.resolveEndpoint(null);
assertEquals(endpoint, endpointSelected);
}

Expand All @@ -675,8 +687,14 @@ public void testNoClientCertificate() throws IOException {
boolean switchToMtlsEndpointAllowed = true;
MtlsProvider provider =
new FakeMtlsProvider(true, MtlsEndpointUsagePolicy.AUTO, null, "", false);
String endpointSelected =
ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider);
EndpointContext endpointContext =
EndpointContext.newBuilder()
.setClientSettingsEndpoint(endpoint)
.setMtlsEndpoint(mtlsEndpoint)
.setSwitchToMtlsEndpointAllowed(switchToMtlsEndpointAllowed)
.setMtlsProvider(provider)
.build();
String endpointSelected = endpointContext.resolveEndpoint(null);
assertEquals(endpoint, endpointSelected);
}

Expand All @@ -686,8 +704,14 @@ public void testAlwaysUseMtlsEndpoint() throws IOException {
boolean switchToMtlsEndpointAllowed = true;
MtlsProvider provider =
new FakeMtlsProvider(false, MtlsEndpointUsagePolicy.ALWAYS, null, "", false);
String endpointSelected =
ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider);
EndpointContext endpointContext =
EndpointContext.newBuilder()
.setClientSettingsEndpoint(endpoint)
.setMtlsEndpoint(mtlsEndpoint)
.setSwitchToMtlsEndpointAllowed(switchToMtlsEndpointAllowed)
.setMtlsProvider(provider)
.build();
String endpointSelected = endpointContext.resolveEndpoint(null);
assertEquals(mtlsEndpoint, endpointSelected);
}

Expand All @@ -702,8 +726,14 @@ public void testNeverUseMtlsEndpoint() throws IOException {
FakeMtlsProvider.createTestMtlsKeyStore(),
"",
false);
String endpointSelected =
ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider);
EndpointContext endpointContext =
EndpointContext.newBuilder()
.setClientSettingsEndpoint(endpoint)
.setMtlsEndpoint(mtlsEndpoint)
.setSwitchToMtlsEndpointAllowed(switchToMtlsEndpointAllowed)
.setMtlsProvider(provider)
.build();
String endpointSelected = endpointContext.resolveEndpoint(null);
assertEquals(endpoint, endpointSelected);
}

Expand All @@ -714,7 +744,14 @@ public void testGetKeyStoreThrows() throws IOException {
boolean switchToMtlsEndpointAllowed = true;
MtlsProvider provider =
new FakeMtlsProvider(true, MtlsEndpointUsagePolicy.AUTO, null, "", true);
ClientContext.getEndpoint(endpoint, mtlsEndpoint, switchToMtlsEndpointAllowed, provider);
EndpointContext endpointContext =
EndpointContext.newBuilder()
.setClientSettingsEndpoint(endpoint)
.setMtlsEndpoint(mtlsEndpoint)
.setSwitchToMtlsEndpointAllowed(switchToMtlsEndpointAllowed)
.setMtlsProvider(provider)
.build();
String endpointSelected = endpointContext.resolveEndpoint(null);
fail("should throw an exception");
} catch (IOException e) {
assertTrue(
Expand Down

0 comments on commit 6e6439c

Please sign in to comment.