Skip to content
This repository has been archived by the owner on Sep 26, 2023. It is now read-only.

Commit

Permalink
feat: add mtls feature to http and grpc transport provider (#1249)
Browse files Browse the repository at this point in the history
* feat: add mtls support to grpc and http transport
  • Loading branch information
arithmetic1728 authored May 26, 2021
1 parent 3b1859e commit b863041
Show file tree
Hide file tree
Showing 16 changed files with 1,045 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.annotations.VisibleForTesting;
Expand All @@ -46,16 +47,22 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.CharStreams;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.TlsChannelCredentials;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
import java.io.InputStreamReader;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import javax.net.ssl.KeyManagerFactory;
import org.threeten.bp.Duration;

/**
Expand Down Expand Up @@ -96,6 +103,7 @@ public final class InstantiatingGrpcChannelProvider implements TransportChannelP
@Nullable private final ChannelPrimer channelPrimer;
@Nullable private final Boolean attemptDirectPath;
@VisibleForTesting final ImmutableMap<String, ?> directPathServiceConfig;
@Nullable private final MtlsProvider mtlsProvider;

@Nullable
private final ApiFunction<ManagedChannelBuilder, ManagedChannelBuilder> channelConfigurator;
Expand All @@ -105,6 +113,7 @@ private InstantiatingGrpcChannelProvider(Builder builder) {
this.executor = builder.executor;
this.headerProvider = builder.headerProvider;
this.endpoint = builder.endpoint;
this.mtlsProvider = builder.mtlsProvider;
this.envProvider = builder.envProvider;
this.interceptorProvider = builder.interceptorProvider;
this.maxInboundMessageSize = builder.maxInboundMessageSize;
Expand Down Expand Up @@ -216,8 +225,13 @@ private TransportChannel createChannel() throws IOException {
int realPoolSize = MoreObjects.firstNonNull(poolSize, 1);
ChannelFactory channelFactory =
new ChannelFactory() {
@Override
public ManagedChannel createSingleChannel() throws IOException {
return InstantiatingGrpcChannelProvider.this.createSingleChannel();
try {
return InstantiatingGrpcChannelProvider.this.createSingleChannel();
} catch (GeneralSecurityException e) {
throw new IOException(e);
}
}
};
ManagedChannel outerChannel;
Expand Down Expand Up @@ -264,7 +278,21 @@ static boolean isOnComputeEngine() {
return false;
}

private ManagedChannel createSingleChannel() throws IOException {
@VisibleForTesting
ChannelCredentials createMtlsChannelCredentials() throws IOException, GeneralSecurityException {
if (mtlsProvider.useMtlsClientCertificate()) {
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
if (mtlsKeyStore != null) {
KeyManagerFactory factory =
KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
factory.init(mtlsKeyStore, new char[] {});
return TlsChannelCredentials.newBuilder().keyManager(factory.getKeyManagers()).build();
}
}
return null;
}

private ManagedChannel createSingleChannel() throws IOException, GeneralSecurityException {
GrpcHeaderInterceptor headerInterceptor =
new GrpcHeaderInterceptor(headerProvider.getHeaders());
GrpcMetadataHandlerInterceptor metadataHandlerInterceptor =
Expand All @@ -290,7 +318,12 @@ && isOnComputeEngine()) {
builder.keepAliveTimeout(DIRECT_PATH_KEEP_ALIVE_TIMEOUT_SECONDS, TimeUnit.SECONDS);
builder.defaultServiceConfig(directPathServiceConfig);
} else {
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
ChannelCredentials channelCredentials = createMtlsChannelCredentials();
if (channelCredentials != null) {
builder = Grpc.newChannelBuilder(endpoint, channelCredentials);
} else {
builder = ManagedChannelBuilder.forAddress(serviceAddress, port);
}
}
builder =
builder
Expand Down Expand Up @@ -376,6 +409,7 @@ public static final class Builder {
private HeaderProvider headerProvider;
private String endpoint;
private EnvironmentProvider envProvider;
private MtlsProvider mtlsProvider = new MtlsProvider();
@Nullable private GrpcInterceptorProvider interceptorProvider;
@Nullable private Integer maxInboundMessageSize;
@Nullable private Integer maxInboundMetadataSize;
Expand Down Expand Up @@ -412,6 +446,7 @@ private Builder(InstantiatingGrpcChannelProvider provider) {
this.channelPrimer = provider.channelPrimer;
this.attemptDirectPath = provider.attemptDirectPath;
this.directPathServiceConfig = provider.directPathServiceConfig;
this.mtlsProvider = provider.mtlsProvider;
}

/** Sets the number of available CPUs, used internally for testing. */
Expand Down Expand Up @@ -458,6 +493,12 @@ public Builder setEndpoint(String endpoint) {
return this;
}

@VisibleForTesting
Builder setMtlsProvider(MtlsProvider mtlsProvider) {
this.mtlsProvider = mtlsProvider;
return this;
}

/**
* Sets the GrpcInterceptorProvider for this TransportChannelProvider.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import com.google.api.gax.grpc.InstantiatingGrpcChannelProvider.Builder;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.oauth2.CloudShellCredentials;
import com.google.auth.oauth2.ComputeEngineCredentials;
import com.google.common.collect.ImmutableList;
Expand All @@ -46,6 +48,7 @@
import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.ComputeEngineChannelBuilder;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -63,8 +66,7 @@
import org.threeten.bp.Duration;

@RunWith(JUnit4.class)
public class InstantiatingGrpcChannelProviderTest {

public class InstantiatingGrpcChannelProviderTest extends AbstractMtlsTransportChannelTest {
@Test
public void testEndpoint() {
String endpoint = "localhost:8080";
Expand Down Expand Up @@ -499,4 +501,17 @@ public void testWithCustomDirectPathServiceConfig() {
ImmutableMap<String, ?> defaultServiceConfig = provider.directPathServiceConfig;
assertThat(defaultServiceConfig).isEqualTo(passedServiceConfig);
}

@Override
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider)
throws IOException, GeneralSecurityException {
InstantiatingGrpcChannelProvider channelProvider =
InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setMtlsProvider(provider)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.build();
return channelProvider.createMtlsChannelCredentials();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import com.google.api.gax.rpc.StubSettings;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.UnaryCallSettings;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
Expand Down Expand Up @@ -83,6 +84,7 @@ private static class FakeStubSettings extends StubSettings<FakeStubSettings> {
public static final int DEFAULT_SERVICE_PORT = 443;
public static final String DEFAULT_SERVICE_ENDPOINT =
DEFAULT_SERVICE_ADDRESS + ':' + DEFAULT_SERVICE_PORT;
public static final MtlsProvider DEFAULT_MTLS_PROVIDER = new MtlsProvider();
public static final ImmutableList<String> DEFAULT_SERVICE_SCOPES =
ImmutableList.<String>builder()
.add("https://www.googleapis.com/auth/pubsub")
Expand Down Expand Up @@ -148,7 +150,9 @@ public static InstantiatingExecutorProvider.Builder defaultExecutorProviderBuild

/** Returns a builder for the default TransportChannelProvider for this service. */
public static InstantiatingGrpcChannelProvider.Builder defaultGrpcChannelProviderBuilder() {
return InstantiatingGrpcChannelProvider.newBuilder().setEndpoint(DEFAULT_SERVICE_ENDPOINT);
return InstantiatingGrpcChannelProvider.newBuilder()
.setEndpoint(DEFAULT_SERVICE_ENDPOINT)
.setMtlsProvider(DEFAULT_MTLS_PROVIDER);
}

public static ApiClientHeaderProvider.Builder defaultGoogleServiceHeaderProviderBuilder() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,21 @@
package com.google.api.gax.httpjson;

import com.google.api.client.http.HttpTransport;
import com.google.api.client.http.javanet.NetHttpTransport;
import com.google.api.core.BetaApi;
import com.google.api.core.InternalExtensionOnly;
import com.google.api.gax.core.ExecutorProvider;
import com.google.api.gax.rpc.FixedHeaderProvider;
import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannel;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import com.google.auth.Credentials;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executor;
Expand All @@ -64,24 +69,28 @@ public final class InstantiatingHttpJsonChannelProvider implements TransportChan
private final HeaderProvider headerProvider;
private final String endpoint;
private final HttpTransport httpTransport;
private final MtlsProvider mtlsProvider;

private InstantiatingHttpJsonChannelProvider(
Executor executor, HeaderProvider headerProvider, String endpoint) {
this.executor = executor;
this.headerProvider = headerProvider;
this.endpoint = endpoint;
this.httpTransport = null;
this.mtlsProvider = new MtlsProvider();
}

private InstantiatingHttpJsonChannelProvider(
Executor executor,
HeaderProvider headerProvider,
String endpoint,
HttpTransport httpTransport) {
HttpTransport httpTransport,
MtlsProvider mtlsProvider) {
this.executor = executor;
this.headerProvider = headerProvider;
this.endpoint = endpoint;
this.httpTransport = httpTransport;
this.mtlsProvider = mtlsProvider;
}

@Override
Expand Down Expand Up @@ -145,7 +154,11 @@ public TransportChannel getTransportChannel() throws IOException {
} else if (needsHeaders()) {
throw new IllegalStateException("getTransportChannel() called when needsHeaders() is true");
} else {
return createChannel();
try {
return createChannel();
} catch (GeneralSecurityException e) {
throw new IOException(e);
}
}
}

