diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index 5ff94884e974..f10f3b91e7aa 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -73,7 +73,6 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; -import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Supplier; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.FluentIterable; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable; @@ -129,7 +128,7 @@ public class StreamingModeExecutionContext extends DataflowExecutionContext SideInput fetchSideInput( return fetchSideInputFromWindmill( view, sideInputWindow, - Preconditions.checkNotNull(stateFamily), + checkNotNull(stateFamily), state, - Preconditions.checkNotNull(scopedReadStateSupplier), + checkNotNull(scopedReadStateSupplier), tagCache); } @@ -329,15 +328,15 @@ private SideInput fetchSideInputFromWindmill( } public Iterable getSideInputNotifications() { - return work.getWorkItem().getGlobalDataIdNotificationsList(); + return getWorkItem().getGlobalDataIdNotificationsList(); } private List getFiredTimers() { - return work.getWorkItem().getTimers().getTimersList(); + return getWorkItem().getTimers().getTimersList(); } public @Nullable ByteString getSerializedKey() { - return work.getWorkItem().getKey(); + return work == null ? null : work.getWorkItem().getKey(); } public WindmillComputationKey getComputationKey() { @@ -345,11 +344,15 @@ public WindmillComputationKey getComputationKey() { } public long getWorkToken() { - return work.getWorkItem().getWorkToken(); + return getWorkItem().getWorkToken(); } public Windmill.WorkItem getWorkItem() { - return work.getWorkItem(); + return checkNotNull( + work, + "work is null. A call to StreamingModeExecutionContext.start(...) is required to set" + + " work for execution.") + .getWorkItem(); } public Windmill.WorkItemCommitRequest.Builder getOutputBuilder() { @@ -390,7 +393,7 @@ public void invalidateCache() { public UnboundedSource.@Nullable CheckpointMark getReaderCheckpoint( Coder coder) { try { - ByteString sourceStateState = work.getWorkItem().getSourceState().getState(); + ByteString sourceStateState = getWorkItem().getSourceState().getState(); if (sourceStateState.isEmpty()) { return null; } @@ -737,7 +740,7 @@ public void start( key, stateFamily, stateReader, - work.getWorkItem().getIsNewKey(), + getWorkItem().getIsNewKey(), cacheForKey.forFamily(stateFamily), scopedReadStateSupplier); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutorTest.java new file mode 100644 index 000000000000..9146ad02fddd --- /dev/null +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/streaming/ComputationWorkExecutorTest.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.runners.dataflow.worker.streaming; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.beam.runners.core.metrics.ExecutionStateTracker; +import org.apache.beam.runners.dataflow.worker.DataflowWorkExecutor; +import org.apache.beam.runners.dataflow.worker.StreamingModeExecutionContext; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ComputationWorkExecutorTest { + + private final DataflowWorkExecutor dataflowWorkExecutor = mock(DataflowWorkExecutor.class); + private final StreamingModeExecutionContext context = mock(StreamingModeExecutionContext.class); + private ComputationWorkExecutor computationWorkExecutor; + + @Before + public void setUp() { + computationWorkExecutor = + ComputationWorkExecutor.builder() + .setWorkExecutor(dataflowWorkExecutor) + .setContext(context) + .setExecutionStateTracker(mock(ExecutionStateTracker.class)) + .build(); + } + + @Test + public void testInvalidate_withoutCallToStart() { + // Call to invalidate w/o a call to start should not fail. + computationWorkExecutor.invalidate(); + } + + @Test + public void testInvalidate_handlesException() { + AtomicBoolean verifyContextInvalidated = new AtomicBoolean(false); + Throwable e = new RuntimeException("something bad happened 2"); + doThrow(e).when(dataflowWorkExecutor).close(); + doAnswer( + ignored -> { + verifyContextInvalidated.set(true); + return null; + }) + .when(context) + .invalidateCache(); + computationWorkExecutor.invalidate(); + assertTrue(verifyContextInvalidated.get()); + } +}