Skip to content

Commit

Permalink
Merge pull request apache#17 from bsidhom/multi-output-stage
Browse files Browse the repository at this point in the history
Add support for multiple outputs from executable stages
  • Loading branch information
axelmagn authored Mar 16, 2018
2 parents a2fff55 + 4b098ac commit 5a3e885
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;

import com.google.common.collect.Iterables;
import com.google.common.collect.BiMap;
import com.google.common.collect.ImmutableBiMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.IOException;
Expand All @@ -30,19 +31,20 @@
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.CombineTranslation;
import org.apache.beam.runners.core.construction.CreatePCollectionViewTranslation;
import org.apache.beam.runners.core.construction.ExecutableStageTranslation;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.ReadTranslation;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
import org.apache.beam.runners.flink.translation.functions.FlinkAssignWindows;
import org.apache.beam.runners.flink.translation.functions.FlinkDoFnFunction;
import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStageFunction;
import org.apache.beam.runners.flink.translation.functions.FlinkExecutableStagePruningFunction;
import org.apache.beam.runners.flink.translation.functions.FlinkMergingNonShuffleReduceFunction;
import org.apache.beam.runners.flink.translation.functions.FlinkMultiOutputPruningFunction;
import org.apache.beam.runners.flink.translation.functions.FlinkPartialReduceFunction;
Expand Down Expand Up @@ -689,24 +691,29 @@ private static class ExecutableStageTranslatorBatch<InputT,
@Override
public void translateNode(PTransform<PCollection<InputT>, PCollection<OutputT>> transform,
FlinkBatchTranslationContext context) {
// TODO: Fail on stateful DoFns for now.
// TODO: Support stateful DoFns by inserting group-by-keys where necessary.
// TODO: Fail on splittable DoFns.
// TODO: Special-case single outputs to avoid multiplexing PCollections.

TypeInformation<WindowedValue<OutputT>> typeInformation;

if (context.getCurrentTransform().getOutputs().size() == 0) {
typeInformation = new CoderTypeInformation(VoidCoder.of());
} else {
// TODO: Assert that all inputs and outputs are PCollections.
// TODO: Assert only a single output collection.
@SuppressWarnings("unchecked")
PCollection<OutputT> output = (PCollection<OutputT>)
Iterables.getOnlyElement(context.getCurrentTransform().getOutputs().values());
output.getCoder();
typeInformation =
new CoderTypeInformation<>(
WindowedValue.getFullCoder(
output.getCoder(),
output.getWindowingStrategy().getWindowFn().windowCoder()));
Map<TupleTag<?>, PValue> outputs = context.getOutputs(transform);
BiMap<TupleTag<?>, Integer> outputMap = createOutputMap(outputs.keySet());
// Collect all output Coders and create a UnionCoder for our tagged outputs.
List<Coder<?>> outputCoders = Lists.newArrayList();
// Enforce tuple tag sorting by union tag index.
for (TupleTag<?> tag : new TreeMap<>(outputMap.inverse()).values()) {
PValue taggedValue = outputs.get(tag);
checkState(
taggedValue instanceof PCollection,
"Within ParDo, got a non-PCollection output %s of type %s",
taggedValue,
taggedValue.getClass().getSimpleName());
PCollection<?> coll = (PCollection<?>) taggedValue;
outputCoders.add(coll.getCoder());
}
UnionCoder unionCoder = UnionCoder.of(outputCoders);
TypeInformation<RawUnionValue> typeInformation =
new CoderTypeInformation<>(unionCoder);

RunnerApi.ExecutableStagePayload stagePayload;
try {
Expand All @@ -715,22 +722,50 @@ public void translateNode(PTransform<PCollection<InputT>, PCollection<OutputT>>
} catch (IOException e) {
throw new RuntimeException(e);
}
FlinkExecutableStageFunction<InputT, OutputT> function =
new FlinkExecutableStageFunction<>(stagePayload, context.getComponents());
FlinkExecutableStageFunction<InputT> function =
new FlinkExecutableStageFunction<>(stagePayload, context.getComponents(),
outputMap);
DataSet<WindowedValue<InputT>> inputDataSet =
context.getInputDataSet(context.getInput(transform));
DataSet<WindowedValue<OutputT>> outputDataset =
new MapPartitionOperator<>(
inputDataSet,
DataSet<RawUnionValue> taggedDataset =
new MapPartitionOperator<>(inputDataSet,
typeInformation,
function,
transform.getName());
if (context.getCurrentTransform().getOutputs().size() > 0) {
context.setOutputDataSet(context.getOutput(transform), outputDataset);
} else {
outputDataset.output(new DiscardingOutputFormat());
for (Entry<TupleTag<?>, PValue> output : outputs.entrySet()) {
pruneOutput(taggedDataset, context, outputMap.get(output.getKey()),
(PCollection) output.getValue());
}
if (context.getCurrentTransform().getOutputs().isEmpty()) {
// TODO: Is this still necessary?
taggedDataset.output(new DiscardingOutputFormat());
}
}

private static <T> void pruneOutput(
DataSet<RawUnionValue> taggedDataset,
FlinkBatchTranslationContext context,
int unionTag,
PCollection<T> collection) {
TypeInformation<WindowedValue<T>> outputType = context.getTypeInfo(collection);
FlinkExecutableStagePruningFunction<T> pruningFunction =
new FlinkExecutableStagePruningFunction<>(unionTag);
FlatMapOperator<RawUnionValue, WindowedValue<T>> pruningOperator =
new FlatMapOperator<>(taggedDataset, outputType, pruningFunction, collection.getName());
context.setOutputDataSet(collection, pruningOperator);
}

private static BiMap<TupleTag<?>, Integer> createOutputMap(Iterable<TupleTag<?>> tags) {
ImmutableBiMap.Builder<TupleTag<?>, Integer> builder = ImmutableBiMap.builder();
int outputIndex = 0;
for (TupleTag<?> tag : tags) {
builder.put(tag, outputIndex);
outputIndex++;
}
return builder.build();
}


}

private static class FlattenPCollectionTranslatorBatch<T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,29 +43,35 @@
import org.apache.beam.sdk.fn.data.CloseableFnDataReceiver;
import org.apache.beam.sdk.fn.data.FnDataReceiver;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.util.MoreFutures;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.util.Collector;

/** ExecutableStage operator. */
public class FlinkExecutableStageFunction<InputT, OutputT> extends
RichMapPartitionFunction<WindowedValue<InputT>, WindowedValue<OutputT>> {
public class FlinkExecutableStageFunction<InputT> extends
RichMapPartitionFunction<WindowedValue<InputT>, RawUnionValue> {

private static final Logger logger =
Logger.getLogger(FlinkExecutableStageFunction.class.getName());

private final RunnerApi.ExecutableStagePayload payload;
private final RunnerApi.Components components;
private final Map<TupleTag<?>, Integer> outputMap;

private transient EnvironmentSession session;
private transient SdkHarnessClient client;
private transient ProcessBundleDescriptors.SimpleProcessBundleDescriptor processBundleDescriptor;

public FlinkExecutableStageFunction(RunnerApi.ExecutableStagePayload payload,
RunnerApi.Components components) {
RunnerApi.Components components,
Map<TupleTag<?>, Integer> outputMap) {
this.payload = payload;
this.components = components;
this.outputMap = outputMap;
}

@Override
Expand Down Expand Up @@ -97,7 +103,7 @@ public void open(Configuration parameters) throws Exception {

@Override
public void mapPartition(Iterable<WindowedValue<InputT>> input,
Collector<WindowedValue<OutputT>> collector) throws Exception {
Collector<RawUnionValue> collector) throws Exception {
checkState(client != null, "SDK client not prepared");
checkState(processBundleDescriptor != null,
"ProcessBundleDescriptor not prepared");
Expand All @@ -110,49 +116,67 @@ public void mapPartition(Iterable<WindowedValue<InputT>> input,
SdkHarnessClient.BundleProcessor<InputT> processor = client.getProcessor(
processBundleDescriptor.getProcessBundleDescriptor(), destination);
processor.getRegistrationFuture().toCompletableFuture().get();
// TODO: Support multiple output receivers and redirect them properly.
Map<BeamFnApi.Target, Coder<WindowedValue<?>>> outputCoders =
processBundleDescriptor.getOutputTargetCoders();
BeamFnApi.Target outputTarget = null;
SdkHarnessClient.RemoteOutputReceiver<WindowedValue<OutputT>> mainOutputReceiver = null;
if (outputCoders.size() > 0) {
outputTarget = Iterables.getOnlyElement(outputCoders.keySet());
Coder<?> outputCoder = Iterables.getOnlyElement(outputCoders.values());
mainOutputReceiver =
new SdkHarnessClient.RemoteOutputReceiver<WindowedValue<OutputT>>() {
@Override
public Coder<WindowedValue<OutputT>> getCoder() {
@SuppressWarnings("unchecked")
Coder<WindowedValue<OutputT>> result = (Coder<WindowedValue<OutputT>>) outputCoder;
return result;
}
ImmutableMap.Builder<BeamFnApi.Target,
SdkHarnessClient.RemoteOutputReceiver<?>> receiverBuilder = ImmutableMap.builder();
final Object collectorLock = new Object();
for (Map.Entry<BeamFnApi.Target, Coder<WindowedValue<?>>> entry : outputCoders.entrySet()) {
BeamFnApi.Target target = entry.getKey();
Coder<WindowedValue<?>> coder = entry.getValue();
SdkHarnessClient.RemoteOutputReceiver<WindowedValue<?>> receiver = new SdkHarnessClient.RemoteOutputReceiver<WindowedValue<?>>() {
@Override
public Coder<WindowedValue<?>> getCoder() {
return coder;
}

@Override
public FnDataReceiver<WindowedValue<?>> getReceiver() {
return new FnDataReceiver<WindowedValue<?>>() {
@Override
public FnDataReceiver<WindowedValue<OutputT>> getReceiver() {
return new FnDataReceiver<WindowedValue<OutputT>>() {
@Override
public void accept(WindowedValue<OutputT> input) throws Exception {
collector.collect(input);
}
};
public void accept(WindowedValue<?> input) throws Exception {
logger.finer(String.format("Receiving value: %s", input));
// TODO: Can this be called by multiple threads? Are calls guaranteed to at least be
// serial? If not, these calls may need to be synchronized.
// TODO: If this needs to be synchronized, consider requiring immutable maps.
synchronized (collectorLock) {
// TODO: Get the correct output union tag from the corresponding output tag.
// Plumb through output tags from process bundle descriptor.
// NOTE: The output map is guaranteed to be non-empty at this point, so we can
// always grab index 0.
// TODO: Plumb through TupleTag <-> Target mappings to get correct union tag here.
// For now, assume only one union tag.
int unionTag = Iterables.getOnlyElement(outputMap.values());
collector.collect(new RawUnionValue(unionTag, input));
}
}
};
}
};
receiverBuilder.put(target, receiver);
}
Map<BeamFnApi.Target,
SdkHarnessClient.RemoteOutputReceiver<?>> receiverMap;
if (outputTarget == null) {
receiverMap = ImmutableMap.of();
} else {
receiverMap = ImmutableMap.of(outputTarget, mainOutputReceiver);
}
SdkHarnessClient.RemoteOutputReceiver<?>> receiverMap =
receiverBuilder.build();

SdkHarnessClient.ActiveBundle<InputT> bundle = processor.newBundle(receiverMap);
try (CloseableFnDataReceiver<WindowedValue<InputT>> inputReceiver = bundle.getInputReceiver()) {
for (WindowedValue<InputT> value : input) {
logger.finer(String.format("Sending value: %s", value));
inputReceiver.accept(value);
}
}
bundle.getBundleResponse().toCompletableFuture().get();

// Await all outputs and active bundle completion. This is necessary because the Flink collector
// must not be accessed outside of mapPartition.
bundle.getOutputClients().values().forEach(client -> {
try {
client.awaitCompletion();
} catch (Exception e) {
throw new RuntimeException(e);
}
});
MoreFutures.get(bundle.getBundleResponse());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package org.apache.beam.runners.flink.translation.functions;

import org.apache.beam.sdk.transforms.join.RawUnionValue;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.util.Collector;

public class FlinkExecutableStagePruningFunction<T>
implements FlatMapFunction<RawUnionValue, WindowedValue<T>> {

private final int unionTag;

public FlinkExecutableStagePruningFunction(int unionTag) {
this.unionTag = unionTag;
}

@Override
public void flatMap(RawUnionValue rawUnionValue, Collector<WindowedValue<T>> collector) throws Exception {
if (rawUnionValue.getUnionTag() == unionTag) {
collector.collect((WindowedValue<T>) rawUnionValue.getValue());
}
}
}

0 comments on commit 5a3e885

Please sign in to comment.