Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Jul 27, 2024
1 parent 3e7f124 commit d5acf0f
Show file tree
Hide file tree
Showing 39 changed files with 483 additions and 838 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ApplianceGetDataClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.GetDataClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.StreamingEngineGetDataClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.StreamPoolGetDataClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.ThrottlingGetDataMetricTracker;
import org.apache.beam.runners.dataflow.worker.windmill.client.getdata.WorkRefreshClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.ChannelzServlet;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcDispatcherClient;
import org.apache.beam.runners.dataflow.worker.windmill.client.grpc.GrpcWindmillServer;
Expand Down Expand Up @@ -217,12 +216,16 @@ private StreamingDataflowWorker(

this.workCommitter =
windmillServiceEnabled
? StreamingEngineWorkCommitter.create(
WindmillStreamPool.create(
numCommitThreads, COMMIT_STREAM_TIMEOUT, windmillServer::commitWorkStream)
::getCloseableStream,
numCommitThreads,
this::onCompleteCommit)
? StreamingEngineWorkCommitter.builder()
.setCommitWorkStreamFactory(
WindmillStreamPool.create(
numCommitThreads,
COMMIT_STREAM_TIMEOUT,
windmillServer::commitWorkStream)
::getCloseableStream)
.setNumCommitSenders(numCommitThreads)
.setOnCommitComplete(this::onCompleteCommit)
.build()
: StreamingApplianceWorkCommitter.create(
windmillServer::commitWork, this::onCompleteCommit);

Expand Down Expand Up @@ -252,31 +255,26 @@ private StreamingDataflowWorker(
ThrottlingGetDataMetricTracker getDataMetricTracker =
new ThrottlingGetDataMetricTracker(memoryMonitor);

WindmillStreamPool<GetDataStream> getDataStreamPool =
WindmillStreamPool.create(
Math.max(1, options.getWindmillGetDataStreamCount()),
GET_DATA_STREAM_TIMEOUT,
windmillServer::getDataStream);

// Register standard file systems.
FileSystems.setDefaultPipelineOptions(options);

int stuckCommitDurationMillis =
windmillServiceEnabled && options.getStuckCommitDurationMillis() > 0
? options.getStuckCommitDurationMillis()
: 0;

WorkRefreshClient workRefreshClient;
int stuckCommitDurationMillis;
if (windmillServiceEnabled) {
StreamingEngineGetDataClient streamingEngineGetDataClient =
new StreamingEngineGetDataClient(getDataMetricTracker, getDataStreamPool);
this.getDataClient = streamingEngineGetDataClient;
workRefreshClient = streamingEngineGetDataClient;
WindmillStreamPool<GetDataStream> getDataStreamPool =
WindmillStreamPool.create(
Math.max(1, options.getWindmillGetDataStreamCount()),
GET_DATA_STREAM_TIMEOUT,
windmillServer::getDataStream);
this.getDataClient = new StreamPoolGetDataClient(getDataMetricTracker, getDataStreamPool);
this.heartbeatSender =
new StreamPoolHeartbeatSender(
options.getUseSeparateWindmillHeartbeatStreams()
? WindmillStreamPool.create(
1, GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream)
: getDataStreamPool);
stuckCommitDurationMillis =
options.getStuckCommitDurationMillis() > 0 ? options.getStuckCommitDurationMillis() : 0;
} else {
ApplianceGetDataClient applianceGetDataClient =
new ApplianceGetDataClient(windmillServer, getDataMetricTracker);
this.getDataClient = applianceGetDataClient;
workRefreshClient = applianceGetDataClient;
this.getDataClient = new ApplianceGetDataClient(windmillServer, getDataMetricTracker);
this.heartbeatSender = new ApplianceHeartbeatSender(windmillServer::getData);
stuckCommitDurationMillis = 0;
}

this.activeWorkRefresher =
Expand All @@ -287,7 +285,7 @@ private StreamingDataflowWorker(
computationStateCache::getAllPresentComputations,
sampler,
executorSupplier.apply("RefreshWork"),
workRefreshClient::refreshActiveWork);
getDataMetricTracker::trackHeartbeats);

WorkerStatusPages workerStatusPages =
WorkerStatusPages.create(DEFAULT_STATUS_PORT, memoryMonitor);
Expand Down Expand Up @@ -333,14 +331,8 @@ private StreamingDataflowWorker(
ID_GENERATOR,
stageInfoMap);

this.heartbeatSender =
options.isEnableStreamingEngine()
? new StreamPoolHeartbeatSender(
options.getUseSeparateWindmillHeartbeatStreams()
? WindmillStreamPool.create(
1, GET_DATA_STREAM_TIMEOUT, windmillServer::getDataStream)
: getDataStreamPool)
: new ApplianceHeartbeatSender(windmillServer::getData);
// Register standard file systems.
FileSystems.setDefaultPipelineOptions(options);

LOG.debug("windmillServiceEnabled: {}", windmillServiceEnabled);
LOG.debug("WindmillServiceEndpoint: {}", options.getWindmillServiceEndpoint());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ public WorkItemCancelledException(long sharding_key) {
super("Work item cancelled for key " + sharding_key);
}

public WorkItemCancelledException(Throwable e) {
super(e);
}

/** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */
public static boolean isWorkItemCancelledException(Throwable t) {
while (t != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,14 +206,17 @@ synchronized ImmutableListMultimap<ShardedKey, RefreshableWork> getReadOnlyActiv
.collect(
flatteningToImmutableListMultimap(
Entry::getKey,
e -> e.getValue().stream().map(ExecutableWork::work).map(Work::refreshableView)));
e ->
e.getValue().stream()
.map(ExecutableWork::work)
.map(work -> (RefreshableWork) work)));
}