Expand All @@ -160,20 +173,35 @@ public TransportChannelProvider withCredentials(Credentials credentials) {
"InstantiatingHttpJsonChannelProvider doesn't need credentials");
}

private TransportChannel createChannel() throws IOException {
HttpTransport createHttpTransport() throws IOException, GeneralSecurityException {
if (mtlsProvider.useMtlsClientCertificate()) {
KeyStore mtlsKeyStore = mtlsProvider.getKeyStore();
if (mtlsKeyStore != null) {
return new NetHttpTransport.Builder().trustCertificates(null, mtlsKeyStore, "").build();
}
}
return null;
}

private TransportChannel createChannel() throws IOException, GeneralSecurityException {
Map<String, String> headers = headerProvider.getHeaders();

List<HttpJsonHeaderEnhancer> headerEnhancers = Lists.newArrayList();
for (Map.Entry<String, String> header : headers.entrySet()) {
headerEnhancers.add(HttpJsonHeaderEnhancers.create(header.getKey(), header.getValue()));
}

HttpTransport httpTransportToUse = httpTransport;
if (httpTransportToUse == null) {
httpTransportToUse = createHttpTransport();
}

ManagedHttpJsonChannel channel =
ManagedHttpJsonChannel.newBuilder()
.setEndpoint(endpoint)
.setHeaderEnhancers(headerEnhancers)
.setExecutor(executor)
.setHttpTransport(httpTransport)
.setHttpTransport(httpTransportToUse)
.build();

return HttpJsonTransportChannel.newBuilder().setManagedChannel(channel).build();
Expand Down Expand Up @@ -202,6 +230,7 @@ public static final class Builder {
private HeaderProvider headerProvider;
private String endpoint;
private HttpTransport httpTransport;
private MtlsProvider mtlsProvider = new MtlsProvider();

private Builder() {}

Expand All @@ -210,6 +239,7 @@ private Builder(InstantiatingHttpJsonChannelProvider provider) {
this.headerProvider = provider.headerProvider;
this.endpoint = provider.endpoint;
this.httpTransport = provider.httpTransport;
this.mtlsProvider = provider.mtlsProvider;
}

/**
Expand Down Expand Up @@ -259,9 +289,15 @@ public String getEndpoint() {
return endpoint;
}

@VisibleForTesting
Builder setMtlsProvider(MtlsProvider mtlsProvider) {
this.mtlsProvider = mtlsProvider;
return this;
}

public InstantiatingHttpJsonChannelProvider build() {
return new InstantiatingHttpJsonChannelProvider(
executor, headerProvider, endpoint, httpTransport);
executor, headerProvider, endpoint, httpTransport, mtlsProvider);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,23 @@
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals;

import com.google.api.gax.rpc.HeaderProvider;
import com.google.api.gax.rpc.TransportChannelProvider;
import com.google.api.gax.rpc.mtls.AbstractMtlsTransportChannelTest;
import com.google.api.gax.rpc.mtls.MtlsProvider;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.Collections;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.Mockito;

@RunWith(JUnit4.class)
public class InstantiatingHttpJsonChannelProviderTest {
public class InstantiatingHttpJsonChannelProviderTest extends AbstractMtlsTransportChannelTest {

@Test
public void basicTest() throws IOException {
Expand Down Expand Up @@ -94,4 +99,17 @@ public void basicTest() throws IOException {
// Make sure we can create channels OK.
provider.getTransportChannel().shutdownNow();
}

@Override
protected Object getMtlsObjectFromTransportChannel(MtlsProvider provider)
throws IOException, GeneralSecurityException {
InstantiatingHttpJsonChannelProvider channelProvider =
InstantiatingHttpJsonChannelProvider.newBuilder()
.setEndpoint("localhost:8080")
.setMtlsProvider(provider)
.setHeaderProvider(Mockito.mock(HeaderProvider.class))
.setExecutor(Mockito.mock(Executor.class))
.build();
return channelProvider.createHttpTransport();
}
}
4 changes: 4 additions & 0 deletions gax/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ java_library(
srcs = glob(["src/test/java/**/*.java"]),
javacopts = _JAVA_COPTS,
plugins = ["//:auto_value_plugin"],
resources = glob([
"src/test/resources/com/google/api/gax/rpc/mtls/mtls_context_aware_metadata.json",
"src/test/resources/com/google/api/gax/rpc/mtls/mtlsCertAndKey.pem",
]),
visibility = ["//visibility:public"],
deps = [":gax"] + _COMPILE_DEPS + _TEST_COMPILE_DEPS,
)
Expand Down
Loading

0 comments on commit b863041

Please sign in to comment.