diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java deleted file mode 100644 index 11761b6b161d2..0000000000000 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/DoFnFunction.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.beam.runners.spark.translation; - -import java.util.Collections; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import org.apache.beam.runners.core.DoFnRunner; -import org.apache.beam.runners.core.DoFnRunners; -import org.apache.beam.runners.spark.aggregators.NamedAggregators; -import org.apache.beam.runners.spark.aggregators.SparkAggregators; -import org.apache.beam.runners.spark.metrics.SparkMetricsContainer; -import org.apache.beam.runners.spark.util.SideInputBroadcast; -import org.apache.beam.runners.spark.util.SparkSideInputReader; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.util.WindowingStrategy; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.TupleTag; -import org.apache.spark.Accumulator; -import org.apache.spark.api.java.function.FlatMapFunction; - - -/** - * Beam's Do functions correspond to Spark's FlatMap functions. - * - * @param Input element type. - * @param Output element type. - */ -public class DoFnFunction - implements FlatMapFunction>, WindowedValue> { - - private final Accumulator aggregatorsAccum; - private final Accumulator metricsAccum; - private final String stepName; - private final DoFn doFn; - private final SparkRuntimeContext runtimeContext; - private final Map, KV, SideInputBroadcast>> sideInputs; - private final WindowingStrategy windowingStrategy; - - /** - * @param aggregatorsAccum The Spark {@link Accumulator} that backs the Beam Aggregators. - * @param doFn The {@link DoFn} to be wrapped. - * @param runtimeContext The {@link SparkRuntimeContext}. - * @param sideInputs Side inputs used in this {@link DoFn}. - * @param windowingStrategy Input {@link WindowingStrategy}. - */ - public DoFnFunction( - Accumulator aggregatorsAccum, - Accumulator metricsAccum, - String stepName, - DoFn doFn, - SparkRuntimeContext runtimeContext, - Map, KV, SideInputBroadcast>> sideInputs, - WindowingStrategy windowingStrategy) { - this.aggregatorsAccum = aggregatorsAccum; - this.metricsAccum = metricsAccum; - this.stepName = stepName; - this.doFn = doFn; - this.runtimeContext = runtimeContext; - this.sideInputs = sideInputs; - this.windowingStrategy = windowingStrategy; - } - - @Override - public Iterable> call( - Iterator> iter) throws Exception { - DoFnOutputManager outputManager = new DoFnOutputManager(); - - DoFnRunner doFnRunner = - DoFnRunners.simpleRunner( - runtimeContext.getPipelineOptions(), - doFn, - new SparkSideInputReader(sideInputs), - outputManager, - new TupleTag() { - }, - Collections.>emptyList(), - new SparkProcessContext.NoOpStepContext(), - new SparkAggregators.Factory(runtimeContext, aggregatorsAccum), - windowingStrategy); - - DoFnRunner doFnRunnerWithMetrics = - new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum); - - return new SparkProcessContext<>(doFn, doFnRunnerWithMetrics, outputManager) - .processPartition(iter); - } - - private class DoFnOutputManager - implements SparkProcessContext.SparkOutputManager> { - - private final List> outputs = new LinkedList<>(); - - @Override - public void clear() { - outputs.clear(); - } - - @Override - public Iterator> iterator() { - return outputs.iterator(); - } - - @Override - @SuppressWarnings("unchecked") - public synchronized void output(TupleTag tag, WindowedValue output) { - outputs.add((WindowedValue) output); - } - } - -} diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 6290bbaf85ea0..7894c4eac1da2 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -24,7 +24,6 @@ import static org.apache.beam.runners.spark.translation.TranslationUtils.rejectStateAndTimers; import com.google.common.base.Optional; -import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import java.util.Collections; @@ -334,8 +333,7 @@ public String toNativeString() { }; } - private static TransformEvaluator> - parDo() { + private static TransformEvaluator> parDo() { return new TransformEvaluator>() { @Override public void evaluate( @@ -351,51 +349,31 @@ public void evaluate( context.getInput(transform).getWindowingStrategy(); Accumulator aggAccum = AggregatorsAccumulator.getInstance(); Accumulator metricsAccum = MetricsAccumulator.getInstance(); - Map, KV, SideInputBroadcast>> sideInputs = - TranslationUtils.getSideInputs(transform.getSideInputs(), context); - if (transform.getSideOutputTags().size() == 0) { - // Don't tag with the output and filter for a single-output ParDo, as it's additional - // identity transforms. - // Also see BEAM-1737 for failures when the two versions are condensed. - PCollection output = - (PCollection) - Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); - context.putDataset( - output, - new BoundedDataset<>( - inRDD.mapPartitions( - new DoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - context.getRuntimeContext(), - sideInputs, - windowingStrategy)))); - } else { - JavaPairRDD, WindowedValue> all = - inRDD - .mapPartitionsToPair( - new MultiDoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - context.getRuntimeContext(), - transform.getMainOutputTag(), - TranslationUtils.getSideInputs(transform.getSideInputs(), context), - windowingStrategy)) - .cache(); - for (TaggedPValue output : context.getOutputs(transform)) { - @SuppressWarnings("unchecked") - JavaPairRDD, WindowedValue> filtered = - all.filter(new TranslationUtils.TupleTagFilter(output.getTag())); - @SuppressWarnings("unchecked") - // Object is the best we can do since different outputs can have different tags - JavaRDD> values = - (JavaRDD>) (JavaRDD) filtered.values(); - context.putDataset(output.getValue(), new BoundedDataset<>(values)); - } + JavaPairRDD, WindowedValue> all = + inRDD.mapPartitionsToPair( + new MultiDoFnFunction<>( + aggAccum, + metricsAccum, + stepName, + doFn, + context.getRuntimeContext(), + transform.getMainOutputTag(), + TranslationUtils.getSideInputs(transform.getSideInputs(), context), + windowingStrategy)); + List outputs = context.getOutputs(transform); + if (outputs.size() > 1) { + // cache the RDD if we're going to filter it more than once. + all.cache(); + } + for (TaggedPValue output : outputs) { + @SuppressWarnings("unchecked") + JavaPairRDD, WindowedValue> filtered = + all.filter(new TranslationUtils.TupleTagFilter(output.getTag())); + @SuppressWarnings("unchecked") + // Object is the best we can do since different outputs can have different tags + JavaRDD> values = + (JavaRDD>) (JavaRDD) filtered.values(); + context.putDataset(output.getValue(), new BoundedDataset<>(values)); } } diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 2d2854f0635b9..d4c6c9de3bea4 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -43,7 +43,6 @@ import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet; import org.apache.beam.runners.spark.translation.BoundedDataset; import org.apache.beam.runners.spark.translation.Dataset; -import org.apache.beam.runners.spark.translation.DoFnFunction; import org.apache.beam.runners.spark.translation.EvaluationContext; import org.apache.beam.runners.spark.translation.GroupCombineFunctions; import org.apache.beam.runners.spark.translation.MultiDoFnFunction; @@ -368,8 +367,7 @@ public String toNativeString() { }; } - private static TransformEvaluator> - multiDo() { + private static TransformEvaluator> parDo() { return new TransformEvaluator>() { public void evaluate( final ParDo.MultiOutput transform, final EvaluationContext context) { @@ -388,16 +386,14 @@ public void evaluate( final String stepName = context.getCurrentTransform().getFullName(); if (transform.getSideOutputTags().size() == 0) { - // Don't tag with the output and filter for a single-output ParDo, as it's additional - // identity transforms. - // Also see BEAM-1737 for failures when the two versions are condensed. - JavaDStream> outStream = - dStream.transform( - new Function>, JavaRDD>>() { + JavaPairDStream, WindowedValue> all = + dStream.transformToPair( + new Function< + JavaRDD>, + JavaPairRDD, WindowedValue>>() { @Override - public JavaRDD> call(JavaRDD> rdd) - throws Exception { - final JavaSparkContext jsc = new JavaSparkContext(rdd.context()); + public JavaPairRDD, WindowedValue> call( + JavaRDD> rdd) throws Exception { final Accumulator aggAccum = AggregatorsAccumulator.getInstance(); final Accumulator metricsAccum = @@ -405,59 +401,27 @@ public JavaRDD> call(JavaRDD> rdd) final Map, KV, SideInputBroadcast>> sideInputs = TranslationUtils.getSideInputs( - transform.getSideInputs(), jsc, pviews); - return rdd.mapPartitions( - new DoFnFunction<>( + transform.getSideInputs(), + JavaSparkContext.fromSparkContext(rdd.context()), + pviews); + return rdd.mapPartitionsToPair( + new MultiDoFnFunction<>( aggAccum, metricsAccum, stepName, doFn, runtimeContext, + transform.getMainOutputTag(), sideInputs, windowingStrategy)); } }); - - PCollection output = - (PCollection) - Iterables.getOnlyElement(context.getOutputs(transform)).getValue(); - context.putDataset( - output, new UnboundedDataset<>(outStream, unboundedDataset.getStreamSources())); - } else { - JavaPairDStream, WindowedValue> all = - dStream - .transformToPair( - new Function< - JavaRDD>, - JavaPairRDD, WindowedValue>>() { - @Override - public JavaPairRDD, WindowedValue> call( - JavaRDD> rdd) throws Exception { - String stepName = context.getCurrentTransform().getFullName(); - final Accumulator aggAccum = - AggregatorsAccumulator.getInstance(); - final Accumulator metricsAccum = - MetricsAccumulator.getInstance(); - final Map, KV, SideInputBroadcast>> - sideInputs = - TranslationUtils.getSideInputs( - transform.getSideInputs(), - JavaSparkContext.fromSparkContext(rdd.context()), - pviews); - return rdd.mapPartitionsToPair( - new MultiDoFnFunction<>( - aggAccum, - metricsAccum, - stepName, - doFn, - runtimeContext, - transform.getMainOutputTag(), - sideInputs, - windowingStrategy)); - } - }) - .cache(); - for (TaggedPValue output : context.getOutputs(transform)) { + List outputs = context.getOutputs(transform); + if (outputs.size() > 1) { + // cache the DStream if we're going to filter it more than once. + all.cache(); + } + for (TaggedPValue output : outputs) { @SuppressWarnings("unchecked") JavaPairDStream, WindowedValue> filtered = all.filter(new TranslationUtils.TupleTagFilter(output.getTag())); @@ -525,7 +489,7 @@ public JavaRDD>> call( EVALUATORS.put(Read.Unbounded.class, readUnbounded()); EVALUATORS.put(GroupByKey.class, groupByKey()); EVALUATORS.put(Combine.GroupedValues.class, combineGrouped()); - EVALUATORS.put(ParDo.MultiOutput.class, multiDo()); + EVALUATORS.put(ParDo.MultiOutput.class, parDo()); EVALUATORS.put(ConsoleIO.Write.Unbound.class, print()); EVALUATORS.put(CreateStream.class, createFromQueue()); EVALUATORS.put(Window.Assign.class, window());