Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: gRPC stream connection deadline #999

Merged
merged 4 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ public final class Config {
static final String DEFAULT_HOST = "localhost";

static final int DEFAULT_DEADLINE = 500;
static final int DEFAULT_STREAM_DEADLINE_MS = 10 * 60 * 1000;
static final int DEFAULT_MAX_CACHE_SIZE = 1000;
static final long DEFAULT_KEEP_ALIVE = 0;

Expand All @@ -31,6 +32,7 @@ public final class Config {
static final String MAX_EVENT_STREAM_RETRIES_ENV_VAR_NAME = "FLAGD_MAX_EVENT_STREAM_RETRIES";
static final String BASE_EVENT_STREAM_RETRY_BACKOFF_MS_ENV_VAR_NAME = "FLAGD_RETRY_BACKOFF_MS";
static final String DEADLINE_MS_ENV_VAR_NAME = "FLAGD_DEADLINE_MS";
static final String STREAM_DEADLINE_MS_ENV_VAR_NAME = "FLAGD_STREAM_DEADLINE_MS";
static final String SOURCE_SELECTOR_ENV_VAR_NAME = "FLAGD_SOURCE_SELECTOR";
static final String OFFLINE_SOURCE_PATH = "FLAGD_OFFLINE_FLAG_SOURCE_PATH";
static final String KEEP_ALIVE_MS_ENV_VAR_NAME_OLD = "FLAGD_KEEP_ALIVE_TIME";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ public class FlagdOptions {
@Builder.Default
private int deadline = fallBackToEnvOrDefault(Config.DEADLINE_MS_ENV_VAR_NAME, Config.DEFAULT_DEADLINE);

/**
* Streaming connection deadline in milliseconds.
* Set to 0 to disable the deadline.
toddbaert marked this conversation as resolved.
Show resolved Hide resolved
*/
@Builder.Default
private int streamDeadlineMs = fallBackToEnvOrDefault(Config.STREAM_DEADLINE_MS_ENV_VAR_NAME,
Config.DEFAULT_STREAM_DEADLINE_MS);

/**
* Selector to be used with flag sync gRPC contract.
**/
Expand All @@ -101,7 +109,7 @@ public class FlagdOptions {
/**
* gRPC client KeepAlive in milliseconds. Disabled with 0.
* Defaults to 0 (disabled).
*
*
**/
@Builder.Default
private long keepAlive = fallBackToEnvOrDefault(Config.KEEP_ALIVE_MS_ENV_VAR_NAME,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class GrpcConnector {

private final int startEventStreamRetryBackoff;
private final long deadline;
private final long streamDeadlineMs;

private final Cache cache;
private final Consumer<ConnectionEvent> onConnectionEvent;
Expand Down Expand Up @@ -64,6 +65,7 @@ public GrpcConnector(final FlagdOptions options, final Cache cache, final Suppli
this.startEventStreamRetryBackoff = options.getRetryBackoffMs();
this.eventStreamRetryBackoff = options.getRetryBackoffMs();
this.deadline = options.getDeadline();
this.streamDeadlineMs = options.getStreamDeadlineMs();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[suggestion] using the options object instead of separate fields would reduce this over head of adding a new field all the time

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can go either way on this one. In fact it might be a good thing to do in a pure refactor/cleanup PR. There's also some naming that we can probably improve with the provider.

this.cache = cache;
this.onConnectionEvent = onConnectionEvent;
this.connectedSupplier = connectedSupplier;
Expand Down Expand Up @@ -126,7 +128,14 @@ private void observeEventStream() {
while (this.eventStreamAttempt <= this.maxEventStreamRetries) {
final StreamObserver<EventStreamResponse> responseObserver = new EventStreamObserver(sync, this.cache,
this::onConnectionEvent);
this.serviceStub.eventStream(EventStreamRequest.getDefaultInstance(), responseObserver);

ServiceGrpc.ServiceStub localServiceStub = this.serviceStub;

if (this.streamDeadlineMs > 0) {
localServiceStub = localServiceStub.withDeadlineAfter(this.streamDeadlineMs, TimeUnit.MILLISECONDS);
}

localServiceStub.eventStream(EventStreamRequest.getDefaultInstance(), responseObserver);

try {
synchronized (sync) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class GrpcStreamConnector implements Connector {
private final FlagSyncServiceStub serviceStub;
private final FlagSyncServiceBlockingStub serviceBlockingStub;
private final int deadline;
private final int streamDeadlineMs;
private final String selector;

/**
Expand All @@ -55,6 +56,7 @@ public GrpcStreamConnector(final FlagdOptions options) {
serviceStub = FlagSyncServiceGrpc.newStub(channel);
serviceBlockingStub = FlagSyncServiceGrpc.newBlockingStub(channel);
deadline = options.getDeadline();
streamDeadlineMs = options.getStreamDeadlineMs();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[suggestion] Cant we just use the options within this class, and pass it further down the call chain? - adding a new field and parameters might be tedious over time

selector = options.getSelector();
}

Expand All @@ -64,7 +66,8 @@ public GrpcStreamConnector(final FlagdOptions options) {
public void init() {
Thread listener = new Thread(() -> {
try {
observeEventStream(blockingQueue, shutdown, serviceStub, serviceBlockingStub, selector, deadline);
observeEventStream(blockingQueue, shutdown, serviceStub, serviceBlockingStub, selector, deadline,
streamDeadlineMs);
} catch (InterruptedException e) {
log.warn("gRPC event stream interrupted, flag configurations are stale", e);
Thread.currentThread().interrupt();
Expand Down Expand Up @@ -114,7 +117,8 @@ static void observeEventStream(final BlockingQueue<QueuePayload> writeTo,
final FlagSyncServiceStub serviceStub,
final FlagSyncServiceBlockingStub serviceBlockingStub,
final String selector,
final int deadline)
final int deadline,
final int streamDeadlineMs)
throws InterruptedException {

final BlockingQueue<GrpcResponseModel> streamReceiver = new LinkedBlockingQueue<>(QUEUE_SIZE);
Expand All @@ -135,7 +139,13 @@ static void observeEventStream(final BlockingQueue<QueuePayload> writeTo,
}

try (CancellableContext context = Context.current().withCancellation()) {
serviceStub.syncFlags(syncRequest.build(), new GrpcStreamHandler(streamReceiver));
FlagSyncServiceStub localServiceStub = serviceStub;
if (streamDeadlineMs > 0) {
localServiceStub = localServiceStub.withDeadlineAfter(streamDeadlineMs, TimeUnit.MILLISECONDS);
}

localServiceStub.syncFlags(syncRequest.build(), new GrpcStreamHandler(streamReceiver));

try {
metadataResponse = serviceBlockingStub.withDeadlineAfter(deadline, TimeUnit.MILLISECONDS)
.getMetadata(metadataRequest.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,15 @@

import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.argThat;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockConstruction;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

import java.lang.reflect.Field;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -58,7 +51,7 @@ void validate_retry_calls(int retries) throws NoSuchFieldException, IllegalAcces

final Cache cache = new Cache("disabled", 0);

final ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
final ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
doAnswer(invocation -> null).when(mockStub).eventStream(any(), any());

final GrpcConnector connector = new GrpcConnector(options, cache, () -> true,
Expand Down Expand Up @@ -94,7 +87,7 @@ void validate_retry_calls(int retries) throws NoSuchFieldException, IllegalAcces
@Test
void initialization_succeed_with_connected_status() throws NoSuchFieldException, IllegalAccessException {
final Cache cache = new Cache("disabled", 0);
final ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
final ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
Consumer<ConnectionEvent> onConnectionEvent = mock(Consumer.class);
doAnswer((InvocationOnMock invocation) -> {
EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1);
Expand Down Expand Up @@ -128,7 +121,7 @@ void initialization_succeed_with_connected_status() throws NoSuchFieldException,
@Test
void initialization_fail_with_timeout() throws Exception {
final Cache cache = new Cache("disabled", 0);
final ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
final ServiceStub mockStub = createServiceStubMock();
Consumer<ConnectionEvent> onConnectionEvent = mock(Consumer.class);
doAnswer((InvocationOnMock invocation) -> {
EventStreamObserver eventStreamObserver = (EventStreamObserver) invocation.getArgument(1);
Expand Down Expand Up @@ -165,7 +158,7 @@ void host_and_port_arg_should_build_tcp_socket() {
final int port = 1234;

ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class);
ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
Expand Down Expand Up @@ -196,7 +189,7 @@ void no_args_host_and_port_env_set_should_build_tcp_socket() throws Exception {

new EnvironmentVariables("FLAGD_HOST", host, "FLAGD_PORT", String.valueOf(port)).execute(() -> {
ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class);
ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
Expand Down Expand Up @@ -230,7 +223,7 @@ void path_arg_should_build_domain_socket_with_correct_path() {
final String path = "/some/path";

ServiceGrpc.ServiceBlockingStub mockBlockingStub = mock(ServiceGrpc.ServiceBlockingStub.class);
ServiceGrpc.ServiceStub mockStub = mock(ServiceGrpc.ServiceStub.class);
ServiceGrpc.ServiceStub mockStub = createServiceStubMock();
NettyChannelBuilder mockChannelBuilder = getMockChannelBuilderSocket();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
Expand Down Expand Up @@ -304,6 +297,50 @@ void no_args_socket_env_should_build_domain_socket_with_correct_path() throws Ex
});
}

@Test
void initialization_with_stream_deadline() throws NoSuchFieldException, IllegalAccessException {
final FlagdOptions options = FlagdOptions.builder()
.streamDeadlineMs(16983)
.build();

final Cache cache = new Cache("disabled", 0);
final ServiceGrpc.ServiceStub mockStub = createServiceStubMock();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub);

final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, null);

assertDoesNotThrow(connector::initialize);
verify(mockStub).withDeadlineAfter(16983, TimeUnit.MILLISECONDS);
}
}

@Test
void initialization_without_stream_deadline() throws NoSuchFieldException, IllegalAccessException {
final FlagdOptions options = FlagdOptions.builder()
.streamDeadlineMs(0)
.build();

final Cache cache = new Cache("disabled", 0);
final ServiceGrpc.ServiceStub mockStub = createServiceStubMock();

try (MockedStatic<ServiceGrpc> mockStaticService = mockStatic(ServiceGrpc.class)) {
mockStaticService.when(() -> ServiceGrpc.newStub(any())).thenReturn(mockStub);

final GrpcConnector connector = new GrpcConnector(options, cache, () -> true, null);

assertDoesNotThrow(connector::initialize);
verify(mockStub, never()).withDeadlineAfter(16983, TimeUnit.MILLISECONDS);
}
}

private static ServiceStub createServiceStubMock() {
final ServiceStub mockStub = mock(ServiceStub.class);
when(mockStub.withDeadlineAfter(anyLong(), any())).thenReturn(mockStub);
return mockStub;
}

private NettyChannelBuilder getMockChannelBuilderSocket() {
NettyChannelBuilder mockChannelBuilder = mock(NettyChannelBuilder.class);
when(mockChannelBuilder.eventLoopGroup(any(EventLoopGroup.class))).thenReturn(mockChannelBuilder);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
import static org.junit.jupiter.api.Assertions.assertTimeoutPreemptively;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

import java.lang.reflect.Field;
import java.time.Duration;
Expand Down Expand Up @@ -42,6 +37,7 @@ public void connectionParameters() throws Throwable {
final FlagdOptions options = FlagdOptions.builder()
.selector("selector")
.deadline(1337)
.streamDeadlineMs(87699)
.build();

final GrpcStreamConnector connector = new GrpcStreamConnector(options);
Expand All @@ -58,6 +54,37 @@ public void connectionParameters() throws Throwable {
connector.init();
verify(stubMock, timeout(MAX_WAIT_MS.toMillis()).times(1)).syncFlags(any(), any());
verify(blockingStubMock).withDeadlineAfter(1337, TimeUnit.MILLISECONDS);
verify(stubMock).withDeadlineAfter(87699, TimeUnit.MILLISECONDS);

// then
final SyncFlagsRequest flagsRequest = request[0];
assertNotNull(flagsRequest);
assertEquals("selector", flagsRequest.getSelector());
}


@Test
public void disableStreamDeadline() throws Throwable {
// given
final FlagdOptions options = FlagdOptions.builder()
.selector("selector")
.streamDeadlineMs(0)
.build();

final GrpcStreamConnector connector = new GrpcStreamConnector(options);
final FlagSyncServiceStub stubMock = mockStubAndReturn(connector);
final FlagSyncServiceBlockingStub blockingStubMock = mockBlockingStubAndReturn(connector);
final SyncFlagsRequest[] request = new SyncFlagsRequest[1];

doAnswer(invocation -> {
request[0] = invocation.getArgument(0, SyncFlagsRequest.class);
return null;
}).when(stubMock).syncFlags(any(), any());

// when
connector.init();
verify(stubMock, timeout(MAX_WAIT_MS.toMillis()).times(1)).syncFlags(any(), any());
verify(stubMock, never()).withDeadlineAfter(anyLong(), any());

// then
final SyncFlagsRequest flagsRequest = request[0];
Expand Down Expand Up @@ -186,6 +213,7 @@ private static FlagSyncServiceStub mockStubAndReturn(final GrpcStreamConnector c
serviceStubField.setAccessible(true);

final FlagSyncServiceStub stubMock = mock(FlagSyncServiceStub.class);
when(stubMock.withDeadlineAfter(anyLong(), any())).thenReturn(stubMock);

serviceStubField.set(connector, stubMock);

Expand Down