diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java index 33f67c92c7004..bad19da608548 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java @@ -20,7 +20,8 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import com.google.common.collect.Iterables; +import com.google.common.collect.BiMap; +import com.google.common.collect.ImmutableBiMap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.io.IOException; @@ -30,6 +31,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.TreeMap; import javax.annotation.Nullable; import org.apache.beam.model.pipeline.v1.RunnerApi; import org.apache.beam.runners.core.construction.CombineTranslation; @@ -37,12 +39,12 @@ import org.apache.beam.runners.core.construction.ExecutableStageTranslation; import org.apache.beam.runners.core.construction.PTransformTranslation; import org.apache.beam.runners.core.construction.ParDoTranslation; -import org.apache.beam.runners.core.construction.PipelineTranslation; import org.apache.beam.runners.core.construction.ReadTranslation; import org.apache.beam.runners.core.construction.graph.ExecutableStage; import org.apache.beam.runners.flink.translation.functions.FlinkAssignWindows; import org.apache.beam.runners.flink.translation.functions.FlinkDoFnFunction; import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageFunction; +import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStagePruningFunction; import org.apache.beam.runners.flink.translation.functions.FlinkMergingNonShuffleReduceFunction; import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputPruningFunction; import org.apache.beam.runners.flink.translation.functions.FlinkPartialReduceFunction; @@ -689,24 +691,29 @@ private static class ExecutableStageTranslatorBatch, PCollection> transform, FlinkBatchTranslationContext context) { + // TODO: Fail on stateful DoFns for now. + // TODO: Support stateful DoFns by inserting group-by-keys where necessary. + // TODO: Fail on splittable DoFns. + // TODO: Special-case single outputs to avoid multiplexing PCollections. - TypeInformation> typeInformation; - - if (context.getCurrentTransform().getOutputs().size() == 0) { - typeInformation = new CoderTypeInformation(VoidCoder.of()); - } else { - // TODO: Assert that all inputs and outputs are PCollections. - // TODO: Assert only a single output collection. - @SuppressWarnings("unchecked") - PCollection output = (PCollection) - Iterables.getOnlyElement(context.getCurrentTransform().getOutputs().values()); - output.getCoder(); - typeInformation = - new CoderTypeInformation<>( - WindowedValue.getFullCoder( - output.getCoder(), - output.getWindowingStrategy().getWindowFn().windowCoder())); + Map, PValue> outputs = context.getOutputs(transform); + BiMap, Integer> outputMap = createOutputMap(outputs.keySet()); + // Collect all output Coders and create a UnionCoder for our tagged outputs. + List> outputCoders = Lists.newArrayList(); + // Enforce tuple tag sorting by union tag index. + for (TupleTag tag : new TreeMap<>(outputMap.inverse()).values()) { + PValue taggedValue = outputs.get(tag); + checkState( + taggedValue instanceof PCollection, + "Within ParDo, got a non-PCollection output %s of type %s", + taggedValue, + taggedValue.getClass().getSimpleName()); + PCollection coll = (PCollection) taggedValue; + outputCoders.add(coll.getCoder()); } + UnionCoder unionCoder = UnionCoder.of(outputCoders); + TypeInformation typeInformation = + new CoderTypeInformation<>(unionCoder); RunnerApi.ExecutableStagePayload stagePayload; try { @@ -715,22 +722,50 @@ public void translateNode(PTransform, PCollection> } catch (IOException e) { throw new RuntimeException(e); } - FlinkExecutableStageFunction function = - new FlinkExecutableStageFunction<>(stagePayload, context.getComponents()); + FlinkExecutableStageFunction function = + new FlinkExecutableStageFunction<>(stagePayload, context.getComponents(), + outputMap); DataSet> inputDataSet = context.getInputDataSet(context.getInput(transform)); - DataSet> outputDataset = - new MapPartitionOperator<>( - inputDataSet, + DataSet taggedDataset = + new MapPartitionOperator<>(inputDataSet, typeInformation, function, transform.getName()); - if (context.getCurrentTransform().getOutputs().size() > 0) { - context.setOutputDataSet(context.getOutput(transform), outputDataset); - } else { - outputDataset.output(new DiscardingOutputFormat()); + for (Entry, PValue> output : outputs.entrySet()) { + pruneOutput(taggedDataset, context, outputMap.get(output.getKey()), + (PCollection) output.getValue()); + } + if (context.getCurrentTransform().getOutputs().isEmpty()) { + // TODO: Is this still necessary? + taggedDataset.output(new DiscardingOutputFormat()); + } + } + + private static void pruneOutput( + DataSet taggedDataset, + FlinkBatchTranslationContext context, + int unionTag, + PCollection collection) { + TypeInformation> outputType = context.getTypeInfo(collection); + FlinkExecutableStagePruningFunction pruningFunction = + new FlinkExecutableStagePruningFunction<>(unionTag); + FlatMapOperator> pruningOperator = + new FlatMapOperator<>(taggedDataset, outputType, pruningFunction, collection.getName()); + context.setOutputDataSet(collection, pruningOperator); + } + + private static BiMap, Integer> createOutputMap(Iterable> tags) { + ImmutableBiMap.Builder, Integer> builder = ImmutableBiMap.builder(); + int outputIndex = 0; + for (TupleTag tag : tags) { + builder.put(tag, outputIndex); + outputIndex++; } + return builder.build(); } + + } private static class FlattenPCollectionTranslatorBatch diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java index db07bb6912998..d7f6322ffa8d4 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStageFunction.java @@ -43,29 +43,35 @@ import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver; import org.apache.beam.sdk.fn.data.FnDataReceiver; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.util.MoreFutures; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.TupleTag; import org.apache.flink.api.common.functions.RichMapPartitionFunction; import org.apache.flink.configuration.Configuration; import org.apache.flink.util.Collector; /** ExecutableStage operator. */ -public class FlinkExecutableStageFunction extends - RichMapPartitionFunction, WindowedValue> { +public class FlinkExecutableStageFunction extends + RichMapPartitionFunction, RawUnionValue> { private static final Logger logger = Logger.getLogger(FlinkExecutableStageFunction.class.getName()); private final RunnerApi.ExecutableStagePayload payload; private final RunnerApi.Components components; + private final Map, Integer> outputMap; private transient EnvironmentSession session; private transient SdkHarnessClient client; private transient ProcessBundleDescriptors.SimpleProcessBundleDescriptor processBundleDescriptor; public FlinkExecutableStageFunction(RunnerApi.ExecutableStagePayload payload, - RunnerApi.Components components) { + RunnerApi.Components components, + Map, Integer> outputMap) { this.payload = payload; this.components = components; + this.outputMap = outputMap; } @Override @@ -97,7 +103,7 @@ public void open(Configuration parameters) throws Exception { @Override public void mapPartition(Iterable> input, - Collector> collector) throws Exception { + Collector collector) throws Exception { checkState(client != null, "SDK client not prepared"); checkState(processBundleDescriptor != null, "ProcessBundleDescriptor not prepared"); @@ -110,49 +116,67 @@ public void mapPartition(Iterable> input, SdkHarnessClient.BundleProcessor processor = client.getProcessor( processBundleDescriptor.getProcessBundleDescriptor(), destination); processor.getRegistrationFuture().toCompletableFuture().get(); - // TODO: Support multiple output receivers and redirect them properly. Map>> outputCoders = processBundleDescriptor.getOutputTargetCoders(); - BeamFnApi.Target outputTarget = null; - SdkHarnessClient.RemoteOutputReceiver> mainOutputReceiver = null; - if (outputCoders.size() > 0) { - outputTarget = Iterables.getOnlyElement(outputCoders.keySet()); - Coder outputCoder = Iterables.getOnlyElement(outputCoders.values()); - mainOutputReceiver = - new SdkHarnessClient.RemoteOutputReceiver>() { - @Override - public Coder> getCoder() { - @SuppressWarnings("unchecked") - Coder> result = (Coder>) outputCoder; - return result; - } + ImmutableMap.Builder> receiverBuilder = ImmutableMap.builder(); + final Object collectorLock = new Object(); + for (Map.Entry>> entry : outputCoders.entrySet()) { + BeamFnApi.Target target = entry.getKey(); + Coder> coder = entry.getValue(); + SdkHarnessClient.RemoteOutputReceiver> receiver = new SdkHarnessClient.RemoteOutputReceiver>() { + @Override + public Coder> getCoder() { + return coder; + } + @Override + public FnDataReceiver> getReceiver() { + return new FnDataReceiver>() { @Override - public FnDataReceiver> getReceiver() { - return new FnDataReceiver>() { - @Override - public void accept(WindowedValue input) throws Exception { - collector.collect(input); - } - }; + public void accept(WindowedValue input) throws Exception { + logger.finer(String.format("Receiving value: %s", input)); + // TODO: Can this be called by multiple threads? Are calls guaranteed to at least be + // serial? If not, these calls may need to be synchronized. + // TODO: If this needs to be synchronized, consider requiring immutable maps. + synchronized (collectorLock) { + // TODO: Get the correct output union tag from the corresponding output tag. + // Plumb through output tags from process bundle descriptor. + // NOTE: The output map is guaranteed to be non-empty at this point, so we can + // always grab index 0. + // TODO: Plumb through TupleTag <-> Target mappings to get correct union tag here. + // For now, assume only one union tag. + int unionTag = Iterables.getOnlyElement(outputMap.values()); + collector.collect(new RawUnionValue(unionTag, input)); + } } }; + } + }; + receiverBuilder.put(target, receiver); } Map> receiverMap; - if (outputTarget == null) { - receiverMap = ImmutableMap.of(); - } else { - receiverMap = ImmutableMap.of(outputTarget, mainOutputReceiver); - } + SdkHarnessClient.RemoteOutputReceiver> receiverMap = + receiverBuilder.build(); SdkHarnessClient.ActiveBundle bundle = processor.newBundle(receiverMap); try (CloseableFnDataReceiver> inputReceiver = bundle.getInputReceiver()) { for (WindowedValue value : input) { + logger.finer(String.format("Sending value: %s", value)); inputReceiver.accept(value); } } - bundle.getBundleResponse().toCompletableFuture().get(); + + // Await all outputs and active bundle completion. This is necessary because the Flink collector + // must not be accessed outside of mapPartition. + bundle.getOutputClients().values().forEach(client -> { + try { + client.awaitCompletion(); + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + MoreFutures.get(bundle.getBundleResponse()); } @Override diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStagePruningFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStagePruningFunction.java new file mode 100644 index 0000000000000..0f8b0e040ecc3 --- /dev/null +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkExecutableStagePruningFunction.java @@ -0,0 +1,23 @@ +package org.apache.beam.runners.flink.translation.functions; + +import org.apache.beam.sdk.transforms.join.RawUnionValue; +import org.apache.beam.sdk.util.WindowedValue; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.util.Collector; + +public class FlinkExecutableStagePruningFunction + implements FlatMapFunction> { + + private final int unionTag; + + public FlinkExecutableStagePruningFunction(int unionTag) { + this.unionTag = unionTag; + } + + @Override + public void flatMap(RawUnionValue rawUnionValue, Collector> collector) throws Exception { + if (rawUnionValue.getUnionTag() == unionTag) { + collector.collect((WindowedValue) rawUnionValue.getValue()); + } + } +}