diff --git a/binder/src/test/java/io/grpc/binder/PeerUidTestHelperTest.java b/binder/src/test/java/io/grpc/binder/PeerUidTestHelperTest.java new file mode 100644 index 00000000000..6c4c95412df --- /dev/null +++ b/binder/src/test/java/io/grpc/binder/PeerUidTestHelperTest.java @@ -0,0 +1,123 @@ +package io.grpc.binder; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.ImmutableList; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientInterceptors; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.ServerServiceDefinition; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.ClientCalls; +import io.grpc.stub.MetadataUtils; +import io.grpc.stub.ServerCalls; +import io.grpc.testing.GrpcCleanupRule; +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class PeerUidTestHelperTest { + + @Rule + public final GrpcCleanupRule grpcCleanup = new GrpcCleanupRule(); + + private static final int FAKE_UID = 12345; + + private final AtomicReference clientUidCapture = new AtomicReference<>(); + + @Test + public void keyPopulatedWithInterceptorAndHeader() throws Exception { + makeServiceCall(/* includeInterceptor= */ true, /* includeUidInHeader= */ true, FAKE_UID); + assertThat(clientUidCapture.get()).isEqualTo(new PeerUid(FAKE_UID)); + } + + @Test + public void keyNotPopulatedWithInterceptorAndNoHeader() throws Exception { + makeServiceCall(/* includeInterceptor= */ true, /* includeUidInHeader= */ false, /* uid= */ -1); + assertThat(clientUidCapture.get()).isNull(); + } + + @Test + public void keyNotPopulatedWithoutInterceptorAndWithHeader() throws Exception { + makeServiceCall( + /* includeInterceptor= */ false, /* includeUidInHeader= */ true, /* uid= */ FAKE_UID); + assertThat(clientUidCapture.get()).isNull(); + } + + private final MethodDescriptor method = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/method") + .setType(MethodDescriptor.MethodType.UNARY) + .build(); + + private void makeServiceCall(boolean includeInterceptor, boolean includeUidInHeader, int uid) + throws Exception { + ServerCallHandler callHandler = + ServerCalls.asyncUnaryCall( + (req, respObserver) -> { + clientUidCapture.set(PeerUids.REMOTE_PEER.get()); + respObserver.onNext(req); + respObserver.onCompleted(); + }); + ImmutableList interceptors; + if (includeInterceptor) { + interceptors = ImmutableList.of(PeerUidTestHelper.newTestPeerIdentifyingServerInterceptor()); + } else { + interceptors = ImmutableList.of(); + } + ServerServiceDefinition serviceDef = + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test").addMethod(method, callHandler).build(), + interceptors); + + InProcessServerBuilder server = + InProcessServerBuilder.forName("test").directExecutor().addService(serviceDef); + + grpcCleanup.register(server.build().start()); + + Channel channel = InProcessChannelBuilder.forName("test").directExecutor().build(); + grpcCleanup.register((ManagedChannel) channel); + + if (includeUidInHeader) { + Metadata header = new Metadata(); + header.put(PeerUidTestHelper.UID_KEY, uid); + channel = + ClientInterceptors.intercept(channel, MetadataUtils.newAttachHeadersInterceptor(header)); + } + + ClientCalls.blockingUnaryCall(channel, method, CallOptions.DEFAULT, "hello"); + } + + private static class StringMarshaller implements MethodDescriptor.Marshaller { + + public static final StringMarshaller INSTANCE = new StringMarshaller(); + + @Override + public InputStream stream(String value) { + return new ByteArrayInputStream(value.getBytes(UTF_8)); + } + + @Override + public String parse(InputStream stream) { + try { + return new String(stream.readAllBytes(), UTF_8); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + } +} diff --git a/binder/src/testFixtures/java/io/grpc/binder/PeerUidTestHelper.java b/binder/src/testFixtures/java/io/grpc/binder/PeerUidTestHelper.java new file mode 100644 index 00000000000..0492f2c1e91 --- /dev/null +++ b/binder/src/testFixtures/java/io/grpc/binder/PeerUidTestHelper.java @@ -0,0 +1,63 @@ +package io.grpc.binder; + +import io.grpc.Context; +import io.grpc.Contexts; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; + +/** + * Class which helps set up {@link PeerUids} to be used in tests. + */ +public final class PeerUidTestHelper { + + /** + * The UID of the calling package is set with the value of this key. + */ + public static final Metadata.Key UID_KEY = + Metadata.Key.of("binder-remote-uid-for-unit-testing", PeerUidTestMarshaller.INSTANCE); + + /** + * Creates an interceptor that associates the {@link PeerUids#REMOTE_PEER} key in the request + * {@link Context} with a UID provided by the client in the {@link #UID_KEY} request header, if + * present. + * + *

The returned interceptor works with any gRPC transport but is meant for in-process unit + * testing of gRPC/binder services that depend on {@link PeerUids}. + */ + public static ServerInterceptor newTestPeerIdentifyingServerInterceptor() { + return new ServerInterceptor() { + @Override + public ServerCall.Listener interceptCall( + ServerCall call, Metadata headers, ServerCallHandler next) { + if (headers.containsKey(UID_KEY)) { + Context context = + Context.current().withValue(PeerUids.REMOTE_PEER, new PeerUid(headers.get(UID_KEY))); + return Contexts.interceptCall(context, call, headers, next); + } + return next.startCall(call, headers); + } + }; + } + + private PeerUidTestHelper() { + } + + private static class PeerUidTestMarshaller implements Metadata.AsciiMarshaller { + + public static final PeerUidTestMarshaller INSTANCE = new PeerUidTestMarshaller(); + + @Override + public String toAsciiString(Integer value) { + return value.toString(); + } + + @Override + public Integer parseAsciiString(String serialized) { + return Integer.parseInt(serialized); + } + } + + ; +}