Skip to content

Commit

Permalink
This closes apache#2420
Browse files Browse the repository at this point in the history
  • Loading branch information
amitsela committed Apr 9, 2017
2 parents 810db7f + 9e294dc commit a0cfccd
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 235 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers;

import com.google.common.base.Optional;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.Collections;
Expand Down Expand Up @@ -334,8 +333,7 @@ public String toNativeString() {
};
}

private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>
parDo() {
private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() {
return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() {
@Override
public void evaluate(
Expand All @@ -351,51 +349,31 @@ public void evaluate(
context.getInput(transform).getWindowingStrategy();
Accumulator<NamedAggregators> aggAccum = AggregatorsAccumulator.getInstance();
Accumulator<SparkMetricsContainer> metricsAccum = MetricsAccumulator.getInstance();
Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>> sideInputs =
TranslationUtils.getSideInputs(transform.getSideInputs(), context);
if (transform.getSideOutputTags().size() == 0) {
// Don't tag with the output and filter for a single-output ParDo, as it's additional
// identity transforms.
// Also see BEAM-1737 for failures when the two versions are condensed.
PCollection<OutputT> output =
(PCollection<OutputT>)
Iterables.getOnlyElement(context.getOutputs(transform)).getValue();
context.putDataset(
output,
new BoundedDataset<>(
inRDD.mapPartitions(
new DoFnFunction<>(
aggAccum,
metricsAccum,
stepName,
doFn,
context.getRuntimeContext(),
sideInputs,
windowingStrategy))));
} else {
JavaPairRDD<TupleTag<?>, WindowedValue<?>> all =
inRDD
.mapPartitionsToPair(
new MultiDoFnFunction<>(
aggAccum,
metricsAccum,
stepName,
doFn,
context.getRuntimeContext(),
transform.getMainOutputTag(),
TranslationUtils.getSideInputs(transform.getSideInputs(), context),
windowingStrategy))
.cache();
for (TaggedPValue output : context.getOutputs(transform)) {
@SuppressWarnings("unchecked")
JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
all.filter(new TranslationUtils.TupleTagFilter(output.getTag()));
@SuppressWarnings("unchecked")
// Object is the best we can do since different outputs can have different tags
JavaRDD<WindowedValue<Object>> values =
(JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
context.putDataset(output.getValue(), new BoundedDataset<>(values));
}
JavaPairRDD<TupleTag<?>, WindowedValue<?>> all =
inRDD.mapPartitionsToPair(
new MultiDoFnFunction<>(
aggAccum,
metricsAccum,
stepName,
doFn,
context.getRuntimeContext(),
transform.getMainOutputTag(),
TranslationUtils.getSideInputs(transform.getSideInputs(), context),
windowingStrategy));
List<TaggedPValue> outputs = context.getOutputs(transform);
if (outputs.size() > 1) {
// cache the RDD if we're going to filter it more than once.
all.cache();
}
for (TaggedPValue output : outputs) {
@SuppressWarnings("unchecked")
JavaPairRDD<TupleTag<?>, WindowedValue<?>> filtered =
all.filter(new TranslationUtils.TupleTagFilter(output.getTag()));
@SuppressWarnings("unchecked")
// Object is the best we can do since different outputs can have different tags
JavaRDD<WindowedValue<Object>> values =
(JavaRDD<WindowedValue<Object>>) (JavaRDD<?>) filtered.values();
context.putDataset(output.getValue(), new BoundedDataset<>(values));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet;
import org.apache.beam.runners.spark.translation.BoundedDataset;
import org.apache.beam.runners.spark.translation.Dataset;
import org.apache.beam.runners.spark.translation.DoFnFunction;
import org.apache.beam.runners.spark.translation.EvaluationContext;
import org.apache.beam.runners.spark.translation.GroupCombineFunctions;
import org.apache.beam.runners.spark.translation.MultiDoFnFunction;
Expand Down Expand Up @@ -368,8 +367,7 @@ public String toNativeString() {
};
}

private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>
multiDo() {
private static <InputT, OutputT> TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>> parDo() {
return new TransformEvaluator<ParDo.MultiOutput<InputT, OutputT>>() {
public void evaluate(
final ParDo.MultiOutput<InputT, OutputT> transform, final EvaluationContext context) {
Expand All @@ -388,76 +386,42 @@ public void evaluate(

final String stepName = context.getCurrentTransform().getFullName();
if (transform.getSideOutputTags().size() == 0) {
// Don't tag with the output and filter for a single-output ParDo, as it's additional
// identity transforms.
// Also see BEAM-1737 for failures when the two versions are condensed.
JavaDStream<WindowedValue<OutputT>> outStream =
dStream.transform(
new Function<JavaRDD<WindowedValue<InputT>>, JavaRDD<WindowedValue<OutputT>>>() {
JavaPairDStream<TupleTag<?>, WindowedValue<?>> all =
dStream.transformToPair(
new Function<
JavaRDD<WindowedValue<InputT>>,
JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() {
@Override
public JavaRDD<WindowedValue<OutputT>> call(JavaRDD<WindowedValue<InputT>> rdd)
throws Exception {
final JavaSparkContext jsc = new JavaSparkContext(rdd.context());
public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(
JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
final Accumulator<NamedAggregators> aggAccum =
AggregatorsAccumulator.getInstance();
final Accumulator<SparkMetricsContainer> metricsAccum =
MetricsAccumulator.getInstance();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
sideInputs =
TranslationUtils.getSideInputs(
transform.getSideInputs(), jsc, pviews);
return rdd.mapPartitions(
new DoFnFunction<>(
transform.getSideInputs(),
JavaSparkContext.fromSparkContext(rdd.context()),
pviews);
return rdd.mapPartitionsToPair(
new MultiDoFnFunction<>(
aggAccum,
metricsAccum,
stepName,
doFn,
runtimeContext,
transform.getMainOutputTag(),
sideInputs,
windowingStrategy));
}
});

PCollection<OutputT> output =
(PCollection<OutputT>)
Iterables.getOnlyElement(context.getOutputs(transform)).getValue();
context.putDataset(
output, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources()));
} else {
JavaPairDStream<TupleTag<?>, WindowedValue<?>> all =
dStream
.transformToPair(
new Function<
JavaRDD<WindowedValue<InputT>>,
JavaPairRDD<TupleTag<?>, WindowedValue<?>>>() {
@Override
public JavaPairRDD<TupleTag<?>, WindowedValue<?>> call(
JavaRDD<WindowedValue<InputT>> rdd) throws Exception {
String stepName = context.getCurrentTransform().getFullName();
final Accumulator<NamedAggregators> aggAccum =
AggregatorsAccumulator.getInstance();
final Accumulator<SparkMetricsContainer> metricsAccum =
MetricsAccumulator.getInstance();
final Map<TupleTag<?>, KV<WindowingStrategy<?, ?>, SideInputBroadcast<?>>>
sideInputs =
TranslationUtils.getSideInputs(
transform.getSideInputs(),
JavaSparkContext.fromSparkContext(rdd.context()),
pviews);
return rdd.mapPartitionsToPair(
new MultiDoFnFunction<>(
aggAccum,
metricsAccum,
stepName,
doFn,
runtimeContext,
transform.getMainOutputTag(),
sideInputs,
windowingStrategy));
}
})
.cache();
for (TaggedPValue output : context.getOutputs(transform)) {
List<TaggedPValue> outputs = context.getOutputs(transform);
if (outputs.size() > 1) {
// cache the DStream if we're going to filter it more than once.
all.cache();
}
for (TaggedPValue output : outputs) {
@SuppressWarnings("unchecked")
JavaPairDStream<TupleTag<?>, WindowedValue<?>> filtered =
all.filter(new TranslationUtils.TupleTagFilter(output.getTag()));
Expand Down Expand Up @@ -525,7 +489,7 @@ public JavaRDD<WindowedValue<KV<K, V>>> call(
EVALUATORS.put(Read.Unbounded.class, readUnbounded());
EVALUATORS.put(GroupByKey.class, groupByKey());
EVALUATORS.put(Combine.GroupedValues.class, combineGrouped());
EVALUATORS.put(ParDo.MultiOutput.class, multiDo());
EVALUATORS.put(ParDo.MultiOutput.class, parDo());
EVALUATORS.put(ConsoleIO.Write.Unbound.class, print());
EVALUATORS.put(CreateStream.class, createFromQueue());
EVALUATORS.put(Window.Assign.class, window());
Expand Down

0 comments on commit a0cfccd

Please sign in to comment.