Skip to content

Commit

Permalink
Portable runner fixes (#32247)
Browse files Browse the repository at this point in the history
* Change jobIdBytes to String

* Store latest State in an AtomicReference

* Store latest metrics in an AtomicReference

* Pass job service blocking stub without close wrapper

* Remove use of CloseableResource from JobServicePipelineResult
  • Loading branch information
damondouglas authored Aug 20, 2024
1 parent 6582e7a commit bd65ee9
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 183 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,196 +17,182 @@
*/
package org.apache.beam.runners.portability;

import java.util.Iterator;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.beam.model.jobmanagement.v1.JobApi;
import org.apache.beam.model.jobmanagement.v1.JobApi.CancelJobRequest;
import org.apache.beam.model.jobmanagement.v1.JobApi.CancelJobResponse;
import org.apache.beam.model.jobmanagement.v1.JobApi.GetJobStateRequest;
import org.apache.beam.model.jobmanagement.v1.JobApi.JobMessage;
import org.apache.beam.model.jobmanagement.v1.JobApi.JobMessagesRequest;
import org.apache.beam.model.jobmanagement.v1.JobApi.JobMessagesResponse;
import org.apache.beam.model.jobmanagement.v1.JobApi.JobStateEvent;
import org.apache.beam.model.jobmanagement.v1.JobServiceGrpc.JobServiceBlockingStub;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.metrics.MetricResults;
import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListenableFuture;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.ListeningScheduledExecutorService;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.util.concurrent.MoreExecutors;
import org.joda.time.Duration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@SuppressWarnings({"keyfor", "nullness"}) // TODO(https://github.com/apache/beam/issues/20497)
class JobServicePipelineResult implements PipelineResult, AutoCloseable {

private static final long POLL_INTERVAL_MS = 3_000;

private static final Logger LOG = LoggerFactory.getLogger(JobServicePipelineResult.class);
private final ListeningScheduledExecutorService executorService =
MoreExecutors.listeningDecorator(Executors.newSingleThreadScheduledExecutor());

private final String jobId;
private final JobServiceBlockingStub jobService;
private final AtomicReference<State> latestState = new AtomicReference<>(State.UNKNOWN);
private final Runnable cleanup;
private final AtomicReference<PortableMetrics> jobMetrics =
new AtomicReference<>(PortableMetrics.of(JobApi.MetricResults.getDefaultInstance()));
private CompletableFuture<State> terminalStateFuture = new CompletableFuture<>();
private CompletableFuture<MetricResults> metricResultsCompletableFuture =
new CompletableFuture<>();

private final ByteString jobId;
private final int jobServerTimeout;
private final CloseableResource<JobServiceBlockingStub> jobService;
private @Nullable State terminalState;
private final @Nullable Runnable cleanup;
private org.apache.beam.model.jobmanagement.v1.JobApi.MetricResults jobMetrics;

JobServicePipelineResult(
ByteString jobId,
int jobServerTimeout,
CloseableResource<JobServiceBlockingStub> jobService,
Runnable cleanup) {
JobServicePipelineResult(String jobId, JobServiceBlockingStub jobService, Runnable cleanup) {
this.jobId = jobId;
this.jobServerTimeout = jobServerTimeout;
this.jobService = jobService;
this.terminalState = null;
this.cleanup = cleanup;
}

@Override
public State getState() {
if (terminalState != null) {
return terminalState;
if (latestState.get().isTerminal()) {
return latestState.get();
}
JobServiceBlockingStub stub =
jobService.get().withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS);
JobStateEvent response =
stub.getState(GetJobStateRequest.newBuilder().setJobIdBytes(jobId).build());
return getJavaState(response.getState());
jobService.getState(GetJobStateRequest.newBuilder().setJobId(jobId).build());
State state = State.valueOf(response.getState().name());
latestState.set(state);
return state;
}

@Override
public State cancel() {
JobServiceBlockingStub stub = jobService.get();
if (latestState.get().isTerminal()) {
return latestState.get();
}
CancelJobResponse response =
stub.cancel(CancelJobRequest.newBuilder().setJobIdBytes(jobId).build());
return getJavaState(response.getState());
jobService.cancel(CancelJobRequest.newBuilder().setJobId(jobId).build());
State state = State.valueOf(response.getState().name());
latestState.set(state);
return state;
}

@Nullable
@Override
public State waitUntilFinish(Duration duration) {
if (duration.compareTo(Duration.millis(1)) <= 0) {
// Equivalent to infinite timeout.
return waitUntilFinish();
} else {
CompletableFuture<State> result = CompletableFuture.supplyAsync(this::waitUntilFinish);
try {
return result.get(duration.getMillis(), TimeUnit.MILLISECONDS);
} catch (TimeoutException e) {
// Null result indicates a timeout.
return null;
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
if (latestState.get().isTerminal()) {
return latestState.get();
}
try {
return pollForTerminalState().get(duration.getMillis(), TimeUnit.MILLISECONDS);
} catch (InterruptedException | ExecutionException | TimeoutException e) {
throw new RuntimeException(e);
}
}

@Override
public State waitUntilFinish() {
if (terminalState != null) {
return terminalState;
if (latestState.get().isTerminal()) {
return latestState.get();
}
try {
waitForTerminalState();
propagateErrors();
return terminalState;
} finally {
close();
return pollForTerminalState().get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

@Override
public MetricResults metrics() {
return PortableMetrics.of(jobMetrics);
void setTerminalStateFuture(CompletableFuture<State> terminalStateFuture) {
this.terminalStateFuture = terminalStateFuture;
}

@Override
public void close() {
try (CloseableResource<JobServiceBlockingStub> jobService = this.jobService) {
JobApi.GetJobMetricsRequest metricsRequest =
JobApi.GetJobMetricsRequest.newBuilder().setJobIdBytes(jobId).build();
jobMetrics = jobService.get().getJobMetrics(metricsRequest).getMetrics();
if (cleanup != null) {
cleanup.run();
}
} catch (Exception e) {
LOG.warn("Error cleaning up job service", e);
}
CompletableFuture<State> getTerminalStateFuture() {
return this.terminalStateFuture;
}

private void waitForTerminalState() {
JobServiceBlockingStub stub =
jobService.get().withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS);
GetJobStateRequest request = GetJobStateRequest.newBuilder().setJobIdBytes(jobId).build();
JobStateEvent response = stub.getState(request);
State lastState = getJavaState(response.getState());
while (!lastState.isTerminal()) {
try {
Thread.sleep(POLL_INTERVAL_MS);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
response = stub.withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS).getState(request);
lastState = getJavaState(response.getState());
}
terminalState = lastState;
void setMetricResultsCompletableFuture(
CompletableFuture<MetricResults> metricResultsCompletableFuture) {
this.metricResultsCompletableFuture = metricResultsCompletableFuture;
}

private void propagateErrors() {
if (terminalState != State.DONE) {
JobMessagesRequest messageStreamRequest =
JobMessagesRequest.newBuilder().setJobIdBytes(jobId).build();
Iterator<JobMessagesResponse> messageStreamIterator =
jobService
.get()
.withDeadlineAfter(jobServerTimeout, TimeUnit.SECONDS)
.getMessageStream(messageStreamRequest);
while (messageStreamIterator.hasNext()) {
JobMessage messageResponse = messageStreamIterator.next().getMessageResponse();
if (messageResponse.getImportance() == JobMessage.MessageImportance.JOB_MESSAGE_ERROR) {
throw new RuntimeException(
"The Runner experienced the following error during execution:\n"
+ messageResponse.getMessageText());
}
}
}
CompletableFuture<MetricResults> getMetricResultsCompletableFuture() {
return this.metricResultsCompletableFuture;
}

private static State getJavaState(JobApi.JobState.Enum protoState) {
switch (protoState) {
case UNSPECIFIED:
return State.UNKNOWN;
case STOPPED:
return State.STOPPED;
case RUNNING:
return State.RUNNING;
case DONE:
return State.DONE;
case FAILED:
return State.FAILED;
case CANCELLED:
return State.CANCELLED;
case UPDATED:
return State.UPDATED;
case DRAINING:
// TODO: Determine the correct mappings for the states below.
return State.UNKNOWN;
case DRAINED:
return State.UNKNOWN;
case STARTING:
return State.RUNNING;
case CANCELLING:
return State.CANCELLED;
default:
LOG.warn("Unrecognized state from server: {}", protoState);
return State.UNKNOWN;
}
CompletableFuture<State> pollForTerminalState() {
CompletableFuture<State> completableFuture = new CompletableFuture<>();
ListenableFuture<?> future =
executorService.scheduleAtFixedRate(
() -> {
State state = getState();
LOG.info("Job: {} latest state: {}", jobId, state);
latestState.set(state);
if (state.isTerminal()) {
completableFuture.complete(state);
}
},
0L,
POLL_INTERVAL_MS,
TimeUnit.MILLISECONDS);
return completableFuture.whenComplete(
(state, throwable) -> {
checkState(
state.isTerminal(),
"future should have completed with a terminal state, got: %s",
state);
future.cancel(true);
LOG.info("Job: {} reached terminal state: {}", jobId, state);
if (throwable != null) {
throw new RuntimeException(throwable);
}
});
}

CompletableFuture<MetricResults> pollForMetrics() {
CompletableFuture<MetricResults> completableFuture = new CompletableFuture<>();
ListenableFuture<?> future =
executorService.scheduleAtFixedRate(
() -> {
if (latestState.get().isTerminal()) {
completableFuture.complete(jobMetrics.get());
return;
}
JobApi.GetJobMetricsRequest metricsRequest =
JobApi.GetJobMetricsRequest.newBuilder().setJobId(jobId).build();
JobApi.MetricResults results = jobService.getJobMetrics(metricsRequest).getMetrics();
jobMetrics.set(PortableMetrics.of(results));
},
0L,
1L,
TimeUnit.SECONDS);
return completableFuture.whenComplete(
((metricResults, throwable) -> {
checkState(
latestState.get().isTerminal(),
"future should have completed with a terminal state, got: %s",
latestState.get());
future.cancel(true);
LOG.info("Job: {} latest metrics: {}", jobId, metricResults.toString());
}));
}

@Override
public MetricResults metrics() {
return jobMetrics.get();
}

@Override
public void close() {
cleanup.run();
}
}
Loading

0 comments on commit bd65ee9

Please sign in to comment.