diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java index 50e4dc2d2c855..c30e1620a5e11 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorker.java @@ -75,9 +75,8 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ApplianceGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.StreamingEngineGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.StreamPoolGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.WorkRefreshClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.ChannelzServlet; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcDispatcherClient; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer; @@ -217,12 +216,16 @@ private StreamingDataflowWorker( this.workCommitter = windmillServiceEnabled - ? StreamingEngineWorkCommitter.create( - WindmillStreamPool.create( - numCommitThreads, COMMIT_STREAM_TIMEOUT, windmillServer::commitWorkStream) - ::getCloseableStream, - numCommitThreads, - this::onCompleteCommit) + ? StreamingEngineWorkCommitter.builder() + .setCommitWorkStreamFactory( + WindmillStreamPool.create( + numCommitThreads, + COMMIT_STREAM_TIMEOUT, + windmillServer::commitWorkStream) + ::getCloseableStream) + .setNumCommitSenders(numCommitThreads) + .setOnCommitComplete(this::onCompleteCommit) + .build() : StreamingApplianceWorkCommitter.create( windmillServer::commitWork, this::onCompleteCommit); @@ -252,31 +255,26 @@ private StreamingDataflowWorker( ThrottlingGetDataMetricTracker getDataMetricTracker = new ThrottlingGetDataMetricTracker(memoryMonitor); - WindmillStreamPool getDataStreamPool = - WindmillStreamPool.create( - Math.max(1, options.getWindmillGetDataStreamCount()), - GET_DATA_STREAM_TIMEOUT, - windmillServer::getDataStream); - - // Register standard file systems. - FileSystems.setDefaultPipelineOptions(options); - - int stuckCommitDurationMillis = - windmillServiceEnabled && options.getStuckCommitDurationMillis() > 0 - ? options.getStuckCommitDurationMillis() - : 0; - - WorkRefreshClient workRefreshClient; + int stuckCommitDurationMillis; if (windmillServiceEnabled) { - StreamingEngineGetDataClient streamingEngineGetDataClient = - new StreamingEngineGetDataClient(getDataMetricTracker, getDataStreamPool); - this.getDataClient = streamingEngineGetDataClient; - workRefreshClient = streamingEngineGetDataClient; + WindmillStreamPool getDataStreamPool = + WindmillStreamPool.create( + Math.max(1, options.getWindmillGetDataStreamCount()), + GET_DATA_STREAM_TIMEOUT, + windmillServer::getDataStream); + this.getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool); + this.heartbeatSender = + new StreamPoolHeartbeatSender( + options.getUseSeparateWindmillHeartbeatStreams() + ? WindmillStreamPool.create( + 1, GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream) + : getDataStreamPool); + stuckCommitDurationMillis = + options.getStuckCommitDurationMillis() > 0 ? options.getStuckCommitDurationMillis() : 0; } else { - ApplianceGetDataClient applianceGetDataClient = - new ApplianceGetDataClient(windmillServer, getDataMetricTracker); - this.getDataClient = applianceGetDataClient; - workRefreshClient = applianceGetDataClient; + this.getDataClient = new ApplianceGetDataClient(windmillServer, getDataMetricTracker); + this.heartbeatSender = new ApplianceHeartbeatSender(windmillServer::getData); + stuckCommitDurationMillis = 0; } this.activeWorkRefresher = @@ -287,7 +285,7 @@ private StreamingDataflowWorker( computationStateCache::getAllPresentComputations, sampler, executorSupplier.apply("RefreshWork"), - workRefreshClient::refreshActiveWork); + getDataMetricTracker::trackHeartbeats); WorkerStatusPages workerStatusPages = WorkerStatusPages.create(DEFAULT_STATUS_PORT, memoryMonitor); @@ -333,14 +331,8 @@ private StreamingDataflowWorker( ID_GENERATOR, stageInfoMap); - this.heartbeatSender = - options.isEnableStreamingEngine() - ? new StreamPoolHeartbeatSender( - options.getUseSeparateWindmillHeartbeatStreams() - ? WindmillStreamPool.create( - 1, GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream) - : getDataStreamPool) - : new ApplianceHeartbeatSender(windmillServer::getData); + // Register standard file systems. + FileSystems.setDefaultPipelineOptions(options); LOG.debug("windmillServiceEnabled: {}", windmillServiceEnabled); LOG.debug("WindmillServiceEndpoint: {}", options.getWindmillServiceEndpoint()); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java index 934977fe0985e..ec5122a8732ab 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WorkItemCancelledException.java @@ -26,6 +26,10 @@ public WorkItemCancelledException(long sharding_key) { super("Work item cancelled for key " + sharding_key); } + public WorkItemCancelledException(Throwable e) { + super(e); + } + /** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */ public static boolean isWorkItemCancelledException(Throwable t) { while (t != null) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java index 64309d0a75010..56b0e3f539a50 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkState.java @@ -206,14 +206,17 @@ synchronized ImmutableListMultimap getReadOnlyActiv .collect( flatteningToImmutableListMultimap( Entry::getKey, - e -> e.getValue().stream().map(ExecutableWork::work).map(Work::refreshableView))); + e -> + e.getValue().stream() + .map(ExecutableWork::work) + .map(work -> (RefreshableWork) work))); } synchronized ImmutableList getRefreshableWork(Instant refreshDeadline) { return activeWork.values().stream() .flatMap(Deque::stream) .map(ExecutableWork::work) - .filter(work -> work.isRefreshable(refreshDeadline)) + .filter(work -> !work.isFailed() && work.getStartTime().isBefore(refreshDeadline)) .collect(toImmutableList()); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java index a1668d9ae7851..c51b04f23719f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/RefreshableWork.java @@ -22,7 +22,6 @@ import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.joda.time.Instant; /** View of {@link Work} that exposes an interface for work refreshing. */ @Internal @@ -32,8 +31,6 @@ public interface RefreshableWork { ShardedKey getShardedKey(); - boolean isRefreshable(Instant refreshDeadline); - HeartbeatSender heartbeatSender(); ImmutableList getHeartbeatLatencyAttributions( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java index 71ffd98ac1c03..e77823602eda7 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/Work.java @@ -150,10 +150,6 @@ private static LatencyAttribution.Builder createLatencyAttributionWithActiveLate return latencyAttribution; } - public RefreshableWork refreshableView() { - return this; - } - public WorkItem getWorkItem() { return workItem; } @@ -209,11 +205,6 @@ public String getLatencyTrackingId() { return latencyTrackingId; } - @Override - public boolean isRefreshable(Instant refreshDeadline) { - return !isFailed && getStartTime().isBefore(refreshDeadline); - } - @Override public HeartbeatSender heartbeatSender() { return processingContext.heartbeatSender(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java index 113b760556dfd..303cdeb94f8c6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/streaming/sideinput/SideInputStateFetcher.java @@ -34,6 +34,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; +import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; @@ -52,6 +53,7 @@ /** Class responsible for fetching side input state from the streaming backend. */ @NotThreadSafe +@Internal public class SideInputStateFetcher { private static final Logger LOG = LoggerFactory.getLogger(SideInputStateFetcher.class); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java index b00c4c9c0c7fe..fd0d1b1a3a92d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/AbstractWindmillStream.java @@ -85,10 +85,11 @@ public abstract class AbstractWindmillStream implements Win private final Supplier> requestObserverSupplier; // Indicates if the current stream in requestObserver is closed by calling close() method private final AtomicBoolean streamClosed; - private @Nullable StreamObserver requestObserver; private final String backendWorkerToken; + private @Nullable StreamObserver requestObserver; protected AbstractWindmillStream( + String debugStreamType, Function, StreamObserver> clientFactory, BackOff backoff, StreamObserverFactory streamObserverFactory, @@ -100,7 +101,7 @@ protected AbstractWindmillStream( Executors.newSingleThreadExecutor( new ThreadFactoryBuilder() .setDaemon(true) - .setNameFormat(createThreadName(streamType(), backendWorkerToken)) + .setNameFormat(createThreadName(debugStreamType, backendWorkerToken)) .build()); this.backoff = backoff; this.streamRegistry = streamRegistry; @@ -122,10 +123,10 @@ protected AbstractWindmillStream( clientFactory, new AbstractWindmillStream.ResponseObserver()); } - private static String createThreadName(Type streamType, String backendWorkerToken) { + private static String createThreadName(String streamType, String backendWorkerToken) { return !backendWorkerToken.isEmpty() - ? String.format("%s-%s-WindmillStream-thread", streamType.name(), backendWorkerToken) - : String.format("%s-WindmillStream-thread", streamType.name()); + ? String.format("%s-%s-WindmillStream-thread", streamType, backendWorkerToken) + : String.format("%s-WindmillStream-thread", streamType); } private static long debugDuration(long nowMs, long startMs) { @@ -151,6 +152,11 @@ private static long debugDuration(long nowMs, long startMs) { */ protected abstract void startThrottleTimer(); + /** Reflects that {@link #shutdown()} was explicitly called. */ + protected boolean isShutdown() { + return isShutdown.get(); + } + private StreamObserver requestObserver() { if (requestObserver == null) { throw new NullPointerException( @@ -274,15 +280,11 @@ public String backendWorkerToken() { @Override public void shutdown() { if (isShutdown.compareAndSet(false, true)) { - halfClose(); + requestObserver() + .onError(new WindmillStreamShutdownException("Explicit call to shutdown stream.")); } } - @Override - public boolean isShutdown() { - return isShutdown.get(); - } - private void setLastError(String error) { lastError.set(error); lastErrorTime.set(DateTime.now()); @@ -313,7 +315,7 @@ public void onCompleted() { private void onStreamFinished(@Nullable Throwable t) { synchronized (this) { - if (clientClosed.get() && !hasPendingRequests()) { + if (isShutdown.get() || (clientClosed.get() && !hasPendingRequests())) { streamRegistry.remove(AbstractWindmillStream.this); finishLatch.countDown(); return; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java index a4bfa69ad7798..ee467c01c8f6e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStream.java @@ -52,18 +52,6 @@ public interface WindmillStream { */ void shutdown(); - /** Reflects that {@link #shutdown()} was explicitly called. */ - boolean isShutdown(); - - Type streamType(); - - enum Type { - GET_WORKER_METADATA, - GET_WORK, - GET_DATA, - COMMIT_WORK, - } - /** Handle representing a stream of GetWork responses. */ @ThreadSafe interface GetWorkStream extends WindmillStream { @@ -72,11 +60,6 @@ interface GetWorkStream extends WindmillStream { /** Returns the remaining in-flight {@link GetWorkBudget}. */ GetWorkBudget remainingBudget(); - - @Override - default Type streamType() { - return Type.GET_WORK; - } } /** Interface for streaming GetDataRequests to Windmill. */ @@ -93,11 +76,6 @@ Windmill.KeyedGetDataResponse requestKeyedData( void refreshActiveWork(Map> heartbeats); void onHeartbeatResponse(List responses); - - @Override - default Type streamType() { - return Type.GET_DATA; - } } /** Interface for streaming CommitWorkRequests to Windmill. */ @@ -109,11 +87,6 @@ interface CommitWorkStream extends WindmillStream { */ CommitWorkStream.RequestBatcher batcher(); - @Override - default Type streamType() { - return Type.COMMIT_WORK; - } - @NotThreadSafe interface RequestBatcher extends Closeable { /** @@ -140,10 +113,11 @@ default void close() { /** Interface for streaming GetWorkerMetadata requests to Windmill. */ @ThreadSafe - interface GetWorkerMetadataStream extends WindmillStream { - @Override - default Type streamType() { - return Type.GET_WORKER_METADATA; + interface GetWorkerMetadataStream extends WindmillStream {} + + class WindmillStreamShutdownException extends RuntimeException { + public WindmillStreamShutdownException(String message) { + super(message); } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java index 911b6809c2429..afdb29560a2b2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitter.java @@ -17,6 +17,7 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.commits; +import com.google.auto.value.AutoBuilder; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -55,10 +56,11 @@ public final class StreamingEngineWorkCommitter implements WorkCommitter { private final int numCommitSenders; private final AtomicBoolean isRunning; - private StreamingEngineWorkCommitter( + StreamingEngineWorkCommitter( Supplier> commitWorkStreamFactory, int numCommitSenders, - Consumer onCommitComplete) { + Consumer onCommitComplete, + String backendWorkerToken) { this.commitWorkStreamFactory = commitWorkStreamFactory; this.commitQueue = WeightedBoundedQueue.create( @@ -69,7 +71,10 @@ private StreamingEngineWorkCommitter( new ThreadFactoryBuilder() .setDaemon(true) .setPriority(Thread.MAX_PRIORITY) - .setNameFormat("CommitThread-%d") + .setNameFormat( + backendWorkerToken.isEmpty() + ? "CommitThread-%d" + : "CommitThread-" + backendWorkerToken + "-%d") .build()); this.activeCommitBytes = new AtomicLong(); this.onCommitComplete = onCommitComplete; @@ -77,32 +82,33 @@ private StreamingEngineWorkCommitter( this.isRunning = new AtomicBoolean(false); } - public static StreamingEngineWorkCommitter create( - Supplier> commitWorkStreamFactory, - int numCommitSenders, - Consumer onCommitComplete) { - return new StreamingEngineWorkCommitter( - commitWorkStreamFactory, numCommitSenders, onCommitComplete); + public static Builder builder() { + return new AutoBuilder_StreamingEngineWorkCommitter_Builder() + .setBackendWorkerToken("") + .setNumCommitSenders(1); } @Override @SuppressWarnings("FutureReturnValueIgnored") public void start() { - if (isRunning.compareAndSet(false, true) && !commitSenders.isShutdown()) { - for (int i = 0; i < numCommitSenders; i++) { - commitSenders.submit(this::streamingCommitLoop); - } + Preconditions.checkState( + isRunning.compareAndSet(false, true), "Multiple calls to WorkCommitter.start()."); + for (int i = 0; i < numCommitSenders; i++) { + commitSenders.submit(this::streamingCommitLoop); } } @Override public void commit(Commit commit) { - if (commit.work().isFailed() || !isRunning.get()) { - LOG.debug( - "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}, workId={} ].", - commit.computationId(), - commit.work().getShardedKey(), - commit.work().id()); + boolean isShutdown = !this.isRunning.get(); + if (commit.work().isFailed() || isShutdown) { + if (isShutdown) { + LOG.debug( + "Trying to queue commit on shutdown, failing commit=[computationId={}, shardingKey={}, workId={} ].", + commit.computationId(), + commit.work().getShardedKey(), + commit.work().id()); + } failCommit(commit); } else { commitQueue.put(commit); @@ -116,17 +122,16 @@ public long currentActiveCommitBytes() { @Override public void stop() { - if (isRunning.compareAndSet(true, false) && !commitSenders.isTerminated()) { - commitSenders.shutdownNow(); - try { - commitSenders.awaitTermination(10, TimeUnit.SECONDS); - } catch (InterruptedException e) { - LOG.warn( - "Commit senders didn't complete shutdown within 10 seconds, continuing to drain queue.", - e); - } - drainCommitQueue(); + Preconditions.checkState(isRunning.compareAndSet(true, false)); + commitSenders.shutdownNow(); + try { + commitSenders.awaitTermination(10, TimeUnit.SECONDS); + } catch (InterruptedException e) { + LOG.warn( + "Commit senders didn't complete shutdown within 10 seconds, continuing to drain queue.", + e); } + drainCommitQueue(); } private void drainCommitQueue() { @@ -150,7 +155,7 @@ public int parallelism() { private void streamingCommitLoop() { @Nullable Commit initialCommit = null; try { - while (true) { + while (isRunning.get()) { if (initialCommit == null) { try { // Block until we have a commit or are shutting down. @@ -169,17 +174,14 @@ private void streamingCommitLoop() { } try (CloseableStream closeableCommitStream = - commitWorkStreamFactory.get()) { - CommitWorkStream commitStream = closeableCommitStream.stream(); - try (CommitWorkStream.RequestBatcher batcher = commitStream.batcher()) { - if (!tryAddToCommitBatch(initialCommit, batcher)) { - throw new AssertionError( - "Initial commit on flushed stream should always be accepted."); - } - // Batch additional commits to the stream and possibly make an un-batched commit the - // next initial commit. - initialCommit = expandBatch(batcher); + commitWorkStreamFactory.get(); + CommitWorkStream.RequestBatcher batcher = closeableCommitStream.stream().batcher()) { + if (!tryAddToCommitBatch(initialCommit, batcher)) { + throw new AssertionError("Initial commit on flushed stream should always be accepted."); } + // Batch additional commits to the stream and possibly make an un-batched commit the + // next initial commit. + initialCommit = expandBatch(batcher); } catch (Exception e) { LOG.error("Error occurred sending commits.", e); } @@ -200,7 +202,7 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch batcher.commitWorkItem( commit.computationId(), commit.request(), - (commitStatus) -> { + commitStatus -> { onCommitComplete.accept(CompleteCommit.create(commit, commitStatus)); activeCommitBytes.addAndGet(-commit.getSize()); }); @@ -214,11 +216,13 @@ private boolean tryAddToCommitBatch(Commit commit, CommitWorkStream.RequestBatch return isCommitAccepted; } - // Helper to batch additional commits into the commit batch as long as they fit. - // Returns a commit that was removed from the queue but not consumed or null. - private Commit expandBatch(CommitWorkStream.RequestBatcher batcher) { + /** + * Helper to batch additional commits into the commit batch as long as they fit. Returns a commit + * that was removed from the queue but not consumed or null. + */ + private @Nullable Commit expandBatch(CommitWorkStream.RequestBatcher batcher) { int commits = 1; - while (true) { + while (isRunning.get()) { Commit commit; try { if (commits < TARGET_COMMIT_BATCH_KEYS) { @@ -245,5 +249,25 @@ private Commit expandBatch(CommitWorkStream.RequestBatcher batcher) { } commits++; } + + return null; + } + + @AutoBuilder + public interface Builder { + Builder setCommitWorkStreamFactory( + Supplier> commitWorkStreamFactory); + + Builder setNumCommitSenders(int numCommitSenders); + + Builder setOnCommitComplete(Consumer onCommitComplete); + + Builder setBackendWorkerToken(String backendWorkerToken); + + StreamingEngineWorkCommitter autoBuild(); + + default WorkCommitter build() { + return autoBuild(); + } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java index dc5adb4e7966c..e0500dde0c538 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ApplianceGetDataClient.java @@ -29,18 +29,15 @@ import org.apache.beam.runners.dataflow.worker.windmill.ApplianceWindmillClient; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.ComputationGetDataRequest; -import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; -import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.Heartbeats; import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.SettableFuture; import org.checkerframework.checker.nullness.qual.Nullable; /** Appliance implementation of {@link GetDataClient}. */ @Internal @ThreadSafe -public final class ApplianceGetDataClient implements GetDataClient, WorkRefreshClient { +public final class ApplianceGetDataClient implements GetDataClient { private static final int MAX_READS_PER_BATCH = 60; private static final int MAX_ACTIVE_READS = 10; @@ -61,19 +58,12 @@ public ApplianceGetDataClient( this.activeReadThreads = 0; } - public static GetDataClient create( - ApplianceWindmillClient windmillClient, ThrottlingGetDataMetricTracker getDataMetricTracker) { - return new ApplianceGetDataClient(windmillClient, getDataMetricTracker); - } - @Override public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling( - ThrottlingGetDataMetricTracker.Type.STATE)) { + String computationId, Windmill.KeyedGetDataRequest request) { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { SettableFuture response = SettableFuture.create(); - ReadBatch batch = addToReadBatch(new QueueEntry(computation, request, response)); + ReadBatch batch = addToReadBatch(new QueueEntry(computationId, request, response)); if (batch != null) { issueReadBatch(batch); } @@ -81,7 +71,7 @@ public Windmill.KeyedGetDataResponse getStateData( } catch (Exception e) { throw new GetDataException( "Error occurred fetching state for computation=" - + computation + + computationId + ", key=" + request.getShardingKey(), e); @@ -90,9 +80,7 @@ public Windmill.KeyedGetDataResponse getStateData( @Override public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling( - ThrottlingGetDataMetricTracker.Type.STATE)) { + try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { return windmillClient .getData(Windmill.GetDataRequest.newBuilder().addGlobalDataFetchRequests(request).build()) .getGlobalData(0); @@ -102,28 +90,6 @@ public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) } } - /** - * Appliance sends heartbeats (used to refresh active work) as KeyedGetDataRequests. So we must - * translate the HeartbeatRequest to a KeyedGetDataRequest. - */ - @Override - public void refreshActiveWork(Map heartbeats) { - Map.Entry heartbeat = - Iterables.getOnlyElement(heartbeats.entrySet()); - HeartbeatSender heartbeatSender = heartbeat.getKey(); - Heartbeats heartbeatToSend = heartbeat.getValue(); - - if (heartbeatToSend.heartbeatRequests().isEmpty()) { - return; - } - - try (AutoCloseable ignored = getDataMetricTracker.trackHeartbeats(heartbeatToSend.size())) { - heartbeatSender.sendHeartbeats(heartbeatToSend); - } catch (Exception e) { - throw new GetDataException("Error occurred refreshing heartbeats=" + heartbeatToSend, e); - } - } - @Override public synchronized void printHtml(PrintWriter writer) { getDataMetricTracker.printHtml(writer); @@ -133,7 +99,8 @@ public synchronized void printHtml(PrintWriter writer) { private void issueReadBatch(ReadBatch batch) { try { - Preconditions.checkState(batch.startRead.get()); + // Possibly block until the batch is allowed to start. + batch.startRead.get(); } catch (InterruptedException e) { // We don't expect this thread to be interrupted. To simplify handling, we just fall through // to issuing the call. @@ -191,7 +158,7 @@ private void issueReadBatch(ReadBatch batch) { } else { // Notify the thread responsible for issuing the next batch read. ReadBatch startBatch = pendingReadBatches.remove(0); - startBatch.startRead.set(true); + startBatch.startRead.set(null); } } } @@ -227,13 +194,13 @@ private void issueReadBatch(ReadBatch batch) { } ReadBatch batch = new ReadBatch(); batch.reads.add(entry); - batch.startRead.set(true); + batch.startRead.set(null); return batch; } private static final class ReadBatch { ArrayList reads = new ArrayList<>(); - SettableFuture startRead = SettableFuture.create(); + SettableFuture startRead = SettableFuture.create(); } private static final class QueueEntry { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/FanOutWorkRefreshClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/FanOutWorkRefreshClient.java deleted file mode 100644 index 79cde43ffc24b..0000000000000 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/FanOutWorkRefreshClient.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; -import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.Heartbeats; -import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * {@link WorkRefreshClient} that fans out heartbeats to all {@link HeartbeatSender}(s) in parallel - * passed into {@link #refreshActiveWork(Map)} - */ -@Internal -public final class FanOutWorkRefreshClient implements WorkRefreshClient { - private static final Logger LOG = LoggerFactory.getLogger(FanOutWorkRefreshClient.class); - private static final String FAN_OUT_REFRESH_WORK_EXECUTOR_NAME = - "FanOutActiveWorkRefreshExecutor"; - - private final ThrottlingGetDataMetricTracker getDataMetricTracker; - private final ExecutorService fanOutActiveWorkRefreshExecutor; - - public FanOutWorkRefreshClient(ThrottlingGetDataMetricTracker getDataMetricTracker) { - this.getDataMetricTracker = getDataMetricTracker; - this.fanOutActiveWorkRefreshExecutor = - Executors.newCachedThreadPool( - new ThreadFactoryBuilder() - // FanOutWorkRefreshClient runs as a background process, don't let failures crash - // the worker. - .setUncaughtExceptionHandler( - (t, e) -> LOG.error("Unexpected failure in {}", t.getName(), e)) - .setNameFormat(FAN_OUT_REFRESH_WORK_EXECUTOR_NAME) - .build()); - } - - @Override - public void refreshActiveWork(Map heartbeats) { - List> fanOutRefreshActiveWork = new ArrayList<>(); - for (Map.Entry heartbeat : heartbeats.entrySet()) { - fanOutRefreshActiveWork.add(sendHeartbeatOnStreamFuture(heartbeat)); - } - - // Don't block until we kick off all the refresh active work RPCs. - @SuppressWarnings("rawtypes") - CompletableFuture parallelFanOutRefreshActiveWork = - CompletableFuture.allOf(fanOutRefreshActiveWork.toArray(new CompletableFuture[0])); - parallelFanOutRefreshActiveWork.join(); - } - - private CompletableFuture sendHeartbeatOnStreamFuture( - Map.Entry heartbeat) { - return CompletableFuture.runAsync( - () -> { - try (AutoCloseable ignored = - getDataMetricTracker.trackHeartbeats(heartbeat.getValue().size())) { - HeartbeatSender sender = heartbeat.getKey(); - Heartbeats heartbeats = heartbeat.getValue(); - sender.sendHeartbeats(heartbeats); - } catch (Exception e) { - LOG.error( - "Unable to send {} heartbeats to {}.", - heartbeat.getValue().size(), - heartbeat.getKey(), - new GetDataClient.GetDataException("Error refreshing heartbeats.", e)); - } - }, - fanOutActiveWorkRefreshExecutor); - } -} diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java index 4577b29f8850f..c732591bf12d1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/GetDataClient.java @@ -27,18 +27,30 @@ /** Client for streaming backend GetData API. */ @Internal public interface GetDataClient { - KeyedGetDataResponse getStateData(String computation, KeyedGetDataRequest request); + /** + * Issues a blocking call to fetch state data for a specific computation and {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem}. + * + * @throws GetDataException when there was an unexpected error during the attempted fetch. + */ + KeyedGetDataResponse getStateData(String computationId, KeyedGetDataRequest request) + throws GetDataException; - GlobalData getSideInputData(GlobalDataRequest request); + /** + * Issues a blocking call to fetch side input data. + * + * @throws GetDataException when there was an unexpected error during the attempted fetch. + */ + GlobalData getSideInputData(GlobalDataRequest request) throws GetDataException; - default void printHtml(PrintWriter writer) {} + void printHtml(PrintWriter writer); - class GetDataException extends RuntimeException { - protected GetDataException(String message, Throwable cause) { + final class GetDataException extends RuntimeException { + GetDataException(String message, Throwable cause) { super(message, cause); } - public GetDataException(String message) { + GetDataException(String message) { super(message); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/DirectGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java similarity index 63% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/DirectGetDataClient.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java index 6ee86b6ae7241..b0625384641e2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/DirectGetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamGetDataClient.java @@ -17,25 +17,27 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; +import java.io.PrintWriter; import java.util.function.Function; import org.apache.beam.runners.dataflow.worker.WorkItemCancelledException; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.sdk.annotations.Internal; /** {@link GetDataClient} that fetches data directly from a specific {@link GetDataStream}. */ @Internal -public final class DirectGetDataClient implements GetDataClient { +public final class StreamGetDataClient implements GetDataClient { - private final GetDataStream directGetDataStream; + private final GetDataStream getDataStream; private final Function sideInputGetDataStreamFactory; private final ThrottlingGetDataMetricTracker getDataMetricTracker; - private DirectGetDataClient( - GetDataStream directGetDataStream, + private StreamGetDataClient( + GetDataStream getDataStream, Function sideInputGetDataStreamFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { - this.directGetDataStream = directGetDataStream; + this.getDataStream = getDataStream; this.sideInputGetDataStreamFactory = sideInputGetDataStreamFactory; this.getDataMetricTracker = getDataMetricTracker; } @@ -44,51 +46,56 @@ public static GetDataClient create( GetDataStream getDataStream, Function sideInputGetDataStreamFactory, ThrottlingGetDataMetricTracker getDataMetricTracker) { - return new DirectGetDataClient( + return new StreamGetDataClient( getDataStream, sideInputGetDataStreamFactory, getDataMetricTracker); } + /** + * @throws WorkItemCancelledException when the fetch fails due to the stream being shutdown, + * indicating that the {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem} that triggered the + * fetch has been cancelled. + */ @Override public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - if (directGetDataStream.isShutdown()) { + String computationId, Windmill.KeyedGetDataRequest request) throws GetDataException { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { + return getDataStream.requestKeyedData(computationId, request); + } catch (WindmillStream.WindmillStreamShutdownException e) { throw new WorkItemCancelledException(request.getShardingKey()); - } - - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling( - ThrottlingGetDataMetricTracker.Type.STATE)) { - return directGetDataStream.requestKeyedData(computation, request); } catch (Exception e) { - if (directGetDataStream.isShutdown()) { - throw new WorkItemCancelledException(request.getShardingKey()); - } - throw new GetDataException( "Error occurred fetching state for computation=" - + computation + + computationId + ", key=" + request.getShardingKey(), e); } } + /** + * @throws WorkItemCancelledException when the fetch fails due to the stream being shutdown, + * indicating that the {@link + * org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem} that triggered the + * fetch has been cancelled. + */ @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { + public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) + throws GetDataException { GetDataStream sideInputGetDataStream = sideInputGetDataStreamFactory.apply(request.getDataId().getTag()); - if (sideInputGetDataStream.isShutdown()) { - throw new GetDataException( - "Error occurred fetching side input for tag=" + request.getDataId()); - } - - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling( - ThrottlingGetDataMetricTracker.Type.SIDE_INPUT)) { + try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { return sideInputGetDataStream.requestGlobalData(request); + } catch (WindmillStream.WindmillStreamShutdownException e) { + throw new WorkItemCancelledException(e); } catch (Exception e) { throw new GetDataException( "Error occurred fetching side input for tag=" + request.getDataId(), e); } } + + @Override + public void printHtml(PrintWriter writer) { + getDataMetricTracker.printHtml(writer); + } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamingEngineGetDataClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamPoolGetDataClient.java similarity index 63% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamingEngineGetDataClient.java rename to runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamPoolGetDataClient.java index 54967f039f2d2..d6b20e425b0ba 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamingEngineGetDataClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/StreamPoolGetDataClient.java @@ -18,7 +18,6 @@ package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; import java.io.PrintWriter; -import java.util.Map; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest; @@ -26,25 +25,22 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; -import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; -import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.Heartbeats; import org.apache.beam.sdk.annotations.Internal; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; /** * StreamingEngine implementation of {@link GetDataClient}. * - * @implNote Uses {@link WindmillStreamPool} to send/receive requests. Depending on options, may use - * a dedicated stream pool for heartbeats. + * @implNote Uses {@link WindmillStreamPool} to send requests. Depending on options, may use a + * dedicated stream pool for heartbeats. */ @Internal @ThreadSafe -public final class StreamingEngineGetDataClient implements GetDataClient, WorkRefreshClient { +public final class StreamPoolGetDataClient implements GetDataClient { private final WindmillStreamPool getDataStreamPool; private final ThrottlingGetDataMetricTracker getDataMetricTracker; - public StreamingEngineGetDataClient( + public StreamPoolGetDataClient( ThrottlingGetDataMetricTracker getDataMetricTracker, WindmillStreamPool getDataStreamPool) { this.getDataMetricTracker = getDataMetricTracker; @@ -53,16 +49,14 @@ public StreamingEngineGetDataClient( @Override public Windmill.KeyedGetDataResponse getStateData( - String computation, KeyedGetDataRequest request) { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling( - ThrottlingGetDataMetricTracker.Type.STATE); + String computationId, KeyedGetDataRequest request) { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling(); CloseableStream closeableStream = getDataStreamPool.getCloseableStream()) { - return closeableStream.stream().requestKeyedData(computation, request); + return closeableStream.stream().requestKeyedData(computationId, request); } catch (Exception e) { throw new GetDataException( "Error occurred fetching state for computation=" - + computation + + computationId + ", key=" + request.getShardingKey(), e); @@ -71,9 +65,7 @@ public Windmill.KeyedGetDataResponse getStateData( @Override public Windmill.GlobalData getSideInputData(GlobalDataRequest request) { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling( - ThrottlingGetDataMetricTracker.Type.SIDE_INPUT); + try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling(); CloseableStream closeableStream = getDataStreamPool.getCloseableStream()) { return closeableStream.stream().requestGlobalData(request); } catch (Exception e) { @@ -82,24 +74,6 @@ public Windmill.GlobalData getSideInputData(GlobalDataRequest request) { } } - @Override - public void refreshActiveWork(Map heartbeats) { - Map.Entry heartbeat = - Iterables.getOnlyElement(heartbeats.entrySet()); - HeartbeatSender heartbeatSender = heartbeat.getKey(); - Heartbeats heartbeatToSend = heartbeat.getValue(); - - if (heartbeatToSend.heartbeatRequests().isEmpty()) { - return; - } - - try (AutoCloseable ignored = getDataMetricTracker.trackHeartbeats(heartbeatToSend.size())) { - heartbeatSender.sendHeartbeats(heartbeatToSend); - } catch (Exception e) { - throw new GetDataException("Error occurred refreshing heartbeats=" + heartbeatToSend, e); - } - } - @Override public void printHtml(PrintWriter writer) { getDataMetricTracker.printHtml(writer); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java index d356f205817a4..a66cf932bd742 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTracker.java @@ -26,29 +26,45 @@ import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; /** - * Wraps GetData calls that tracks metrics for the number of in-flight requests and throttles - * requests when memory pressure is high. + * Wraps GetData calls to track metrics for the number of in-flight requests and throttles requests + * when memory pressure is high. */ @Internal @ThreadSafe public final class ThrottlingGetDataMetricTracker { + private static final String GET_STATE_DATA_RESOURCE_CONTEXT = "GetStateData"; + private static final String GET_SIDE_INPUT_RESOURCE_CONTEXT = "GetSideInputData"; + private final MemoryMonitor gcThrashingMonitor; - private final GetDataMetrics getDataMetrics; + private final AtomicInteger activeStateReads; + private final AtomicInteger activeSideInputs; + private final AtomicInteger activeHeartbeats; public ThrottlingGetDataMetricTracker(MemoryMonitor gcThrashingMonitor) { this.gcThrashingMonitor = gcThrashingMonitor; - this.getDataMetrics = GetDataMetrics.create(); + this.activeStateReads = new AtomicInteger(); + this.activeSideInputs = new AtomicInteger(); + this.activeHeartbeats = new AtomicInteger(); + } + + /** + * Tracks a state data fetch. If there is memory pressure, may throttle requests. Returns an + * {@link AutoCloseable} that will decrement the metric after the call is finished. + */ + AutoCloseable trackStateDataFetchWithThrottling() { + gcThrashingMonitor.waitForResources(GET_STATE_DATA_RESOURCE_CONTEXT); + activeStateReads.getAndIncrement(); + return activeStateReads::getAndDecrement; } /** - * Tracks a GetData call. If there is memory pressure, may throttle requests. Returns an {@link - * AutoCloseable} that will decrement the metric after the call is finished. + * Tracks a side input fetch. If there is memory pressure, may throttle requests. Returns an + * {@link AutoCloseable} that will decrement the metric after the call is finished. */ - public AutoCloseable trackSingleCallWithThrottling(Type callType) { - gcThrashingMonitor.waitForResources(callType.debugName); - AtomicInteger getDataMetricTracker = getDataMetrics.getMetricFor(callType); - getDataMetricTracker.getAndIncrement(); - return getDataMetricTracker::getAndDecrement; + AutoCloseable trackSideInputFetchWithThrottling() { + gcThrashingMonitor.waitForResources(GET_SIDE_INPUT_RESOURCE_CONTEXT); + activeSideInputs.getAndIncrement(); + return activeSideInputs::getAndDecrement; } /** @@ -56,91 +72,38 @@ public AutoCloseable trackSingleCallWithThrottling(Type callType) { * metric after the call is finished. */ public AutoCloseable trackHeartbeats(int numHeartbeats) { - getDataMetrics - .activeHeartbeats() - .getAndUpdate(currentActiveHeartbeats -> currentActiveHeartbeats + numHeartbeats); - return () -> - getDataMetrics.activeHeartbeats().getAndUpdate(existing -> existing - numHeartbeats); + activeHeartbeats.getAndUpdate( + currentActiveHeartbeats -> currentActiveHeartbeats + numHeartbeats); + return () -> activeHeartbeats.getAndUpdate(existing -> existing - numHeartbeats); } public void printHtml(PrintWriter writer) { writer.println("Active Fetches:"); - getDataMetrics.printMetrics(writer); + writer.println(" Side Inputs: " + activeSideInputs.get()); + writer.println(" State Reads: " + activeStateReads.get()); + writer.println("Heartbeat Keys Active: " + activeHeartbeats.get()); } @VisibleForTesting - GetDataMetrics.ReadOnlySnapshot getMetricsSnapshot() { - return getDataMetrics.snapshot(); - } - - public enum Type { - STATE("GetStateData"), - SIDE_INPUT("GetSideInputData"), - HEARTBEAT("RefreshActiveWork"); - private final String debugName; - - Type(String debugName) { - this.debugName = debugName; - } - - public final String debugName() { - return debugName; - } + ReadOnlySnapshot getMetricsSnapshot() { + return ReadOnlySnapshot.create( + activeSideInputs.get(), activeStateReads.get(), activeHeartbeats.get()); } + @VisibleForTesting @AutoValue - abstract static class GetDataMetrics { - private static GetDataMetrics create() { - return new AutoValue_ThrottlingGetDataMetricTracker_GetDataMetrics( - new AtomicInteger(), new AtomicInteger(), new AtomicInteger()); - } - - abstract AtomicInteger activeSideInputs(); - - abstract AtomicInteger activeStateReads(); - - abstract AtomicInteger activeHeartbeats(); - - private ReadOnlySnapshot snapshot() { - return ReadOnlySnapshot.create( - activeSideInputs().get(), activeStateReads().get(), activeHeartbeats().get()); - } + abstract static class ReadOnlySnapshot { - private AtomicInteger getMetricFor(Type callType) { - switch (callType) { - case STATE: - return activeStateReads(); - case SIDE_INPUT: - return activeSideInputs(); - case HEARTBEAT: - return activeHeartbeats(); - - default: - // Should never happen, switch is exhaustive. - throw new IllegalStateException("Unsupported CallType=" + callType); - } - } - - private void printMetrics(PrintWriter writer) { - writer.println(" Side Inputs: " + activeSideInputs().get()); - writer.println(" State Reads: " + activeStateReads().get()); - writer.println("Heartbeat Keys Active: " + activeHeartbeats().get()); + private static ReadOnlySnapshot create( + int activeSideInputs, int activeStateReads, int activeHeartbeats) { + return new AutoValue_ThrottlingGetDataMetricTracker_ReadOnlySnapshot( + activeSideInputs, activeStateReads, activeHeartbeats); } - @AutoValue - abstract static class ReadOnlySnapshot { + abstract int activeSideInputs(); - private static ReadOnlySnapshot create( - int activeSideInputs, int activeStateReads, int activeHeartbeats) { - return new AutoValue_ThrottlingGetDataMetricTracker_GetDataMetrics_ReadOnlySnapshot( - activeSideInputs, activeStateReads, activeHeartbeats); - } + abstract int activeStateReads(); - abstract int activeSideInputs(); - - abstract int activeStateReads(); - - abstract int activeHeartbeats(); - } + abstract int activeHeartbeats(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java index 232461e34e633..053843a8af253 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcCommitWorkStream.java @@ -69,6 +69,7 @@ private GrpcCommitWorkStream( AtomicLong idGenerator, int streamingRpcBatchLimit) { super( + "CommitWorkStream", startCommitWorkRpcFn, backoff, streamObserverFactory, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java index 0a582ea1c6292..58f72610e2d35 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcDirectGetWorkStream.java @@ -110,6 +110,7 @@ private GrpcDirectGetWorkStream( Supplier workCommitter, WorkItemScheduler workItemScheduler) { super( + "GetWorkStream", startGetWorkRpcFn, backoff, streamObserverFactory, @@ -120,8 +121,6 @@ private GrpcDirectGetWorkStream( this.getWorkThrottleTimer = getWorkThrottleTimer; this.workItemScheduler = workItemScheduler; this.workItemBuffers = new ConcurrentHashMap<>(); - // Use the same GetDataStream and CommitWorkStream instances to process all the work in this - // stream. this.heartbeatSender = Suppliers.memoize(heartbeatSender::get); this.workCommitter = Suppliers.memoize(workCommitter::get); this.getDataClient = Suppliers.memoize(getDataClient::get); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java index 5600a8f0f413b..0e9a0c6316ee0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetDataStream.java @@ -90,6 +90,7 @@ private GrpcGetDataStream( boolean sendKeyedGetDataRequests, Consumer> processHeartbeatResponses) { super( + "GetDataStream", startGetDataRpcFn, backoff, streamObserverFactory, @@ -199,6 +200,10 @@ public GlobalData requestGlobalData(GlobalDataRequest request) { @Override public void refreshActiveWork(Map> heartbeats) { + if (isShutdown()) { + throw new WindmillStreamShutdownException("Unable to refresh work for shutdown stream."); + } + StreamingGetDataRequest.Builder builder = StreamingGetDataRequest.newBuilder(); if (sendKeyedGetDataRequests) { long builderBytes = 0; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java index 5fc093ee32aa9..4b392e9190ed2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkStream.java @@ -73,6 +73,7 @@ private GrpcGetWorkStream( ThrottleTimer getWorkThrottleTimer, WorkItemReceiver receiver) { super( + "GetWorkStream", startGetWorkRpcFn, backoff, streamObserverFactory, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java index 6f734b7da9dcb..44e21a9b18edd 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/GrpcGetWorkerMetadataStream.java @@ -65,6 +65,7 @@ private GrpcGetWorkerMetadataStream( ThrottleTimer getWorkerMetadataThrottleTimer, Consumer serverMappingConsumer) { super( + "GetWorkerMetadataStream", startGetWorkerMetadataRpcFn, backoff, streamObserverFactory, diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java index 01fb6381cd4ae..b9573ff94cc9a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/grpc/StreamingEngineClient.java @@ -45,7 +45,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetWorkerMetadataStream; import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.DirectGetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.StreamGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker; import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.stubs.ChannelCachingStubFactory; import org.apache.beam.runners.dataflow.worker.windmill.client.throttling.ThrottleTimer; @@ -410,7 +410,7 @@ private WindmillStreamSender createAndStartWindmillStreamSenderFor( streamFactory, workItemScheduler, getDataStream -> - DirectGetDataClient.create( + StreamGetDataClient.create( getDataStream, this::getGlobalDataStream, getDataMetricTracker), workCommitterFactory); windmillStreamSender.startStreams(); diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java index 5c79fb1ee402b..c4dc375cdb020 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresher.java @@ -19,19 +19,24 @@ import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap.toImmutableMap; +import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; -import java.util.function.Consumer; import java.util.function.Supplier; import javax.annotation.concurrent.ThreadSafe; import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.ComputationState; import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; -import org.apache.beam.runners.dataflow.worker.windmill.Windmill.HeartbeatRequest; import org.apache.beam.sdk.annotations.Internal; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.joda.time.Duration; import org.joda.time.Instant; import org.slf4j.Logger; @@ -48,14 +53,17 @@ @Internal public final class ActiveWorkRefresher { private static final Logger LOG = LoggerFactory.getLogger(ActiveWorkRefresher.class); + private static final String FAN_OUT_REFRESH_WORK_EXECUTOR_NAME = + "FanOutActiveWorkRefreshExecutor-%d"; private final Supplier clock; private final int activeWorkRefreshPeriodMillis; private final Supplier> computations; private final DataflowExecutionStateSampler sampler; private final int stuckCommitDurationMillis; + private final HeartbeatTracker heartbeatTracker; private final ScheduledExecutorService activeWorkRefreshExecutor; - private final Consumer> heartbeatSender; + private final ExecutorService fanOutActiveWorkRefreshExecutor; public ActiveWorkRefresher( Supplier clock, @@ -64,14 +72,23 @@ public ActiveWorkRefresher( Supplier> computations, DataflowExecutionStateSampler sampler, ScheduledExecutorService activeWorkRefreshExecutor, - Consumer> heartbeatSender) { + HeartbeatTracker heartbeatTracker) { this.clock = clock; this.activeWorkRefreshPeriodMillis = activeWorkRefreshPeriodMillis; this.stuckCommitDurationMillis = stuckCommitDurationMillis; this.computations = computations; this.sampler = sampler; this.activeWorkRefreshExecutor = activeWorkRefreshExecutor; - this.heartbeatSender = heartbeatSender; + this.heartbeatTracker = heartbeatTracker; + this.fanOutActiveWorkRefreshExecutor = + Executors.newCachedThreadPool( + new ThreadFactoryBuilder() + // Work refresh runs as a background process, don't let failures crash + // the worker. + .setUncaughtExceptionHandler( + (t, e) -> LOG.error("Unexpected failure in {}", t.getName(), e)) + .setNameFormat(FAN_OUT_REFRESH_WORK_EXECUTOR_NAME) + .build()); } @SuppressWarnings("FutureReturnValueIgnored") @@ -115,9 +132,41 @@ private void invalidateStuckCommits() { } } + /** Create {@link Heartbeats} and group them by {@link HeartbeatSender}. */ private void refreshActiveWork() { Instant refreshDeadline = clock.get().minus(Duration.millis(activeWorkRefreshPeriodMillis)); + Map heartbeatsBySender = + aggregateHeartbeatsBySender(refreshDeadline); + if (heartbeatsBySender.isEmpty()) { + return; + } + + if (heartbeatsBySender.size() == 1) { + // If there is a single HeartbeatSender, just use the calling thread to send heartbeats. + Map.Entry heartbeat = + Iterables.getOnlyElement(heartbeatsBySender.entrySet()); + sendHeartbeat(heartbeat); + } else { + // If there are multiple HeartbeatSenders, send out the heartbeats in parallel using the + // fanOutActiveWorkRefreshExecutor. + List> fanOutRefreshActiveWork = new ArrayList<>(); + for (Map.Entry heartbeat : heartbeatsBySender.entrySet()) { + fanOutRefreshActiveWork.add( + CompletableFuture.runAsync( + () -> sendHeartbeat(heartbeat), fanOutActiveWorkRefreshExecutor)); + } + + // Don't block until we kick off all the refresh active work RPCs. + @SuppressWarnings("rawtypes") + CompletableFuture parallelFanOutRefreshActiveWork = + CompletableFuture.allOf(fanOutRefreshActiveWork.toArray(new CompletableFuture[0])); + parallelFanOutRefreshActiveWork.join(); + } + } + + /** Aggregate the heartbeats across computations by HeartbeatSender for correct fan out. */ + private Map aggregateHeartbeatsBySender(Instant refreshDeadline) { Map heartbeatsBySender = new HashMap<>(); // Aggregate the heartbeats across computations by HeartbeatSender for correct fan out. @@ -125,22 +174,30 @@ private void refreshActiveWork() { for (RefreshableWork work : computationState.getRefreshableWork(refreshDeadline)) { heartbeatsBySender .computeIfAbsent(work.heartbeatSender(), ignored -> Heartbeats.builder()) - .addWork(work) - .addHeartbeatRequest(computationState.getComputationId(), createHeartbeatRequest(work)); + .add(computationState.getComputationId(), work, sampler); } } - heartbeatSender.accept( - heartbeatsBySender.entrySet().stream() - .collect(toImmutableMap(Map.Entry::getKey, e -> e.getValue().build()))); + return heartbeatsBySender.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, e -> e.getValue().build())); + } + + private void sendHeartbeat(Map.Entry heartbeat) { + try (AutoCloseable ignored = heartbeatTracker.trackHeartbeats(heartbeat.getValue().size())) { + HeartbeatSender sender = heartbeat.getKey(); + Heartbeats heartbeats = heartbeat.getValue(); + sender.sendHeartbeats(heartbeats); + } catch (Exception e) { + LOG.error( + "Unable to send {} heartbeats to {}.", + heartbeat.getValue().size(), + heartbeat.getKey(), + e); + } } - private HeartbeatRequest createHeartbeatRequest(RefreshableWork work) { - return HeartbeatRequest.newBuilder() - .setShardingKey(work.getShardedKey().shardingKey()) - .setWorkToken(work.id().workToken()) - .setCacheToken(work.id().cacheToken()) - .addAllLatencyAttribution(work.getHeartbeatLatencyAttributions(sampler)) - .build(); + @FunctionalInterface + public interface HeartbeatTracker { + AutoCloseable trackHeartbeats(int numHeartbeats); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java index a03ff4b430979..b1c42618b09cb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/FixedStreamHeartbeatSender.java @@ -19,6 +19,7 @@ import java.util.Objects; import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; +import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.GetDataStream; import org.apache.beam.sdk.annotations.Internal; import org.slf4j.Logger; @@ -49,15 +50,24 @@ public static FixedStreamHeartbeatSender create(GetDataStream getDataStream) { @Override public void sendHeartbeats(Heartbeats heartbeats) { - if (getDataStream.isShutdown()) { + String threadName = Thread.currentThread().getName(); + try { + String backendWorkerToken = getDataStream.backendWorkerToken(); + if (!backendWorkerToken.isEmpty()) { + // Decorate the thread name w/ the backendWorkerToken for debugging. Resets the thread's + // name after sending the heartbeats succeeds or fails. + Thread.currentThread().setName(threadName + "-" + backendWorkerToken); + } + getDataStream.refreshActiveWork(heartbeats.heartbeatRequests().asMap()); + } catch (WindmillStream.WindmillStreamShutdownException e) { LOG.warn( "Trying to refresh work w/ {} heartbeats on stream={} after work has moved off of worker." + " heartbeats", getDataStream.backendWorkerToken(), heartbeats.heartbeatRequests().size()); heartbeats.work().forEach(RefreshableWork::setFailed); - } else { - getDataStream.refreshActiveWork(heartbeats.heartbeatRequests().asMap()); + } finally { + Thread.currentThread().setName(threadName); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/HeartbeatSender.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/HeartbeatSender.java index 3ee0090ebcaa8..06559344332ca 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/HeartbeatSender.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/HeartbeatSender.java @@ -17,7 +17,11 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; -/** Interface for sending heartbeats. */ +/** + * Interface for sending heartbeats. + * + * @implNote Batching/grouping of heartbeats is performed by HeartbeatSender equality. + */ @FunctionalInterface public interface HeartbeatSender { /** diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java index cff65ca183257..78e9864f4eed3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/Heartbeats.java @@ -18,6 +18,7 @@ package org.apache.beam.runners.dataflow.worker.windmill.work.refresh; import com.google.auto.value.AutoValue; +import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler; import org.apache.beam.runners.dataflow.worker.streaming.RefreshableWork; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; @@ -45,21 +46,32 @@ public abstract static class Builder { abstract ImmutableList.Builder workBuilder(); - public final Builder addWork(RefreshableWork work) { + public final Builder add( + String computationId, RefreshableWork work, DataflowExecutionStateSampler sampler) { workBuilder().add(work); + addHeartbeatRequest(computationId, createHeartbeatRequest(work, sampler)); return this; } + private Windmill.HeartbeatRequest createHeartbeatRequest( + RefreshableWork work, DataflowExecutionStateSampler sampler) { + return Windmill.HeartbeatRequest.newBuilder() + .setShardingKey(work.getShardedKey().shardingKey()) + .setWorkToken(work.id().workToken()) + .setCacheToken(work.id().cacheToken()) + .addAllLatencyAttribution(work.getHeartbeatLatencyAttributions(sampler)) + .build(); + } + abstract Builder setHeartbeatRequests( ImmutableListMultimap value); abstract ImmutableListMultimap.Builder heartbeatRequestsBuilder(); - public final Builder addHeartbeatRequest( + private void addHeartbeatRequest( String computationId, Windmill.HeartbeatRequest heartbeatRequest) { heartbeatRequestsBuilder().put(computationId, heartbeatRequest); - return this; } public abstract Heartbeats build(); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java index 5406a72927393..b3f7467cdbd34 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/FakeWindmillServer.java @@ -239,11 +239,6 @@ public String backendWorkerToken() { @Override public void shutdown() {} - @Override - public boolean isShutdown() { - return false; - } - @Override public void halfClose() { done.countDown(); @@ -315,11 +310,6 @@ public String backendWorkerToken() { @Override public void shutdown() {} - @Override - public boolean isShutdown() { - return false; - } - @Override public Windmill.KeyedGetDataResponse requestKeyedData( String computation, KeyedGetDataRequest request) { @@ -401,11 +391,6 @@ public String backendWorkerToken() { @Override public void shutdown() {} - @Override - public boolean isShutdown() { - return false; - } - @Override public RequestBatcher batcher() { return new RequestBatcher() { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index be1e1278a767b..0d2eb29975508 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -126,7 +126,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.Timer.Type; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WatermarkHold; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItemCommitRequest; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; @@ -316,20 +316,6 @@ private static ExecutableWork createMockWork( return createMockWork(shardedKey, workToken, computationId, ignored -> {}); } - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public KeyedGetDataResponse getStateData(String computation, KeyedGetDataRequest request) { - return KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); - } - }; - } - private static ExecutableWork createMockWork( ShardedKey shardedKey, long workToken, Consumer processWorkFn) { return createMockWork(shardedKey, workToken, "computationId", processWorkFn); @@ -346,10 +332,7 @@ private static ExecutableWork createMockWork( .build(), Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( - computationId, - createMockGetDataClient(), - ignored -> {}, - mock(HeartbeatSender.class)), + computationId, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()), processWorkFn); @@ -3422,7 +3405,7 @@ public void testLatencyAttributionProtobufsPopulated() { Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( "computationId", - createMockGetDataClient(), + new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), clock, diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 9f8e4c2dfc140..2bd6621dd4f44 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -61,7 +61,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.streaming.sideinput.SideInputStateFetcher; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; @@ -135,26 +135,11 @@ private static Work createMockWork(Windmill.WorkItem workItem, Watermarks waterm workItem, watermarks, Work.createProcessingContext( - COMPUTATION_ID, createMockGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), + COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()); } - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - return Windmill.KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); - } - }; - } - @Test public void testTimerInternalsSetTimer() { Windmill.WorkItemCommitRequest.Builder outputBuilder = diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java index 504b50daa3dce..98302c512256c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WorkerCustomSourcesTest.java @@ -95,7 +95,7 @@ import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader; import org.apache.beam.runners.dataflow.worker.util.common.worker.NativeReader.NativeReaderIterator; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; @@ -199,26 +199,11 @@ private static Work createMockWork(Windmill.WorkItem workItem, Watermarks waterm workItem, watermarks, Work.createProcessingContext( - COMPUTATION_ID, createMockGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), + COMPUTATION_ID, new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), Instant::now, Collections.emptyList()); } - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - return Windmill.KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); - } - }; - } - private static class SourceProducingSubSourcesInSplit extends MockSource { int numDesiredBundle; int sourceObjectSize; @@ -1014,7 +999,7 @@ public void testFailedWorkItemsAbort() throws Exception { Watermarks.builder().setInputDataWatermark(new Instant(0)).build(), Work.createProcessingContext( COMPUTATION_ID, - createMockGetDataClient(), + new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), Instant::now, diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java index 663edcbcdb75d..a373dffd1dc47 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ActiveWorkStateTest.java @@ -32,7 +32,7 @@ import java.util.Optional; import org.apache.beam.runners.dataflow.worker.streaming.ActiveWorkState.ActivateWorkResult; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; @@ -82,22 +82,7 @@ private static ExecutableWork expiredWork(Windmill.WorkItem workItem) { private static Work.ProcessingContext createWorkProcessingContext() { return Work.createProcessingContext( - "computationId", createMockGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)); - } - - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - return Windmill.KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); - } - }; + "computationId", new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)); } private static WorkId workId(long workToken, long cacheToken) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java index 658f12cf70ee0..1f70c24763255 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationStateCacheTest.java @@ -36,7 +36,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.config.ComputationConfig; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; @@ -70,7 +70,7 @@ private static ExecutableWork createWork(ShardedKey shardedKey, long workToken, Watermarks.builder().setInputDataWatermark(Instant.now()).build(), Work.createProcessingContext( "computationId", - createMockGetDataClient(), + new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), Instant::now, @@ -78,21 +78,6 @@ private static ExecutableWork createWork(ShardedKey shardedKey, long workToken, ignored -> {}); } - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - return Windmill.KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); - } - }; - } - @Before public void setUp() { computationStateCache = diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java index ef73d4b0ef27d..ad77958837a12 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/util/BoundedQueueExecutorTest.java @@ -32,7 +32,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Watermarks; import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -68,7 +68,7 @@ private static ExecutableWork createWork(Consumer executeWorkFn) { Watermarks.builder().setInputDataWatermark(Instant.now()).build(), Work.createProcessingContext( "computationId", - createMockGetDataClient(), + new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), Instant::now, @@ -76,34 +76,17 @@ private static ExecutableWork createWork(Consumer executeWorkFn) { executeWorkFn); } - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - return Windmill.KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); + private Runnable createSleepProcessWorkFn(CountDownLatch start, CountDownLatch stop) { + return () -> { + start.countDown(); + try { + stop.await(); + } catch (Exception e) { + throw new RuntimeException(e); } }; } - private Runnable createSleepProcessWorkFn(CountDownLatch start, CountDownLatch stop) { - Runnable runnable = - () -> { - start.countDown(); - try { - stop.await(); - } catch (Exception e) { - throw new RuntimeException(e); - } - }; - return runnable; - } - @Before public void setUp() { this.executor = diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java index 7e5b350b48323..bdad382c9af22 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/WindmillStreamPoolTest.java @@ -260,15 +260,5 @@ public String backendWorkerToken() { public void shutdown() { halfClose(); } - - @Override - public boolean isShutdown() { - return closed; - } - - @Override - public Type streamType() { - return Type.GET_DATA; - } } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java index 37ab2c863c79b..51cd83d17fabf 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingApplianceWorkCommitterTest.java @@ -36,7 +36,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; @@ -66,7 +66,7 @@ private static Work createMockWork(long workToken) { Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( "computationId", - createMockGetDataClient(), + new FakeGetDataClient(), ignored -> { throw new UnsupportedOperationException(); }, @@ -75,21 +75,6 @@ private static Work createMockWork(long workToken) { Collections.emptyList()); } - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - return Windmill.KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); - } - }; - } - private static ComputationState createComputationState(String computationId) { return new ComputationState( computationId, diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java index 36d48d778e8cc..546a2883e3b20 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/commits/StreamingEngineWorkCommitterTest.java @@ -50,13 +50,12 @@ import org.apache.beam.runners.dataflow.worker.windmill.client.CloseableStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStream.CommitWorkStream; import org.apache.beam.runners.dataflow.worker.windmill.client.WindmillStreamPool; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.joda.time.Duration; import org.joda.time.Instant; -import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -68,7 +67,7 @@ public class StreamingEngineWorkCommitterTest { @Rule public ErrorCollector errorCollector = new ErrorCollector(); - private StreamingEngineWorkCommitter workCommitter; + private WorkCommitter workCommitter; private FakeWindmillServer fakeWindmillServer; private Supplier> commitWorkStreamFactory; @@ -83,7 +82,7 @@ private static Work createMockWork(long workToken) { Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( "computationId", - createMockGetDataClient(), + new FakeGetDataClient(), ignored -> { throw new UnsupportedOperationException(); }, @@ -92,21 +91,6 @@ private static Work createMockWork(long workToken) { Collections.emptyList()); } - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - return Windmill.KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); - } - }; - } - private static ComputationState createComputationState(String computationId) { return new ComputationState( computationId, @@ -135,14 +119,11 @@ public void setUp() throws IOException { ::getCloseableStream; } - @After - public void cleanUp() { - workCommitter.stop(); - } - - private StreamingEngineWorkCommitter createWorkCommitter( - Consumer onCommitComplete) { - return StreamingEngineWorkCommitter.create(commitWorkStreamFactory, 1, onCommitComplete); + private WorkCommitter createWorkCommitter(Consumer onCommitComplete) { + return StreamingEngineWorkCommitter.builder() + .setCommitWorkStreamFactory(commitWorkStreamFactory) + .setOnCommitComplete(onCommitComplete) + .build(); } @Test @@ -174,6 +155,8 @@ public void testCommit_sendsCommitsToStreamingEngine() { assertThat(request).isEqualTo(commit.request()); assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); } + + workCommitter.stop(); } @Test @@ -214,6 +197,8 @@ public void testCommit_handlesFailedCommits() { .containsEntry(commit.work().getWorkItem().getWorkToken(), commit.request()); } } + + workCommitter.stop(); } @Test @@ -266,6 +251,8 @@ public void testCommit_handlesCompleteCommits_commitStatusNotOK() { .contains(asCompleteCommit(commit, expectedCommitStatus.get(commit.work().id()))); } assertThat(completeCommits.size()).isEqualTo(commits.size()); + + workCommitter.stop(); } @Test @@ -310,11 +297,6 @@ public String backendWorkerToken() { @Override public void shutdown() {} - - @Override - public boolean isShutdown() { - return false; - } }; commitWorkStreamFactory = @@ -359,7 +341,12 @@ public void testMultipleCommitSendersSingleStream() { ::getCloseableStream; Set completeCommits = Collections.newSetFromMap(new ConcurrentHashMap<>()); workCommitter = - StreamingEngineWorkCommitter.create(commitWorkStreamFactory, 5, completeCommits::add); + StreamingEngineWorkCommitter.builder() + .setCommitWorkStreamFactory(commitWorkStreamFactory) + .setNumCommitSenders(5) + .setOnCommitComplete(completeCommits::add) + .build(); + List commits = new ArrayList<>(); for (int i = 1; i <= 500; i++) { Work work = createMockWork(i); @@ -384,5 +371,7 @@ public void testMultipleCommitSendersSingleStream() { assertThat(request).isEqualTo(commit.request()); assertThat(completeCommits).contains(asCompleteCommit(commit, Windmill.CommitStatus.OK)); } + + workCommitter.stop(); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/WorkRefreshClient.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/FakeGetDataClient.java similarity index 56% rename from runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/WorkRefreshClient.java rename to runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/FakeGetDataClient.java index 76f6147b07434..ca89e9647153d 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/WorkRefreshClient.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/FakeGetDataClient.java @@ -17,11 +17,23 @@ */ package org.apache.beam.runners.dataflow.worker.windmill.client.getdata; -import java.util.Map; -import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; -import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.Heartbeats; +import java.io.PrintWriter; +import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -/** Client for requesting work refresh via heartbeats. */ -public interface WorkRefreshClient { - void refreshActiveWork(Map heartbeats); +/** Fake {@link GetDataClient} implementation for testing. */ +public final class FakeGetDataClient implements GetDataClient { + @Override + public Windmill.KeyedGetDataResponse getStateData( + String computationId, Windmill.KeyedGetDataRequest request) throws GetDataException { + return Windmill.KeyedGetDataResponse.getDefaultInstance(); + } + + @Override + public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) + throws GetDataException { + return Windmill.GlobalData.getDefaultInstance(); + } + + @Override + public void printHtml(PrintWriter writer) {} } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTrackerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTrackerTest.java index b19e7f06896cb..d687434edff43 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTrackerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/client/getdata/ThrottlingGetDataMetricTrackerTest.java @@ -20,19 +20,15 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.assertFalse; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; -import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import org.apache.beam.runners.dataflow.worker.util.MemoryMonitor; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker.Type; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -47,15 +43,14 @@ public class ThrottlingGetDataMetricTrackerTest { private final ExecutorService getDataProcessor = Executors.newCachedThreadPool(); @Test - public void testTrackSingleCallWithThrottling_STATE() throws InterruptedException { - doNothing().when(memoryMonitor).waitForResources(eq(Type.STATE.debugName())); + public void testTrackFetchStateDataWithThrottling() throws InterruptedException { + doNothing().when(memoryMonitor).waitForResources(anyString()); CountDownLatch processCall = new CountDownLatch(1); CountDownLatch callProcessing = new CountDownLatch(1); CountDownLatch processingDone = new CountDownLatch(1); getDataProcessor.submit( () -> { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling(Type.STATE)) { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { callProcessing.countDown(); processCall.await(); } catch (Exception e) { @@ -65,7 +60,7 @@ public void testTrackSingleCallWithThrottling_STATE() throws InterruptedExceptio }); callProcessing.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsWhileProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(1); @@ -76,7 +71,7 @@ public void testTrackSingleCallWithThrottling_STATE() throws InterruptedExceptio // decremented processCall.countDown(); processingDone.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsAfterProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); @@ -84,15 +79,14 @@ public void testTrackSingleCallWithThrottling_STATE() throws InterruptedExceptio } @Test - public void testTrackSingleCallWithThrottling_SIDE_INPUT() throws InterruptedException { - doNothing().when(memoryMonitor).waitForResources(eq(Type.SIDE_INPUT.debugName())); + public void testTrackSideInputFetchWithThrottling() throws InterruptedException { + doNothing().when(memoryMonitor).waitForResources(anyString()); CountDownLatch processCall = new CountDownLatch(1); CountDownLatch callProcessing = new CountDownLatch(1); CountDownLatch processingDone = new CountDownLatch(1); getDataProcessor.submit( () -> { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling(Type.SIDE_INPUT)) { + try (AutoCloseable ignored = getDataMetricTracker.trackSideInputFetchWithThrottling()) { callProcessing.countDown(); processCall.await(); } catch (Exception e) { @@ -102,7 +96,7 @@ public void testTrackSingleCallWithThrottling_SIDE_INPUT() throws InterruptedExc }); callProcessing.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsWhileProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(0); @@ -113,96 +107,7 @@ public void testTrackSingleCallWithThrottling_SIDE_INPUT() throws InterruptedExc // decremented processCall.countDown(); processingDone.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsAfterProcessing = - getDataMetricTracker.getMetricsSnapshot(); - assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); - assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); - assertThat(metricsAfterProcessing.activeSideInputs()).isEqualTo(0); - } - - @Test - public void testTrackSingleCallWithThrottling_HEARTBEAT() throws InterruptedException { - doNothing().when(memoryMonitor).waitForResources(eq(Type.HEARTBEAT.debugName())); - CountDownLatch processCall = new CountDownLatch(1); - CountDownLatch callProcessing = new CountDownLatch(1); - CountDownLatch processingDone = new CountDownLatch(1); - getDataProcessor.submit( - () -> { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling(Type.HEARTBEAT)) { - callProcessing.countDown(); - processCall.await(); - } catch (Exception e) { - // Do nothing. - } - processingDone.countDown(); - }); - - callProcessing.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsWhileProcessing = - getDataMetricTracker.getMetricsSnapshot(); - - assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(0); - assertThat(metricsWhileProcessing.activeHeartbeats()).isEqualTo(1); - assertThat(metricsWhileProcessing.activeSideInputs()).isEqualTo(0); - - // Free the thread inside the AutoCloseable, wait for processingDone and check that metrics gets - // decremented - processCall.countDown(); - processingDone.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsAfterProcessing = - getDataMetricTracker.getMetricsSnapshot(); - assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); - assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); - assertThat(metricsAfterProcessing.activeSideInputs()).isEqualTo(0); - } - - @Test - public void testTrackSingleCall_multipleThreads() throws InterruptedException { - doNothing().when(memoryMonitor).waitForResources(anyString()); - // Issuing 5 calls (1 from each thread) - // 2 State Reads - // 2 SideInput Reads - // 1 Heartbeat - List callTypes = - Lists.newArrayList( - Type.STATE, Type.SIDE_INPUT, Type.STATE, Type.HEARTBEAT, Type.SIDE_INPUT); - CountDownLatch processCall = new CountDownLatch(callTypes.size()); - CountDownLatch callProcessing = new CountDownLatch(callTypes.size()); - CountDownLatch processingDone = new CountDownLatch(callTypes.size()); - for (Type callType : callTypes) { - getDataProcessor.submit( - () -> { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling(callType)) { - callProcessing.countDown(); - processCall.await(); - } catch (Exception e) { - // Do nothing. - } - processingDone.countDown(); - }); - } - - callProcessing.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsWhileProcessing = - getDataMetricTracker.getMetricsSnapshot(); - - // Asserting that metrics reflects: - // 2 State Reads - // 2 SideInput Reads - // 1 Heartbeat - assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(2); - assertThat(metricsWhileProcessing.activeSideInputs()).isEqualTo(2); - assertThat(metricsWhileProcessing.activeHeartbeats()).isEqualTo(1); - - // Free the thread inside the AutoCloseable, wait for processingDone and check that metrics gets - // decremented - for (int i = 0; i < callTypes.size(); i++) { - processCall.countDown(); - } - processingDone.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsAfterProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); @@ -217,8 +122,7 @@ public void testThrottledTrackSingleCallWithThrottling() throws InterruptedExcep CountDownLatch processingDone = new CountDownLatch(1); getDataProcessor.submit( () -> { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling(Type.STATE)) { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { callProcessing.countDown(); processCall.await(); } catch (Exception e) { @@ -228,7 +132,7 @@ public void testThrottledTrackSingleCallWithThrottling() throws InterruptedExcep }); assertFalse(callProcessing.await(10, TimeUnit.MILLISECONDS)); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsBeforeProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsBeforeProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsBeforeProcessing.activeStateReads()).isEqualTo(0); assertThat(metricsBeforeProcessing.activeHeartbeats()).isEqualTo(0); @@ -237,7 +141,7 @@ public void testThrottledTrackSingleCallWithThrottling() throws InterruptedExcep // Stop throttling. mockThrottler.countDown(); callProcessing.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsWhileProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(1); @@ -246,7 +150,7 @@ public void testThrottledTrackSingleCallWithThrottling() throws InterruptedExcep // decremented processCall.countDown(); processingDone.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsAfterProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); } @@ -263,8 +167,7 @@ public void testTrackSingleCall_exceptionThrown() throws InterruptedException { getDataProcessor.submit( () -> { try { - try (AutoCloseable ignored = - getDataMetricTracker.trackSingleCallWithThrottling(Type.STATE)) { + try (AutoCloseable ignored = getDataMetricTracker.trackStateDataFetchWithThrottling()) { callProcessing.countDown(); beforeException.await(); throw new RuntimeException("something bad happened"); @@ -277,7 +180,7 @@ public void testTrackSingleCall_exceptionThrown() throws InterruptedException { callProcessing.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsWhileProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsWhileProcessing.activeStateReads()).isEqualTo(1); @@ -285,7 +188,7 @@ public void testTrackSingleCall_exceptionThrown() throws InterruptedException { // In the midst of an exception, close() should still run. afterException.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsAfterProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsAfterProcessing.activeStateReads()).isEqualTo(0); } @@ -308,7 +211,7 @@ public void testTrackHeartbeats() throws InterruptedException { }); callProcessing.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsWhileProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsWhileProcessing.activeHeartbeats()).isEqualTo(5); @@ -317,7 +220,7 @@ public void testTrackHeartbeats() throws InterruptedException { // decremented processCall.countDown(); processingDone.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsAfterProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); } @@ -346,7 +249,7 @@ public void testTrackHeartbeats_exceptionThrown() throws InterruptedException { callProcessing.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsWhileProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsWhileProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsWhileProcessing.activeHeartbeats()).isEqualTo(numHeartbeats); @@ -354,7 +257,7 @@ public void testTrackHeartbeats_exceptionThrown() throws InterruptedException { // In the midst of an exception, close() should still run. afterException.await(); - ThrottlingGetDataMetricTracker.GetDataMetrics.ReadOnlySnapshot metricsAfterProcessing = + ThrottlingGetDataMetricTracker.ReadOnlySnapshot metricsAfterProcessing = getDataMetricTracker.getMetricsSnapshot(); assertThat(metricsAfterProcessing.activeHeartbeats()).isEqualTo(0); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java index ea90bb276a4bb..146b05bb7e35f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/processing/failures/WorkFailureProcessorTest.java @@ -35,7 +35,7 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; @@ -89,7 +89,7 @@ private static ExecutableWork createWork(Supplier clock, Consumer Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( "computationId", - createMockGetDataClient(), + new FakeGetDataClient(), ignored -> {}, mock(HeartbeatSender.class)), clock, @@ -97,21 +97,6 @@ private static ExecutableWork createWork(Supplier clock, Consumer processWorkFn); } - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - return Windmill.KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); - } - }; - } - private static ExecutableWork createWork(Consumer processWorkFn) { return createWork(Instant::now, processWorkFn); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java index dbd5959293167..9dce3392c60c5 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/windmill/work/refresh/ActiveWorkRefresherTest.java @@ -47,13 +47,12 @@ import org.apache.beam.runners.dataflow.worker.streaming.Work; import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; -import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient; +import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.FakeGetDataClient; import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache; import org.apache.beam.runners.direct.Clock; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ThreadFactoryBuilder; import org.joda.time.Duration; @@ -61,6 +60,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; @RunWith(JUnit4.class) public class ActiveWorkRefresherTest { @@ -94,27 +94,12 @@ private static ComputationState createComputationState( stateCache); } - private static GetDataClient createMockGetDataClient() { - return new GetDataClient() { - @Override - public Windmill.KeyedGetDataResponse getStateData( - String computation, Windmill.KeyedGetDataRequest request) { - return Windmill.KeyedGetDataResponse.getDefaultInstance(); - } - - @Override - public Windmill.GlobalData getSideInputData(Windmill.GlobalDataRequest request) { - return Windmill.GlobalData.getDefaultInstance(); - } - }; - } - private ActiveWorkRefresher createActiveWorkRefresher( Supplier clock, int activeWorkRefreshPeriodMillis, int stuckCommitDurationMillis, Supplier> computations, - Consumer> activeWorkRefresherFn) { + ActiveWorkRefresher.HeartbeatTracker heartbeatTracker) { return new ActiveWorkRefresher( clock, activeWorkRefreshPeriodMillis, @@ -122,7 +107,7 @@ private ActiveWorkRefresher createActiveWorkRefresher( computations, DataflowExecutionStateSampler.instance(), Executors.newSingleThreadScheduledExecutor(), - activeWorkRefresherFn); + heartbeatTracker); } private ExecutableWork createOldWork(int workIds, Consumer processWork) { @@ -142,7 +127,7 @@ private ExecutableWork createOldWork( .build(), Watermarks.builder().setInputDataWatermark(Instant.EPOCH).build(), Work.createProcessingContext( - "computationId", createMockGetDataClient(), ignored -> {}, heartbeatSender), + "computationId", new FakeGetDataClient(), ignored -> {}, heartbeatSender), A_LONG_TIME_AGO, ImmutableList.of()), processWork); @@ -177,7 +162,6 @@ public void testActiveWorkRefresh() throws InterruptedException { activeWorkForComputation.add(fakeWork); } - Map fanoutExpectedHeartbeats = new HashMap<>(); CountDownLatch heartbeatsSent = new CountDownLatch(1); TestClock fakeClock = new TestClock(Instant.now()); ActiveWorkRefresher activeWorkRefresher = @@ -186,40 +170,36 @@ public void testActiveWorkRefresh() throws InterruptedException { activeWorkRefreshPeriodMillis, 0, () -> computations, - heartbeats -> { - fanoutExpectedHeartbeats.putAll(heartbeats); - heartbeatsSent.countDown(); - }); + heartbeats -> heartbeatsSent::countDown); + ArgumentCaptor heartbeatsCaptor = ArgumentCaptor.forClass(Heartbeats.class); activeWorkRefresher.start(); fakeClock.advance(Duration.millis(activeWorkRefreshPeriodMillis * 2)); heartbeatsSent.await(); activeWorkRefresher.stop(); - + verify(heartbeatSender).sendHeartbeats(heartbeatsCaptor.capture()); + Heartbeats fanoutExpectedHeartbeats = heartbeatsCaptor.getValue(); assertThat(computationsAndWork.size()) - .isEqualTo( - Iterables.getOnlyElement(fanoutExpectedHeartbeats.values()).heartbeatRequests().size()); - for (Map.Entry fanOutExpectedHeartbeat : - fanoutExpectedHeartbeats.entrySet()) { - for (Map.Entry> expectedHeartbeat : - fanOutExpectedHeartbeat.getValue().heartbeatRequests().asMap().entrySet()) { - String computationId = expectedHeartbeat.getKey(); - Collection heartbeatRequests = expectedHeartbeat.getValue(); - List work = - computationsAndWork.get(computationId).stream() - .map(ExecutableWork::work) - .collect(Collectors.toList()); - // Compare the heartbeatRequest's and Work's workTokens, cacheTokens, and shardingKeys. - assertThat(heartbeatRequests) - .comparingElementsUsing( - Correspondence.from( - (Windmill.HeartbeatRequest h, Work w) -> - h.getWorkToken() == w.getWorkItem().getWorkToken() - && h.getCacheToken() == w.getWorkItem().getWorkToken() - && h.getShardingKey() == w.getWorkItem().getShardingKey(), - "heartbeatRequest's and Work's workTokens, cacheTokens, and shardingKeys should be equal.")) - .containsExactlyElementsIn(work); - } + .isEqualTo(fanoutExpectedHeartbeats.heartbeatRequests().size()); + + for (Map.Entry> expectedHeartbeat : + fanoutExpectedHeartbeats.heartbeatRequests().asMap().entrySet()) { + String computationId = expectedHeartbeat.getKey(); + Collection heartbeatRequests = expectedHeartbeat.getValue(); + List work = + computationsAndWork.get(computationId).stream() + .map(ExecutableWork::work) + .collect(Collectors.toList()); + // Compare the heartbeatRequest's and Work's workTokens, cacheTokens, and shardingKeys. + assertThat(heartbeatRequests) + .comparingElementsUsing( + Correspondence.from( + (Windmill.HeartbeatRequest h, Work w) -> + h.getWorkToken() == w.getWorkItem().getWorkToken() + && h.getCacheToken() == w.getWorkItem().getWorkToken() + && h.getShardingKey() == w.getWorkItem().getShardingKey(), + "heartbeatRequest's and Work's workTokens, cacheTokens, and shardingKeys should be equal.")) + .containsExactlyElementsIn(work); } activeWorkRefresher.stop(); @@ -265,7 +245,7 @@ public void testInvalidateStuckCommits() throws InterruptedException { 0, stuckCommitDurationMillis, computations.rowMap()::keySet, - ignored -> {}); + ignored -> () -> {}); activeWorkRefresher.start(); fakeClock.advance(Duration.millis(stuckCommitDurationMillis));