Skip to content

Commit

Permalink
Merge pull request apache#50 from bsidhom/fix-batch-state-handler
Browse files Browse the repository at this point in the history
Fix compile errors after refactoring side inputs in ExecutableStages.
  • Loading branch information
bsidhom authored Mar 30, 2018
2 parents 0d7c75b + fff4fa7 commit 83eff8f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,14 @@ private <InputT> void translateExecutableStage(
function,
transform.getUniqueName());

for (Map.Entry<String, String> sideInput : stagePayload.getSideInputsMap().entrySet()) {
// register under the global PCollection name, only ExecutableStageFunction needs to
// know the mapping from local name to global name
for (RunnerApi.ExecutableStagePayload.SideInputId sideInputId : stagePayload.getSideInputsList()) {
String collectionId = components.getTransformsOrThrow(sideInputId.getTransformId())
.getInputsOrThrow(sideInputId.getLocalName());
// Register under the global PCollection name. Only ExecutableStageFunction needs to know the
// mapping from local name to global name and how to translate the broadcast data to a state
// API view.
taggedDataset.withBroadcastSet(
context.getDataSetOrThrow(sideInput.getValue()), sideInput.getValue());
context.getDataSetOrThrow(collectionId), collectionId);
}

for (String collectionId : outputs.values()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,31 @@
*/
package org.apache.beam.runners.flink.translation.functions;

import static com.google.common.base.Preconditions.checkState;
import static org.apache.beam.runners.core.construction.UrnUtils.validateCommonUrn;

import com.google.common.base.Preconditions;
import com.google.protobuf.ByteString;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import org.apache.beam.collect.ImmutableMap;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse.Builder;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.Components;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.ExecutableStagePayload.SideInputId;
import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec;
import org.apache.beam.model.pipeline.v1.RunnerApi.PCollection;
import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec;
import org.apache.beam.runners.core.construction.CoderTranslation;
import org.apache.beam.runners.core.construction.RehydratedComponents;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.core.construction.graph.PipelineNode;
import org.apache.beam.runners.core.construction.graph.PipelineNode.PCollectionNode;
import org.apache.beam.runners.core.construction.graph.SideInputReference;
import org.apache.beam.runners.fnexecution.state.StateRequestHandler;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
Expand All @@ -47,14 +56,30 @@
*/
class FlinkBatchStateRequestHandler implements StateRequestHandler {

private final RunnerApi.ExecutableStagePayload payload;
private final RunnerApi.Components components;
// Map from side input id to global PCollection id.
private final Map<SideInputId, PCollectionNode> sideInputToCollection;
private final Components components;
private final RuntimeContext runtimeContext;

public FlinkBatchStateRequestHandler(
ExecutableStagePayload payload,
Components components, RuntimeContext runtimeContext) {
this.payload = payload;
public static FlinkBatchStateRequestHandler forStage(ExecutableStage stage, Components components,
RuntimeContext runtimeContext) {
ImmutableMap.Builder<SideInputId, PCollectionNode> sideInputBuilder = ImmutableMap.builder();
for (SideInputReference sideInput : stage.getSideInputReferences()) {
sideInputBuilder.put(
SideInputId.newBuilder()
.setTransformId(sideInput.transformId())
.setLocalName(sideInput.localName())
.build(),
sideInput.getCollection());
}
return new FlinkBatchStateRequestHandler(sideInputBuilder.build(), components, runtimeContext);
}

private FlinkBatchStateRequestHandler(
Map<SideInputId, PCollectionNode> sideInputToCollection,
Components components,
RuntimeContext runtimeContext) {
this.sideInputToCollection = sideInputToCollection;
this.components = components;
this.runtimeContext = runtimeContext;
}
Expand All @@ -69,14 +94,18 @@ public CompletionStage<Builder> handle(
BeamFnApi.StateKey.MultimapSideInput multimapSideInput =
request.getStateKey().getMultimapSideInput();

String sideInputKey =
multimapSideInput.getSideInputId() + multimapSideInput.getPtransformId();
String globalPCollectionName =
payload.getSideInputsOrThrow(sideInputKey);
String transformId = multimapSideInput.getPtransformId();
String localName = multimapSideInput.getSideInputId();

List<Object> broadcastVariable = runtimeContext.getBroadcastVariable(globalPCollectionName);
PCollectionNode collectionNode = sideInputToCollection.get(
SideInputId.newBuilder()
.setTransformId(transformId)
.setLocalName(localName)
.build());
checkState(collectionNode != null, "No side input for %s/%s", transformId, localName);
List<Object> broadcastVariable = runtimeContext.getBroadcastVariable(collectionNode.getId());

PCollection pCollection = components.getPcollectionsOrThrow(globalPCollectionName);
PCollection pCollection = collectionNode.getPCollection();

String windowingStrategyId = pCollection.getWindowingStrategyId();
String windowCoderId =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class FlinkExecutableStageFunction<InputT> extends
private final Map<String, Integer> outputMap;
private final Struct pipelineOptions;

private transient ExecutableStage executableStage;
private transient EnvironmentSession session;
private transient SdkHarnessClient client;
private transient ExecutableProcessBundleDescriptor processBundleDescriptor;
Expand All @@ -83,7 +84,7 @@ public FlinkExecutableStageFunction(

@Override
public void open(Configuration parameters) throws Exception {
ExecutableStage stage = ExecutableStage.fromPayload(payload, components);
executableStage = ExecutableStage.fromPayload(payload, components);
SdkHarnessManager manager = SingletonSdkHarnessManager.getInstance();
ProvisionApi.ProvisionInfo provisionInfo = ProvisionApi.ProvisionInfo.newBuilder()
// TODO: Set this from job metadata.
Expand All @@ -100,7 +101,7 @@ public void open(Configuration parameters) throws Exception {
logger.info(String.format("Data endpoint: %s", dataEndpoint.getUrl()));
String id = new BigInteger(32, ThreadLocalRandom.current()).toString(36);
processBundleDescriptor = ProcessBundleDescriptors.fromExecutableStage(
id, stage, components, dataEndpoint, stateEndpoint);
id, executableStage, components, dataEndpoint, stateEndpoint);
logger.info(String.format("Process bundle descriptor: %s", processBundleDescriptor));
}

Expand Down Expand Up @@ -177,7 +178,7 @@ public void accept(WindowedValue<?> input) throws Exception {
receiverBuilder.build();

StateRequestHandler stateRequestHandler =
new FlinkBatchStateRequestHandler(payload, components, getRuntimeContext());
FlinkBatchStateRequestHandler.forStage(executableStage, components, getRuntimeContext());

try (SdkHarnessClient.ActiveBundle<InputT> bundle =
processor.newBundle(receiverMap, stateRequestHandler)) {
Expand Down

0 comments on commit 83eff8f

Please sign in to comment.