Skip to content

Commit

Permalink
Close subscribeToLogs RPC when a Session closes
Browse files Browse the repository at this point in the history
  • Loading branch information
devinrsmith committed Apr 26, 2024
1 parent 3ca98c9 commit a70368a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
import javax.inject.Inject;
import javax.inject.Provider;
import javax.inject.Singleton;
import java.io.Closeable;
import java.io.IOException;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
Expand Down Expand Up @@ -152,13 +154,13 @@ public void startConsole(
public void subscribeToLogs(
@NotNull final LogSubscriptionRequest request,
@NotNull final StreamObserver<LogSubscriptionData> responseObserver) {
sessionService.getCurrentSession();
final SessionState session = sessionService.getCurrentSession();
if (REMOTE_CONSOLE_DISABLED) {
GrpcUtil.safelyError(responseObserver, Code.FAILED_PRECONDITION, "Remote console disabled");
return;
}
final LogsClient client =
new LogsClient(request, (ServerCallStreamObserver<LogSubscriptionData>) responseObserver);
new LogsClient(session, request, (ServerCallStreamObserver<LogSubscriptionData>) responseObserver);
client.start();
}

Expand Down Expand Up @@ -371,7 +373,8 @@ public void cancelAutoComplete(
super.cancelAutoComplete(request, responseObserver);
}

private final class LogsClient implements LogBufferRecordListener, Runnable {
private final class LogsClient implements LogBufferRecordListener, Runnable, Closeable {
private final SessionState session;
private final LogSubscriptionRequest request;
private final ServerCallStreamObserver<LogSubscriptionData> client;
private final LockFreeArrayQueue<LogSubscriptionData> buffer;
Expand All @@ -380,8 +383,10 @@ private final class LogsClient implements LogBufferRecordListener, Runnable {
private volatile boolean tooSlow;

public LogsClient(
final SessionState session,
final LogSubscriptionRequest request,
final ServerCallStreamObserver<LogSubscriptionData> client) {
this.session = Objects.requireNonNull(session);
this.request = Objects.requireNonNull(request);
this.client = Objects.requireNonNull(client);
// Our buffer capacity should always be greater than the capacity of the logBuffer; otherwise, the initial
Expand All @@ -391,20 +396,17 @@ public LogsClient(
// clients subscribing to logs.
this.buffer = LockFreeArrayQueue.of(Math.max(SUBSCRIBE_TO_LOGS_BUFFER_SIZE, logBuffer.capacity() * 2));
this.guard = new AtomicBoolean(false);
this.client.setOnReadyHandler(this::onReady);
this.client.setOnCancelHandler(this::onCancel);
this.client.setOnCloseHandler(this::onClose);
this.session.addOnCloseCallback(this);
this.client.setOnReadyHandler(this::onReadyHandler);
this.client.setOnCancelHandler(this::onCancelHandler);
this.client.setOnCloseHandler(this::onCloseHandler);
}

public void start() {
logBuffer.subscribe(this);
scheduler.runImmediately(this);
}

public void stop() {
GrpcUtil.safelyComplete(client);
}

// ------------------------------------------------------------------------------------------------------------

@Override
Expand Down Expand Up @@ -492,18 +494,27 @@ public void run() {

// ------------------------------------------------------------------------------------------------------------

private void onReady() {
private void onReadyHandler() {
scheduler.runImmediately(this);
}

private void onClose() {
private void onCloseHandler() {
done = true;
logBuffer.unsubscribe(this);
session.removeOnCloseCallback(this);
}

private void onCancel() {
private void onCancelHandler() {
done = true;
logBuffer.unsubscribe(this);
session.removeOnCloseCallback(this);
}

@Override
public void close() {
// this is exclusively for session.addOnCloseCallback
// Note: SessionCloseableObserver prefers to do onComplete, but I think this is better?
GrpcUtil.safelyError(client, Code.CANCELLED, "Session closed");
}

// ------------------------------------------------------------------------------------------------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,25 @@

import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.time.Instant;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import static org.assertj.core.api.Assertions.assertThat;

public class ConsoleServiceTest extends DeephavenApiServerSingleAuthenticatedBase {

private static class Observer implements ClientResponseObserver<LogSubscriptionRequest, LogSubscriptionData> {
private final CountDownLatch latch;
private final CountDownLatch done;
private class Observer implements ClientResponseObserver<LogSubscriptionRequest, LogSubscriptionData> {
private final CountDownLatch onNext;
private final CountDownLatch onDone;
private ClientCallStreamObserver<?> stream;
private volatile Throwable error;

public Observer(int expected) {
latch = new CountDownLatch(expected);
done = new CountDownLatch(1);
onNext = new CountDownLatch(expected);
onDone = new CountDownLatch(1);
}

@Override
Expand All @@ -41,56 +43,91 @@ public void beforeStart(ClientCallStreamObserver<LogSubscriptionRequest> request

@Override
public void onNext(LogSubscriptionData value) {
if (latch.getCount() == 0) {
if (onNext.getCount() == 0) {
throw new IllegalStateException("Expected latch count exceeded");
}
latch.countDown();
onNext.countDown();
}

@Override
public void onError(Throwable t) {
error = t;
done.countDown();
onDone.countDown();
}

@Override
public void onCompleted() {
done.countDown();
onDone.countDown();
}

void cancel(String message, Throwable cause) {
stream.cancel(message, cause);
}

void subscribeToLogs() {
channel().console().subscribeToLogs(LogSubscriptionRequest.getDefaultInstance(), this);
}

void awaitRpcEstablished(Duration duration) throws InterruptedException, TimeoutException {
// There is no other way afaict (at least w/ the observer interfaces that gRPC libraries provide) to know
// that an RPC has been established _besides_ waiting for an onNext message.
assertThat(onNext.getCount()).isEqualTo(1);
logBuffer().record(record(Instant.now(), LogLevel.STDOUT, "hello, world!"));
awaitOnNext(duration);
}

void awaitOnNext(Duration duration) throws InterruptedException, TimeoutException {
if (!onNext.await(duration.toNanos(), TimeUnit.NANOSECONDS)) {
cancel("onNext latch timed out", null);
throw new TimeoutException();
}
}

void awaitOnDone(Duration duration) throws InterruptedException, TimeoutException {
if (!onDone.await(duration.toNanos(), TimeUnit.NANOSECONDS)) {
cancel("onDone latch timed out", null);
throw new TimeoutException();
}
}
}

@Test
public void subscribeToLogsHistory() throws InterruptedException {
final LogBufferRecord record1 = record(Instant.now(), LogLevel.STDOUT, "hello");
final LogBufferRecord record2 = record(Instant.now(), LogLevel.STDOUT, "world");
logBuffer().record(record1);
logBuffer().record(record2);
final LogSubscriptionRequest request = LogSubscriptionRequest.getDefaultInstance();
public void subscribeToLogsHistory() throws InterruptedException, TimeoutException {
logBuffer().record(record(Instant.now(), LogLevel.STDOUT, "hello"));
logBuffer().record(record(Instant.now(), LogLevel.STDOUT, "world"));
final Observer observer = new Observer(2);
channel().console().subscribeToLogs(request, observer);
assertThat(observer.latch.await(3, TimeUnit.SECONDS)).isTrue();
observer.stream.cancel("done", null);
assertThat(observer.done.await(3, TimeUnit.SECONDS)).isTrue();
observer.subscribeToLogs();
observer.awaitOnNext(Duration.ofSeconds(3));
observer.cancel("done", null);
observer.awaitOnDone(Duration.ofSeconds(3));
assertThat(observer.error).isInstanceOf(StatusRuntimeException.class);
assertThat(observer.error).hasMessage("CANCELLED: done");
}

@Test
public void subscribeToLogsInline() throws InterruptedException {
final LogSubscriptionRequest request = LogSubscriptionRequest.getDefaultInstance();
public void subscribeToLogsInline() throws InterruptedException, TimeoutException {
final Observer observer = new Observer(2);
channel().console().subscribeToLogs(request, observer);
final LogBufferRecord record1 = record(Instant.now(), LogLevel.STDOUT, "hello");
final LogBufferRecord record2 = record(Instant.now(), LogLevel.STDOUT, "world");
logBuffer().record(record1);
logBuffer().record(record2);
assertThat(observer.latch.await(3, TimeUnit.SECONDS)).isTrue();
observer.stream.cancel("done", null);
assertThat(observer.done.await(3, TimeUnit.SECONDS)).isTrue();
observer.subscribeToLogs();
logBuffer().record(record(Instant.now(), LogLevel.STDOUT, "hello"));
logBuffer().record(record(Instant.now(), LogLevel.STDOUT, "world"));
observer.awaitOnNext(Duration.ofSeconds(3));
observer.cancel("done", null);
observer.awaitOnDone(Duration.ofSeconds(3));
assertThat(observer.error).isInstanceOf(StatusRuntimeException.class);
assertThat(observer.error).hasMessage("CANCELLED: done");
}

@Test
public void closingSessionCancelsSubscribeToLogs() throws InterruptedException, TimeoutException {
final Observer observer = new Observer(1);
observer.subscribeToLogs();
observer.awaitRpcEstablished(Duration.ofSeconds(3));
closeSession();
observer.awaitOnDone(Duration.ofSeconds(3));
assertThat(observer.error).isInstanceOf(StatusRuntimeException.class);
assertThat(observer.error).hasMessage("CANCELLED: Session closed");
}

private static LogBufferRecord record(Instant timestamp, LogLevel level, String message) {
final LogBufferRecord record = new LogBufferRecord();
record.setTimestampMicros(timestamp.toEpochMilli() * 1000);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import io.deephaven.UncheckedDeephavenException;
import io.deephaven.auth.AuthenticationException;
import io.deephaven.proto.DeephavenChannel;
import io.deephaven.proto.backplane.grpc.CloseSessionResponse;
import io.deephaven.proto.backplane.grpc.HandshakeRequest;
import io.deephaven.proto.backplane.grpc.HandshakeResponse;
import io.deephaven.server.session.SessionState;
Expand Down Expand Up @@ -49,4 +50,8 @@ public SessionState authenticatedSessionState() {
public DeephavenChannel channel() {
return channel;
}

public CloseSessionResponse closeSession() {
return channel.sessionBlocking().closeSession(HandshakeRequest.newBuilder().build());
}
}

0 comments on commit a70368a

Please sign in to comment.