synchronized ImmutableList<RefreshableWork> getRefreshableWork(Instant refreshDeadline) {
return activeWork.values().stream()
.flatMap(Deque::stream)
.map(ExecutableWork::work)
.filter(work -> work.isRefreshable(refreshDeadline))
.filter(work -> !work.isFailed() && work.getStartTime().isBefore(refreshDeadline))
.collect(toImmutableList());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.joda.time.Instant;

/** View of {@link Work} that exposes an interface for work refreshing. */
@Internal
Expand All @@ -32,8 +31,6 @@ public interface RefreshableWork {

ShardedKey getShardedKey();

boolean isRefreshable(Instant refreshDeadline);

HeartbeatSender heartbeatSender();

ImmutableList<Windmill.LatencyAttribution> getHeartbeatLatencyAttributions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ private static LatencyAttribution.Builder createLatencyAttributionWithActiveLate
return latencyAttribution;
}

public RefreshableWork refreshableView() {
return this;
}

public WorkItem getWorkItem() {
return workItem;
}
Expand Down Expand Up @@ -209,11 +205,6 @@ public String getLatencyTrackingId() {
return latencyTrackingId;
}

@Override
public boolean isRefreshable(Instant refreshDeadline) {
return !isFailed && getStartTime().isBefore(refreshDeadline);
}

@Override
public HeartbeatSender heartbeatSender() {
return processingContext.heartbeatSender();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalData;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.GlobalDataRequest;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
Expand All @@ -52,6 +53,7 @@

/** Class responsible for fetching side input state from the streaming backend. */
@NotThreadSafe
@Internal
public class SideInputStateFetcher {
private static final Logger LOG = LoggerFactory.getLogger(SideInputStateFetcher.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ public abstract class AbstractWindmillStream<RequestT, ResponseT> implements Win
private final Supplier<StreamObserver<RequestT>> requestObserverSupplier;
// Indicates if the current stream in requestObserver is closed by calling close() method
private final AtomicBoolean streamClosed;
private @Nullable StreamObserver<RequestT> requestObserver;
private final String backendWorkerToken;
private @Nullable StreamObserver<RequestT> requestObserver;

protected AbstractWindmillStream(
String debugStreamType,
Function<StreamObserver<ResponseT>, StreamObserver<RequestT>> clientFactory,
BackOff backoff,
StreamObserverFactory streamObserverFactory,
Expand All @@ -100,7 +101,7 @@ protected AbstractWindmillStream(
Executors.newSingleThreadExecutor(
new ThreadFactoryBuilder()
.setDaemon(true)
.setNameFormat(createThreadName(streamType(), backendWorkerToken))
.setNameFormat(createThreadName(debugStreamType, backendWorkerToken))
.build());
this.backoff = backoff;
this.streamRegistry = streamRegistry;
Expand All @@ -122,10 +123,10 @@ protected AbstractWindmillStream(
clientFactory, new AbstractWindmillStream<RequestT, ResponseT>.ResponseObserver());
}

private static String createThreadName(Type streamType, String backendWorkerToken) {
private static String createThreadName(String streamType, String backendWorkerToken) {
return !backendWorkerToken.isEmpty()
? String.format("%s-%s-WindmillStream-thread", streamType.name(), backendWorkerToken)
: String.format("%s-WindmillStream-thread", streamType.name());
? String.format("%s-%s-WindmillStream-thread", streamType, backendWorkerToken)
: String.format("%s-WindmillStream-thread", streamType);
}

private static long debugDuration(long nowMs, long startMs) {
Expand All @@ -151,6 +152,11 @@ private static long debugDuration(long nowMs, long startMs) {
*/
protected abstract void startThrottleTimer();

/** Reflects that {@link #shutdown()} was explicitly called. */
protected boolean isShutdown() {
return isShutdown.get();
}

private StreamObserver<RequestT> requestObserver() {
if (requestObserver == null) {
throw new NullPointerException(
Expand Down Expand Up @@ -274,15 +280,11 @@ public String backendWorkerToken() {
@Override
public void shutdown() {
if (isShutdown.compareAndSet(false, true)) {
halfClose();
requestObserver()
.onError(new WindmillStreamShutdownException("Explicit call to shutdown stream."));
}
}

@Override
public boolean isShutdown() {
return isShutdown.get();
}

private void setLastError(String error) {
lastError.set(error);
lastErrorTime.set(DateTime.now());
Expand Down Expand Up @@ -313,7 +315,7 @@ public void onCompleted() {

private void onStreamFinished(@Nullable Throwable t) {
synchronized (this) {
if (clientClosed.get() && !hasPendingRequests()) {
if (isShutdown.get() || (clientClosed.get() && !hasPendingRequests())) {
streamRegistry.remove(AbstractWindmillStream.this);
finishLatch.countDown();
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,6 @@ public interface WindmillStream {
*/
void shutdown();

/** Reflects that {@link #shutdown()} was explicitly called. */
boolean isShutdown();

Type streamType();

enum Type {
GET_WORKER_METADATA,
GET_WORK,
GET_DATA,
COMMIT_WORK,
}

/** Handle representing a stream of GetWork responses. */
@ThreadSafe
interface GetWorkStream extends WindmillStream {
Expand All @@ -72,11 +60,6 @@ interface GetWorkStream extends WindmillStream {

/** Returns the remaining in-flight {@link GetWorkBudget}. */
GetWorkBudget remainingBudget();

@Override
default Type streamType() {
return Type.GET_WORK;
}
}

/** Interface for streaming GetDataRequests to Windmill. */
Expand All @@ -93,11 +76,6 @@ Windmill.KeyedGetDataResponse requestKeyedData(
void refreshActiveWork(Map<String, Collection<HeartbeatRequest>> heartbeats);

void onHeartbeatResponse(List<Windmill.ComputationHeartbeatResponse> responses);

@Override
default Type streamType() {
return Type.GET_DATA;
}
}

/** Interface for streaming CommitWorkRequests to Windmill. */
Expand All @@ -109,11 +87,6 @@ interface CommitWorkStream extends WindmillStream {
*/
CommitWorkStream.RequestBatcher batcher();

@Override
default Type streamType() {
return Type.COMMIT_WORK;
}

@NotThreadSafe
interface RequestBatcher extends Closeable {
/**
Expand All @@ -140,10 +113,11 @@ default void close() {

/** Interface for streaming GetWorkerMetadata requests to Windmill. */
@ThreadSafe
interface GetWorkerMetadataStream extends WindmillStream {
@Override
default Type streamType() {
return Type.GET_WORKER_METADATA;
interface GetWorkerMetadataStream extends WindmillStream {}

class WindmillStreamShutdownException extends RuntimeException {
public WindmillStreamShutdownException(String message) {
super(message);
}
}
}
Loading

0 comments on commit d5acf0f

Please sign in to comment.