Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
m-trieu committed Jun 18, 2024
1 parent 09a8345 commit 2ba55f7
Show file tree
Hide file tree
Showing 19 changed files with 173 additions and 178 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ public Windmill.KeyedGetDataResponse getStateData(
activeStateReads.getAndIncrement();
return getDataStream.requestKeyedData(computation, request);
} catch (Exception e) {
if (WindmillStreamClosedException.isWindmillStreamCancelledException(e)) {
if (WindmillStreamClosedException.wasCauseOf(e)) {
LOG.error("Tried to fetch keyed data from a closed stream. Work has been cancelled", e);
throw new WorkItemCancelledException(request.getShardingKey());
}
Expand Down Expand Up @@ -285,16 +285,16 @@ public Windmill.GlobalData getSideInputData(
try {
return getDataStream.requestGlobalData(request);
} catch (Exception e) {
if (WindmillStreamClosedException.wasCauseOf(e)) {
LOG.error("Tried to fetch global data from a closed stream. Work has been cancelled", e);
throw new WorkItemCancelledException("Failed to get side input.", e);
}
throw new RuntimeException("Failed to get side input: ", e);
} finally {
activeSideInputs.getAndDecrement();
}
}

public WindmillStreamPool<GetDataStream> getGetDataStreamPool() {
return getDataStreamPool;
}

/**
* Attempts to refresh active work, fanning out to each {@link GetDataStream} in parallel.
*
Expand All @@ -307,16 +307,14 @@ public void refreshActiveWork(
}

try {
// There is 1 destination to send heartbeat requests.
if (heartbeats.size() == 1) {
// There is 1 destination to send heartbeat requests.
Map.Entry<HeartbeatSender, Map<String, List<HeartbeatRequest>>> heartbeat =
Iterables.getOnlyElement(heartbeats.entrySet());
HeartbeatSender sender = heartbeat.getKey();
sender.sendHeartbeats(heartbeat.getValue());
}

// There are multiple destinations to send heartbeat requests. Fan out requests in parallel.
else {
} else {
// There are multiple destinations to send heartbeat requests. Fan out requests in parallel.
refreshActiveWorkWithFanOut(heartbeats);
}
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ public WorkItemCancelledException(long shardingKey) {
super("Work item cancelled for key " + shardingKey);
}

public WorkItemCancelledException(String message, Throwable t) {
super(message, t);
}

/** Returns whether an exception was caused by a {@link WorkItemCancelledException}. */
public static boolean isWorkItemCancelledException(@Nullable Throwable t) {
while (t != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;
import javax.annotation.concurrent.ThreadSafe;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill.WorkItem;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
Expand Down Expand Up @@ -175,17 +174,7 @@ synchronized void failWorkForKey(Multimap<Long, WorkId> failedWork) {
WorkItem workItem = queuedWork.work().getWorkItem();
if (workItem.getWorkToken() == failedWorkId.workToken()
&& workItem.getCacheToken() == failedWorkId.cacheToken()) {
LOG.debug(
"Failing work "
+ computationStateCache.getComputation()
+ " "
+ entry.getKey().shardingKey()
+ " "
+ failedWorkId.workToken()
+ " "
+ failedWorkId.cacheToken()
+ ". The work will be retried and is not lost.");
queuedWork.work().setFailed();
queuedWork.work().fail();
break;
}
}
Expand Down Expand Up @@ -305,16 +294,12 @@ private synchronized ImmutableMap<ShardedKey, WorkId> getStuckCommitsAt(
* cause a {@link java.util.ConcurrentModificationException} as it is not a thread-safe data
* structure.
*/
synchronized ImmutableListMultimap<ShardedKey, Work.RefreshableView> getReadOnlyActiveWork(
DataflowExecutionStateSampler sampler) {
synchronized ImmutableListMultimap<ShardedKey, RefreshableWork> getReadOnlyActiveWork() {
return activeWork.entrySet().stream()
.collect(
flatteningToImmutableListMultimap(
Entry::getKey,
e ->
e.getValue().stream()
.map(ExecutableWork::work)
.map(work -> work.refreshableView(sampler))));
e -> e.getValue().stream().map(ExecutableWork::work).map(Work::refreshableView)));
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import java.util.Optional;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.annotation.Nullable;
import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
import org.apache.beam.runners.dataflow.worker.util.BoundedQueueExecutor;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateCache;
import org.apache.beam.runners.dataflow.worker.windmill.work.budget.GetWorkBudget;
Expand Down Expand Up @@ -138,9 +137,8 @@ public void invalidateStuckCommits(Instant stuckCommitDeadline) {
stuckCommitDeadline, this::completeWorkAndScheduleNextWorkForKey);
}

public ImmutableListMultimap<ShardedKey, Work.RefreshableView> currentActiveWorkReadOnly(
DataflowExecutionStateSampler sampler) {
return activeWorkState.getReadOnlyActiveWork(sampler);
public ImmutableListMultimap<ShardedKey, RefreshableWork> currentActiveWorkReadOnly() {
return activeWorkState.getReadOnlyActiveWork();
}

private void execute(ExecutableWork executableWork) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,25 @@
*/
package org.apache.beam.runners.dataflow.worker.streaming;

import org.apache.beam.runners.dataflow.worker.DataflowExecutionStateSampler;
import org.apache.beam.runners.dataflow.worker.windmill.Windmill;
import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.joda.time.Instant;

/** View of {@link Work} that exposes an interface for work refreshing. */
@Internal
public interface RefreshableWork {
Windmill.WorkItem getWorkItem();

WorkId id();

boolean isRefreshable(Instant refreshDeadline);

boolean isFailed();

HeartbeatSender heartbeatSender();

ImmutableList<Windmill.LatencyAttribution> getLatencyAttributions(
boolean isHeartbeat, DataflowExecutionStateSampler sampler);
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.Commit;
import org.apache.beam.runners.dataflow.worker.windmill.client.commits.WorkCommitter;
import org.apache.beam.runners.dataflow.worker.windmill.state.WindmillStateReader;
import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.DirectHeartbeatSender;
import org.apache.beam.runners.dataflow.worker.windmill.work.refresh.HeartbeatSender;
import org.apache.beam.sdk.annotations.Internal;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Represents the state of an attempt to process a {@link WorkItem} by executing user code.
Expand All @@ -59,7 +62,9 @@
*/
@NotThreadSafe
@Internal
public final class Work {
public final class Work implements RefreshableWork {
private static final Logger LOG = LoggerFactory.getLogger(Work.class);

private final ShardedKey shardedKey;
private final WorkItem workItem;
private final ProcessingContext processingContext;
Expand All @@ -79,7 +84,6 @@ private Work(
Supplier<Instant> clock) {
this.shardedKey = ShardedKey.create(workItem.getKey(), workItem.getShardingKey());
this.workItem = workItem;
this.processingContext = processingContext;
this.watermarks = watermarks;
this.clock = clock;
this.startTime = clock.get();
Expand All @@ -91,6 +95,15 @@ private Work(
+ Long.toHexString(workItem.getWorkToken());
this.currentState = TimedState.initialState(startTime);
this.isFailed = false;
this.processingContext =
processingContext.heartbeatSender() instanceof DirectHeartbeatSender
? processingContext
.toBuilder()
.setHeartbeatSender(
((DirectHeartbeatSender) processingContext.heartbeatSender())
.withStreamClosedHandler(() -> isFailed = true))
.build()
: processingContext;
}

public static Work create(
Expand Down Expand Up @@ -182,21 +195,27 @@ public void setState(State state) {
this.currentState = TimedState.create(state, now);
}

private boolean isRefreshable(Instant refreshDeadline) {
boolean isRefreshable = getStartTime().isBefore(refreshDeadline);
if (heartbeatSender().isInvalid()) {
setFailed();
return false;
}

return isRefreshable;
@Override
public boolean isRefreshable(Instant refreshDeadline) {
return getStartTime().isBefore(refreshDeadline);
}

@Override
public HeartbeatSender heartbeatSender() {
return processingContext.heartbeatSender();
}

public void setFailed() {
public void fail() {
LOG.debug(
"Failing work "
+ processingContext.computationId()
+ " "
+ shardedKey
+ " "
+ id.workToken()
+ " "
+ id.cacheToken()
+ ". The work will be retried and is not lost.");
this.isFailed = true;
}

Expand All @@ -221,6 +240,7 @@ public WindmillStateReader createWindmillStateReader() {
return WindmillStateReader.forWork(this);
}

@Override
public WorkId id() {
return id;
}
Expand All @@ -232,6 +252,7 @@ private void recordGetWorkStreamLatencies(Collection<LatencyAttribution> getWork
}
}

@Override
public ImmutableList<LatencyAttribution> getLatencyAttributions(
boolean isHeartbeat, DataflowExecutionStateSampler sampler) {
return Arrays.stream(LatencyAttribution.State.values())
Expand Down Expand Up @@ -272,6 +293,7 @@ private LatencyAttribution createLatencyAttribution(
.build();
}

@Override
public boolean isFailed() {
return isFailed;
}
Expand All @@ -281,15 +303,9 @@ boolean isStuckCommittingAt(Instant stuckCommitDeadline) {
&& currentState.startTime().isBefore(stuckCommitDeadline);
}

/** Returns a read-only snapshot of this {@link Work} instance's state for work refreshing. */
RefreshableView refreshableView(DataflowExecutionStateSampler sampler) {
return RefreshableView.builder()
.setWorkId(id)
.setHeartbeatSender(heartbeatSender())
.setIsFailed(isFailed)
.setIsRefreshable(this::isRefreshable)
.setLatencyAttributions(getLatencyAttributions(/* isHeartbeat= */ true, sampler))
.build();
/** Returns a view of this {@link Work} instance for work refreshing. */
public RefreshableWork refreshableView() {
return this;
}

public enum State {
Expand Down Expand Up @@ -344,11 +360,13 @@ private static ProcessingContext create(
BiFunction<String, KeyedGetDataRequest, KeyedGetDataResponse> getKeyedDataFn,
Consumer<Commit> workCommitter,
HeartbeatSender heartbeatSender) {
return new AutoValue_Work_ProcessingContext(
computationId,
request -> Optional.ofNullable(getKeyedDataFn.apply(computationId, request)),
workCommitter,
heartbeatSender);
return new AutoValue_Work_ProcessingContext.Builder()
.setComputationId(computationId)
.setHeartbeatSender(heartbeatSender)
.setWorkCommitter(workCommitter)
.setKeyedDataFetcher(
request -> Optional.ofNullable(getKeyedDataFn.apply(computationId, request)))
.build();
}

/** Computation that the {@link Work} belongs to. */
Expand All @@ -365,50 +383,21 @@ private static ProcessingContext create(
public abstract Consumer<Commit> workCommitter();

public abstract HeartbeatSender heartbeatSender();
}

@AutoValue
public abstract static class RefreshableView {

private static RefreshableView.Builder builder() {
return new AutoValue_Work_RefreshableView.Builder();
}

abstract WorkId workId();

public final long workToken() {
return workId().workToken();
}

public final long cacheToken() {
return workId().cacheToken();
}

abstract Function<Instant, Boolean> isRefreshable();

public final boolean isRefreshable(Instant refreshDeadline) {
return isRefreshable().apply(refreshDeadline);
}

public abstract HeartbeatSender heartbeatSender();

public abstract boolean isFailed();

public abstract ImmutableList<LatencyAttribution> latencyAttributions();
abstract Builder toBuilder();

@AutoValue.Builder
abstract static class Builder {
abstract Builder setWorkId(WorkId value);
abstract Builder setComputationId(String value);

abstract Builder setIsRefreshable(Function<Instant, Boolean> value);
abstract Builder setKeyedDataFetcher(
Function<KeyedGetDataRequest, Optional<KeyedGetDataResponse>> value);

abstract Builder setHeartbeatSender(HeartbeatSender value);
abstract Builder setWorkCommitter(Consumer<Commit> value);

abstract Builder setIsFailed(boolean value);

abstract Builder setLatencyAttributions(ImmutableList<LatencyAttribution> value);
abstract Builder setHeartbeatSender(HeartbeatSender value);

abstract RefreshableView build();
abstract ProcessingContext build();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ public static WorkId of(Windmill.WorkItem workItem) {
.build();
}

abstract long cacheToken();
public abstract long cacheToken();

abstract long workToken();
public abstract long workToken();

@AutoValue.Builder
public abstract static class Builder {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public WindmillStreamClosedException(String message) {
}

/** Returns whether an exception was caused by a {@link WindmillStreamClosedException}. */
public static boolean isWindmillStreamCancelledException(Throwable t) {
public static boolean wasCauseOf(Throwable t) {
while (t != null) {
if (t instanceof WindmillStreamClosedException) {
return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ private void drainCommitQueue() {
}

private void failCommit(Commit commit) {
commit.work().setFailed();
commit.work().fail();
onCommitComplete.accept(CompleteCommit.forFailedWork(commit));
}

Expand Down
Loading

0 comments on commit 2ba55f7

Please sign in to comment.