From 64262a61402fad67d9ad8a66eaf6322593d3b5dc Mon Sep 17 00:00:00 2001 From: Salman Raza <50444219+salmanVD@users.noreply.github.com> Date: Sat, 24 Aug 2019 22:06:33 +0500 Subject: [PATCH] Merge pull request #9275: [BEAM-6858] Support side inputs injected into a DoFn --- .../apex/translation/ParDoTranslator.java | 33 ++++--- .../operators/ApexParDoOperator.java | 26 +++-- .../apex/translation/ParDoTranslatorTest.java | 1 + .../core/construction/ParDoTranslation.java | 39 +++++++- .../core/construction/SplittableParDo.java | 7 +- .../SplittableParDoNaiveBounded.java | 5 + .../construction/ParDoTranslationTest.java | 4 +- .../apache/beam/runners/core/DoFnRunners.java | 12 ++- ...oundedSplittableProcessElementInvoker.java | 5 + .../beam/runners/core/SimpleDoFnRunner.java | 28 +++++- .../runners/core/SimpleDoFnRunnerTest.java | 27 +++-- .../runners/core/StatefulDoFnRunnerTest.java | 3 +- .../runners/direct/DirectGraphVisitor.java | 3 +- .../beam/runners/direct/ParDoEvaluator.java | 13 ++- .../runners/direct/ParDoEvaluatorFactory.java | 9 +- .../direct/ParDoMultiOverrideFactory.java | 24 ++++- ...ttableProcessElementsEvaluatorFactory.java | 10 +- .../direct/StatefulParDoEvaluatorFactory.java | 3 +- .../runners/direct/ParDoEvaluatorTest.java | 2 + .../StatefulParDoEvaluatorFactoryTest.java | 6 +- .../flink/FlinkBatchTransformTranslators.java | 8 +- .../FlinkStreamingTransformTranslators.java | 24 +++-- .../functions/FlinkDoFnFunction.java | 8 +- .../functions/FlinkStatefulDoFnFunction.java | 8 +- .../wrappers/streaming/DoFnOperator.java | 9 +- .../ExecutableStageDoFnOperator.java | 3 +- .../streaming/SplittableDoFnOperator.java | 3 +- .../streaming/WindowDoFnOperator.java | 3 +- .../flink/FlinkPipelineOptionsTest.java | 6 +- .../wrappers/streaming/DoFnOperatorTest.java | 57 +++++++---- .../ParDoMultiOutputTranslator.java | 8 +- .../translators/functions/DoFnFunction.java | 6 +- .../translators/utils/DoFnRunnerFactory.java | 8 +- .../dataflow/DataflowPipelineTranslator.java | 42 +++++--- .../dataflow/PrimitiveParDoSingleFactory.java | 15 ++- .../PrimitiveParDoSingleFactoryTest.java | 2 +- .../worker/CombineValuesFnFactory.java | 13 ++- .../dataflow/worker/DoFnRunnerFactory.java | 3 +- .../worker/SimpleDoFnRunnerFactory.java | 6 +- .../dataflow/worker/SimpleParDoFn.java | 7 +- .../worker/SplittableProcessFnFactory.java | 9 +- .../dataflow/worker/UserParDoFnFactory.java | 2 + .../worker/DefaultParDoFnFactoryTest.java | 4 +- .../worker/DoFnInstanceManagersTest.java | 16 ++- .../IntrinsicMapTaskExecutorFactoryTest.java | 4 +- .../dataflow/worker/SimpleParDoFnTest.java | 24 +++-- .../worker/StreamingDataflowWorkerTest.java | 3 +- .../StreamingSideInputDoFnRunnerTest.java | 3 +- .../worker/UserParDoFnFactoryTest.java | 3 +- .../org/apache/beam/runners/jet/Utils.java | 4 +- .../jet/processors/AbstractParDoP.java | 7 +- .../beam/runners/jet/processors/ParDoP.java | 6 +- .../jet/processors/StatefulParDoP.java | 6 +- .../beam/runners/samza/runtime/DoFnOp.java | 8 +- .../runners/samza/runtime/GroupByKeyOp.java | 3 +- .../samza/runtime/SamzaDoFnRunners.java | 7 +- .../ParDoBoundMultiTranslator.java | 18 +++- .../spark/translation/MultiDoFnFunction.java | 9 +- .../translation/TransformTranslator.java | 8 +- .../spark/translation/TranslationUtils.java | 5 +- .../StreamingTransformTranslator.java | 9 +- .../org/apache/beam/sdk/transforms/DoFn.java | 8 ++ .../beam/sdk/transforms/DoFnTester.java | 5 + .../org/apache/beam/sdk/transforms/ParDo.java | 82 +++++++++++----- .../reflect/ByteBuddyDoFnInvokerFactory.java | 12 +++ .../sdk/transforms/reflect/DoFnInvoker.java | 11 +++ .../sdk/transforms/reflect/DoFnSignature.java | 39 ++++++++ .../transforms/reflect/DoFnSignatures.java | 23 ++++- .../org/apache/beam/sdk/util/DoFnInfo.java | 24 +++-- .../util/DoFnWithExecutionInformation.java | 12 ++- .../beam/sdk/values/PCollectionView.java | 1 - .../beam/sdk/values/PCollectionViews.java | 1 - .../apache/beam/sdk/transforms/ParDoTest.java | 98 +++++++++++++++++-- .../reflect/DoFnSignaturesTest.java | 9 +- .../beam/fn/harness/FnApiDoFnRunner.java | 10 ++ 75 files changed, 769 insertions(+), 223 deletions(-) diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java index 871c83c5bdb67..cf5d11bdf4340 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/ParDoTranslator.java @@ -22,6 +22,7 @@ import com.datatorrent.api.Operator; import com.datatorrent.api.Operator.OutputPort; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -41,6 +42,7 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -76,11 +78,14 @@ public void translate(ParDo.MultiOutput transform, TranslationC Map, PValue> outputs = context.getOutputs(); PCollection input = context.getInput(); - List> sideInputs = transform.getSideInputs(); + Iterable> sideInputs = transform.getSideInputs().values(); DoFnSchemaInformation doFnSchemaInformation; doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + Map, Coder> outputCoders = outputs.entrySet().stream() .filter(e -> e.getValue() instanceof PCollection) @@ -97,6 +102,7 @@ public void translate(ParDo.MultiOutput transform, TranslationC input.getCoder(), outputCoders, doFnSchemaInformation, + sideInputMapping, context.getStateBackend()); Map, OutputPort> ports = Maps.newHashMapWithExpectedSize(outputs.size()); @@ -124,7 +130,7 @@ public void translate(ParDo.MultiOutput transform, TranslationC } context.addOperator(operator, ports); context.addStream(context.getInput(), operator.input); - if (!sideInputs.isEmpty()) { + if (!Iterables.isEmpty(sideInputs)) { addSideInputs(operator.sideInput1, sideInputs, context); } } @@ -139,7 +145,7 @@ public void translate( Map, PValue> outputs = context.getOutputs(); PCollection input = context.getInput(); - List> sideInputs = transform.getSideInputs(); + Iterable> sideInputs = transform.getSideInputs(); Map, Coder> outputCoders = outputs.entrySet().stream() @@ -160,6 +166,7 @@ public void translate( input.getCoder(), outputCoders, DoFnSchemaInformation.create(), + Collections.emptyMap(), context.getStateBackend()); Map, OutputPort> ports = Maps.newHashMapWithExpectedSize(outputs.size()); @@ -188,7 +195,7 @@ public void translate( context.addOperator(operator, ports); context.addStream(context.getInput(), operator.input); - if (!sideInputs.isEmpty()) { + if (!Iterables.isEmpty(sideInputs)) { addSideInputs(operator.sideInput1, sideInputs, context); } } @@ -196,29 +203,29 @@ public void translate( static void addSideInputs( Operator.InputPort sideInputPort, - List> sideInputs, + Iterable> sideInputs, TranslationContext context) { Operator.InputPort[] sideInputPorts = {sideInputPort}; - if (sideInputs.size() > sideInputPorts.length) { + if (Iterables.size(sideInputs) > sideInputPorts.length) { PCollection unionCollection = unionSideInputs(sideInputs, context); context.addStream(unionCollection, sideInputPorts[0]); } else { // the number of ports for side inputs is fixed and each port can only take one input. - for (int i = 0; i < sideInputs.size(); i++) { - context.addStream(context.getViewInput(sideInputs.get(i)), sideInputPorts[i]); + for (int i = 0; i < Iterables.size(sideInputs); i++) { + context.addStream(context.getViewInput(Iterables.get(sideInputs, i)), sideInputPorts[i]); } } } private static PCollection unionSideInputs( - List> sideInputs, TranslationContext context) { - checkArgument(sideInputs.size() > 1, "requires multiple side inputs"); + Iterable> sideInputs, TranslationContext context) { + checkArgument(Iterables.size(sideInputs) > 1, "requires multiple side inputs"); // flatten and assign union tag List> sourceCollections = new ArrayList<>(); Map, Integer> unionTags = new HashMap<>(); - PCollection firstSideInput = context.getViewInput(sideInputs.get(0)); - for (int i = 0; i < sideInputs.size(); i++) { - PCollectionView sideInput = sideInputs.get(i); + PCollection firstSideInput = context.getViewInput(Iterables.get(sideInputs, 0)); + for (int i = 0; i < Iterables.size(sideInputs); i++) { + PCollectionView sideInput = Iterables.get(sideInputs, i); PCollection sideInputCollection = context.getViewInput(sideInput); if (!sideInputCollection .getWindowingStrategy() diff --git a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java index a490b8a0dbe8c..9d4b110e9738f 100644 --- a/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java +++ b/runners/apex/src/main/java/org/apache/beam/runners/apex/translation/operators/ApexParDoOperator.java @@ -87,6 +87,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps; import org.joda.time.Duration; import org.joda.time.Instant; @@ -115,7 +116,7 @@ public class ApexParDoOperator extends BaseOperator private final WindowingStrategy windowingStrategy; @Bind(JavaSerializer.class) - private final List> sideInputs; + private final Iterable> sideInputs; @Bind(JavaSerializer.class) private final Coder> windowedInputCoder; @@ -129,6 +130,9 @@ public class ApexParDoOperator extends BaseOperator @Bind(JavaSerializer.class) private final DoFnSchemaInformation doFnSchemaInformation; + @Bind(JavaSerializer.class) + private final Map> sideInputMapping; + private StateInternalsProxy currentKeyStateInternals; private final ApexTimerInternals currentKeyTimerInternals; @@ -152,10 +156,11 @@ public ApexParDoOperator( TupleTag mainOutputTag, List> additionalOutputTags, WindowingStrategy windowingStrategy, - List> sideInputs, + Iterable> sideInputs, Coder inputCoder, Map, Coder> outputCoders, DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping, ApexStateBackend stateBackend) { this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions); this.doFn = doFn; @@ -186,6 +191,7 @@ public ApexParDoOperator( TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder()); this.currentKeyTimerInternals = new ApexTimerInternals<>(timerCoder); this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; if (doFn instanceof ProcessFn) { // we know that it is keyed on byte[] @@ -219,6 +225,7 @@ private ApexParDoOperator() { this.outputCoders = Collections.emptyMap(); this.currentKeyTimerInternals = null; this.doFnSchemaInformation = null; + this.sideInputMapping = null; } public final transient DefaultInputPort>> input = @@ -260,7 +267,7 @@ public void process(ApexStreamTuple>> t) { LOG.debug("\nsideInput {} {}\n", sideInputIndex, t.getValue()); } - PCollectionView sideInput = sideInputs.get(sideInputIndex); + PCollectionView sideInput = Iterables.get(sideInputs, sideInputIndex); sideInputHandler.addSideInputValue(sideInput, t.getValue()); List> newPushedBack = new ArrayList<>(); @@ -408,7 +415,7 @@ private void processWatermark(ApexStreamTuple.WatermarkTuple mark) { currentInputWatermark); } } - if (sideInputs.isEmpty()) { + if (Iterables.isEmpty(sideInputs)) { outputWatermark(mark); return; } @@ -439,8 +446,9 @@ public void setup(OperatorContext context) { ApexStreamTuple.Logging.isDebugEnabled( pipelineOptions.get().as(ApexPipelineOptions.class), this); SideInputReader sideInputReader = NullSideInputReader.of(sideInputs); - if (!sideInputs.isEmpty()) { - sideInputHandler = new SideInputHandler(sideInputs, sideInputStateInternals); + if (!Iterables.isEmpty(sideInputs)) { + sideInputHandler = + new SideInputHandler(Lists.newArrayList(sideInputs), sideInputStateInternals); sideInputReader = sideInputHandler; } @@ -476,7 +484,8 @@ public TimerInternals timerInternals() { inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); doFnInvoker = DoFnInvokers.invokerFor(doFn); doFnInvoker.invokeSetup(); @@ -501,7 +510,8 @@ public TimerInternals timerInternals() { } pushbackDoFnRunner = - SimplePushbackSideInputDoFnRunner.create(doFnRunner, sideInputs, sideInputHandler); + SimplePushbackSideInputDoFnRunner.create( + doFnRunner, Lists.newArrayList(sideInputs), sideInputHandler); if (doFn instanceof ProcessFn) { diff --git a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java index 87c5acec64f79..cda57aafa8321 100644 --- a/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java +++ b/runners/apex/src/test/java/org/apache/beam/runners/apex/translation/ParDoTranslatorTest.java @@ -214,6 +214,7 @@ public void testSerialization() throws Exception { VarIntCoder.of(), Collections.emptyMap(), DoFnSchemaInformation.create(), + Collections.emptyMap(), new ApexStateInternals.ApexStateBackend()); operator.setup(null); operator.beginWindow(0); diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java index 067a651e0b509..600538d94587f 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ParDoTranslation.java @@ -172,7 +172,7 @@ public static ParDoPayload translateParDo( .map(TupleTag::getId) .collect(Collectors.toSet()); Set sideInputs = - parDo.getSideInputs().stream() + parDo.getSideInputs().values().stream() .map(s -> s.getTagInternal().getId()) .collect(Collectors.toSet()); Set timerInputs = signature.timerDeclarations().keySet(); @@ -209,7 +209,11 @@ public static ParDoPayload translateParDo( @Override public SdkFunctionSpec translateDoFn(SdkComponents newComponents) { return ParDoTranslation.translateDoFn( - parDo.getFn(), parDo.getMainOutputTag(), doFnSchemaInformation, newComponents); + parDo.getFn(), + parDo.getMainOutputTag(), + parDo.getSideInputs(), + doFnSchemaInformation, + newComponents); } @Override @@ -221,7 +225,7 @@ public List translateParameters() { @Override public Map translateSideInputs(SdkComponents components) { Map sideInputs = new HashMap<>(); - for (PCollectionView sideInput : parDo.getSideInputs()) { + for (PCollectionView sideInput : parDo.getSideInputs().values()) { sideInputs.put( sideInput.getTagInternal().getId(), translateView(sideInput, components)); } @@ -315,6 +319,28 @@ public static TupleTag getMainOutputTag(ParDoPayload payload) return doFnWithExecutionInformationFromProto(payload.getDoFn()).getMainOutputTag(); } + public static Map> getSideInputMapping( + AppliedPTransform application) { + try { + return getSideInputMapping(getParDoPayload(application)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static Map> getSideInputMapping( + RunnerApi.PTransform pTransform) { + try { + return getSideInputMapping(getParDoPayload(pTransform)); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static Map> getSideInputMapping(ParDoPayload payload) { + return doFnWithExecutionInformationFromProto(payload.getDoFn()).getSideInputMapping(); + } + public static TupleTag getMainOutputTag(AppliedPTransform application) throws IOException { PTransform transform = application.getTransform(); @@ -359,7 +385,8 @@ public static List> getSideInputs(AppliedPTransform throws IOException { PTransform transform = application.getTransform(); if (transform instanceof ParDo.MultiOutput) { - return ((ParDo.MultiOutput) transform).getSideInputs(); + return ((ParDo.MultiOutput) transform) + .getSideInputs().values().stream().collect(Collectors.toList()); } SdkComponents sdkComponents = SdkComponents.create(application.getPipeline().getOptions()); @@ -551,6 +578,7 @@ private static RunnerApi.TimeDomain.Enum translateTimeDomain(TimeDomain timeDoma public static SdkFunctionSpec translateDoFn( DoFn fn, TupleTag tag, + Map> sideInputMapping, DoFnSchemaInformation doFnSchemaInformation, SdkComponents components) { return SdkFunctionSpec.newBuilder() @@ -561,7 +589,8 @@ public static SdkFunctionSpec translateDoFn( .setPayload( ByteString.copyFrom( SerializableUtils.serializeToByteArray( - DoFnWithExecutionInformation.of(fn, tag, doFnSchemaInformation)))) + DoFnWithExecutionInformation.of( + fn, tag, sideInputMapping, doFnSchemaInformation)))) .build()) .build(); } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java index 63e89cd21b96d..237318817acd8 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDo.java @@ -21,6 +21,7 @@ import com.google.auto.service.AutoService; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.ThreadLocalRandom; @@ -368,7 +369,11 @@ public FunctionSpec translate( public SdkFunctionSpec translateDoFn(SdkComponents newComponents) { // Schemas not yet supported on splittable DoFn. return ParDoTranslation.translateDoFn( - fn, pke.getMainOutputTag(), DoFnSchemaInformation.create(), newComponents); + fn, + pke.getMainOutputTag(), + Collections.emptyMap(), + DoFnSchemaInformation.create(), + newComponents); } @Override diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java index bce429aa1a52b..340889d9fd727 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/SplittableParDoNaiveBounded.java @@ -237,6 +237,11 @@ public InputT element(DoFn doFn) { return element; } + @Override + public Object sideInput(String tagId) { + throw new UnsupportedOperationException(); + } + @Override public Object schemaElement(int index) { throw new UnsupportedOperationException(); diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java index 6ea3bde27b2ca..caff55aad827c 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ParDoTranslationTest.java @@ -122,7 +122,7 @@ public void testToProto() throws Exception { assertThat(ParDoTranslation.getDoFn(payload), equalTo(parDo.getFn())); assertThat(ParDoTranslation.getMainOutputTag(payload), equalTo(parDo.getMainOutputTag())); - for (PCollectionView view : parDo.getSideInputs()) { + for (PCollectionView view : parDo.getSideInputs().values()) { payload.getSideInputsOrThrow(view.getTagInternal().getId()); } } @@ -148,7 +148,7 @@ public void toTransformProto() throws Exception { // Decode ParDoPayload parDoPayload = ParDoPayload.parseFrom(protoTransform.getSpec().getPayload()); - for (PCollectionView view : parDo.getSideInputs()) { + for (PCollectionView view : parDo.getSideInputs().values()) { SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId()); PCollectionView restoredView = PCollectionViewTranslation.viewFromProto( diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java index 3d929d7afc6a0..6496561a8d2d3 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java @@ -61,7 +61,8 @@ public static DoFnRunner simpleRunner( @Nullable Coder inputCoder, Map, Coder> outputCoders, WindowingStrategy windowingStrategy, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { return new SimpleDoFnRunner<>( options, fn, @@ -73,7 +74,8 @@ public static DoFnRunner simpleRunner( inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); } /** @@ -118,7 +120,8 @@ ProcessFnRunner newProcessFnRunner( @Nullable Coder>> inputCoder, Map, Coder> outputCoders, WindowingStrategy windowingStrategy, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { return new ProcessFnRunner<>( simpleRunner( options, @@ -131,7 +134,8 @@ ProcessFnRunner newProcessFnRunner( inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation), + doFnSchemaInformation, + sideInputMapping), views, sideInputReader); } diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java index 6f0a63d926541..67982e10ef4cb 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/OutputAndTimeBoundedSplittableProcessElementInvoker.java @@ -120,6 +120,11 @@ public InputT element(DoFn doFn) { return processContext.element(); } + @Override + public Object sideInput(String tagId) { + throw new UnsupportedOperationException("Not supported in SplittableDoFn"); + } + @Override public Object schemaElement(int index) { throw new UnsupportedOperationException("Not supported in SplittableDoFn"); diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java index 27d0cc79d9725..2b105fb4220f8 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/SimpleDoFnRunner.java @@ -106,6 +106,8 @@ public class SimpleDoFnRunner implements DoFnRunner> sideInputMapping; + /** Constructor. */ public SimpleDoFnRunner( PipelineOptions options, @@ -118,7 +120,8 @@ public SimpleDoFnRunner( @Nullable Coder inputCoder, Map, Coder> outputCoders, WindowingStrategy windowingStrategy, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.options = options; this.fn = fn; this.signature = DoFnSignatures.getSignature(fn.getClass()); @@ -151,6 +154,7 @@ public SimpleDoFnRunner( this.windowCoder = untypedCoder; this.allowedLateness = windowingStrategy.getAllowedLateness(); this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } @Override @@ -301,6 +305,12 @@ public InputT element(DoFn doFn) { "Element parameters are not supported outside of @ProcessElement method."); } + @Override + public InputT sideInput(String tagId) { + throw new UnsupportedOperationException( + "SideInput parameters are not supported outside of @ProcessElement method."); + } + @Override public Object schemaElement(int index) { throw new UnsupportedOperationException( @@ -415,6 +425,12 @@ public InputT element(DoFn doFn) { "Cannot access element outside of @ProcessElement method."); } + @Override + public InputT sideInput(String tagId) { + throw new UnsupportedOperationException( + "Cannot access sideInput outside of @ProcessElement method."); + } + @Override public Object schemaElement(int index) { throw new UnsupportedOperationException( @@ -631,6 +647,11 @@ public InputT element(DoFn doFn) { return element(); } + @Override + public Object sideInput(String tagId) { + return sideInput(sideInputMapping.get(tagId)); + } + @Override public Object schemaElement(int index) { SerializableFunction converter = doFnSchemaInformation.getElementConverters().get(index); @@ -781,6 +802,11 @@ public InputT element(DoFn doFn) { throw new UnsupportedOperationException("Element parameters are not supported."); } + @Override + public Object sideInput(String tagId) { + throw new UnsupportedOperationException("SideInput parameters are not supported."); + } + @Override public Object schemaElement(int index) { throw new UnsupportedOperationException("Element parameters are not supported."); diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/SimpleDoFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/SimpleDoFnRunnerTest.java index c50ddfcb285ed..b790314fd27c1 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/SimpleDoFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/SimpleDoFnRunnerTest.java @@ -88,7 +88,8 @@ public void testProcessElementExceptionsWrappedAsUserCodeException() { null, Collections.emptyMap(), WindowingStrategy.of(new GlobalWindows()), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); thrown.expect(UserCodeException.class); thrown.expectCause(is(fn.exceptionToThrow)); @@ -111,7 +112,8 @@ public void testOnTimerExceptionsWrappedAsUserCodeException() { null, Collections.emptyMap(), WindowingStrategy.of(new GlobalWindows()), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); thrown.expect(UserCodeException.class); thrown.expectCause(is(fn.exceptionToThrow)); @@ -141,7 +143,8 @@ public void testTimerSet() { null, Collections.emptyMap(), WindowingStrategy.of(new GlobalWindows()), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); // Setting the timer needs the current time, as it is set relative Instant currentTime = new Instant(42); @@ -172,7 +175,8 @@ public void testStartBundleExceptionsWrappedAsUserCodeException() { null, Collections.emptyMap(), WindowingStrategy.of(new GlobalWindows()), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); thrown.expect(UserCodeException.class); thrown.expectCause(is(fn.exceptionToThrow)); @@ -195,7 +199,8 @@ public void testFinishBundleExceptionsWrappedAsUserCodeException() { null, Collections.emptyMap(), WindowingStrategy.of(new GlobalWindows()), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); thrown.expect(UserCodeException.class); thrown.expectCause(is(fn.exceptionToThrow)); @@ -222,7 +227,8 @@ public void testOnTimerCalled() { null, Collections.emptyMap(), WindowingStrategy.of(windowFn), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); Instant currentTime = new Instant(42); Duration offset = Duration.millis(37); @@ -264,7 +270,8 @@ public void testBackwardsInTimeNoSkew() { null, Collections.emptyMap(), WindowingStrategy.of(new GlobalWindows()), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); runner.startBundle(); // An element output at the current timestamp is fine. @@ -303,7 +310,8 @@ public void testSkew() { null, Collections.emptyMap(), WindowingStrategy.of(new GlobalWindows()), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); runner.startBundle(); // Outputting between "now" and "now - allowed skew" succeeds. @@ -343,7 +351,8 @@ public void testInfiniteSkew() { null, Collections.emptyMap(), WindowingStrategy.of(new GlobalWindows()), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); runner.startBundle(); runner.processElement( diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java index 06bbba26e4322..85b3c0b6cec23 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StatefulDoFnRunnerTest.java @@ -207,7 +207,8 @@ private DoFnRunner, Integer> getDoFnRunner( null, Collections.emptyMap(), WINDOWING_STRATEGY, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); } private static void advanceInputWatermark( diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java index d48c2410888a7..df4bc55139716 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/DirectGraphVisitor.java @@ -117,7 +117,8 @@ public void visitPrimitiveTransform(TransformHierarchy.Node node) { } } if (node.getTransform() instanceof ParDo.MultiOutput) { - consumedViews.addAll(((ParDo.MultiOutput) node.getTransform()).getSideInputs()); + consumedViews.addAll( + ((ParDo.MultiOutput) node.getTransform()).getSideInputs().values()); } else if (node.getTransform() instanceof ViewOverrideFactory.WriteView) { viewWriters.put( ((WriteView) node.getTransform()).getView(), node.toAppliedPTransform(getPipeline())); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java index 19e90030b8182..31eb80b43a0fd 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluator.java @@ -62,7 +62,8 @@ PushbackSideInputDoFnRunner createRunner( @Nullable Coder inputCoder, Map, Coder> outputCoders, WindowingStrategy windowingStrategy, - DoFnSchemaInformation doFnSchemaInformation); + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping); } public static DoFnRunnerFactory defaultRunnerFactory() { @@ -77,7 +78,8 @@ public static DoFnRunnerFactory defaultRunner schemaCoder, outputCoders, windowingStrategy, - doFnSchemaInformation) -> { + doFnSchemaInformation, + sideInputMapping) -> { DoFnRunner underlying = DoFnRunners.simpleRunner( options, @@ -90,7 +92,8 @@ public static DoFnRunnerFactory defaultRunner schemaCoder, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); return SimplePushbackSideInputDoFnRunner.create(underlying, sideInputs, sideInputReader); }; } @@ -109,6 +112,7 @@ public static ParDoEvaluator create( List> additionalOutputTags, Map, PCollection> outputs, DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping, DoFnRunnerFactory runnerFactory) { BundleOutputManager outputManager = createOutputManager(evaluationContext, key, outputs); @@ -133,7 +137,8 @@ public static ParDoEvaluator create( inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); return create(runner, stepContext, application, outputManager); } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java index 36aa295f9f7d2..9d2521577796f 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoEvaluatorFactory.java @@ -83,7 +83,8 @@ public TransformEvaluator forApplication( ParDoTranslation.getSideInputs(application), (TupleTag) ParDoTranslation.getMainOutputTag(application), ParDoTranslation.getAdditionalOutputTags(application).getAll(), - ParDoTranslation.getSchemaInformation(application)); + ParDoTranslation.getSchemaInformation(application), + ParDoTranslation.getSideInputMapping(application)); return evaluator; } @@ -107,7 +108,8 @@ DoFnLifecycleManagerRemovingTransformEvaluator createEvaluator( List> sideInputs, TupleTag mainOutputTag, List> additionalOutputTags, - DoFnSchemaInformation doFnSchemaInformation) + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) throws Exception { String stepName = evaluationContext.getStepName(application); DirectStepContext stepContext = @@ -126,6 +128,7 @@ DoFnLifecycleManagerRemovingTransformEvaluator createEvaluator( stepContext, fnManager.get(), doFnSchemaInformation, + sideInputMapping, fnManager), fnManager); } @@ -140,6 +143,7 @@ ParDoEvaluator createParDoEvaluator( DirectStepContext stepContext, DoFn fn, DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping, DoFnLifecycleManager fnManager) throws Exception { try { @@ -157,6 +161,7 @@ ParDoEvaluator createParDoEvaluator( additionalOutputTags, pcollections(application.getOutputs()), doFnSchemaInformation, + sideInputMapping, runnerFactory); } catch (Exception e) { try { diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java index 65b93556aa91e..f99d07fddb554 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/ParDoMultiOverrideFactory.java @@ -109,7 +109,8 @@ private PTransform, PCollectionTuple> getReplaceme ParDoTranslation.getMainOutputTag(application), ParDoTranslation.getAdditionalOutputTags(application), ParDoTranslation.getSideInputs(application), - ParDoTranslation.getSchemaInformation(application)); + ParDoTranslation.getSchemaInformation(application), + ParDoTranslation.getSideInputMapping(application)); } else { return application.getTransform(); } @@ -128,18 +129,21 @@ static class GbkThenStatefulParDo private final TupleTag mainOutputTag; private final List> sideInputs; private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; public GbkThenStatefulParDo( DoFn, OutputT> doFn, TupleTag mainOutputTag, TupleTagList additionalOutputTags, List> sideInputs, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.doFn = doFn; this.additionalOutputTags = additionalOutputTags; this.mainOutputTag = mainOutputTag; this.sideInputs = sideInputs; this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } @Override @@ -202,7 +206,12 @@ public PCollectionTuple expand(PCollection> input) { .apply( "Stateful ParDo", new StatefulParDo<>( - doFn, mainOutputTag, additionalOutputTags, sideInputs, doFnSchemaInformation)); + doFn, + mainOutputTag, + additionalOutputTags, + sideInputs, + doFnSchemaInformation, + sideInputMapping)); } } @@ -215,18 +224,21 @@ static class StatefulParDo private final TupleTag mainOutputTag; private final List> sideInputs; private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; public StatefulParDo( DoFn, OutputT> doFn, TupleTag mainOutputTag, TupleTagList additionalOutputTags, List> sideInputs, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.doFn = doFn; this.mainOutputTag = mainOutputTag; this.additionalOutputTags = additionalOutputTags; this.sideInputs = sideInputs; this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } public DoFn, OutputT> getDoFn() { @@ -249,6 +261,10 @@ public DoFnSchemaInformation getSchemaInformation() { return doFnSchemaInformation; } + public Map> getSideInputMapping() { + return sideInputMapping; + } + @Override public Map, PValue> getAdditionalInputs() { return PCollectionViews.toAdditionalInputs(sideInputs); diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java index 5162120e31432..b0ff2d36bd050 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/SplittableProcessElementsEvaluatorFactory.java @@ -20,6 +20,7 @@ import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; import java.util.Collection; +import java.util.Collections; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import org.apache.beam.runners.core.DoFnRunners; @@ -124,7 +125,8 @@ private TransformEvaluator>> crea application.getTransform().getSideInputs(), application.getTransform().getMainOutputTag(), application.getTransform().getAdditionalOutputTags().getAll(), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); final ParDoEvaluator>> pde = evaluator.getParDoEvaluator(); final ProcessFn processFn = @@ -188,7 +190,8 @@ public void outputWindowedValue( inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation) -> { + doFnSchemaInformation, + sideInputMapping) -> { ProcessFn processFn = (ProcessFn) fn; return DoFnRunners.newProcessFnRunner( processFn, @@ -202,7 +205,8 @@ public void outputWindowedValue( inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); }; } } diff --git a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java index fc25a7951887c..366ca05455fb7 100644 --- a/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java +++ b/runners/direct-java/src/main/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactory.java @@ -137,7 +137,8 @@ private TransformEvaluator>> createEvaluator( application.getTransform().getSideInputs(), application.getTransform().getMainOutputTag(), application.getTransform().getAdditionalOutputTags().getAll(), - application.getTransform().getSchemaInformation()); + application.getTransform().getSchemaInformation(), + application.getTransform().getSideInputMapping()); return new StatefulParDoEvaluator<>(delegateEvaluator); } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java index 6331a2c91a27c..90be7fdfc58b7 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/ParDoEvaluatorTest.java @@ -25,6 +25,7 @@ import java.util.ArrayList; import java.util.Collection; +import java.util.Collections; import java.util.List; import javax.annotation.Nullable; import org.apache.beam.runners.core.ReadyCheckingSideInputReader; @@ -161,6 +162,7 @@ private ParDoEvaluator createEvaluator( additionalOutputTags, ImmutableMap.of(mainOutputTag, output), DoFnSchemaInformation.create(), + Collections.emptyMap(), ParDoEvaluator.defaultRunnerFactory()); } diff --git a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java index d17512c548ff0..171d9dd1fb0e7 100644 --- a/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java +++ b/runners/direct-java/src/test/java/org/apache/beam/runners/direct/StatefulParDoEvaluatorFactoryTest.java @@ -142,7 +142,8 @@ public void process(ProcessContext c) {} mainOutput, TupleTagList.empty(), Collections.emptyList(), - DoFnSchemaInformation.create())) + DoFnSchemaInformation.create(), + Collections.emptyMap())) .get(mainOutput) .setCoder(VarIntCoder.of()); @@ -251,7 +252,8 @@ public void process(ProcessContext c) {} mainOutput, TupleTagList.empty(), Collections.singletonList(sideInput), - DoFnSchemaInformation.create())) + DoFnSchemaInformation.create(), + Collections.emptyMap())) .get(mainOutput) .setCoder(VarIntCoder.of()); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java index 76861e8dec202..bc418416d1dd1 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkBatchTransformTranslators.java @@ -502,12 +502,14 @@ public void translateNode( TupleTag mainOutputTag; DoFnSchemaInformation doFnSchemaInformation; + Map> sideInputMapping; try { mainOutputTag = ParDoTranslation.getMainOutputTag(context.getCurrentTransform()); } catch (IOException e) { throw new RuntimeException(e); } doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); + sideInputMapping = ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); Map, Integer> outputMap = Maps.newHashMap(); // put the main output at index 0, FlinkMultiOutputDoFnFunction expects this outputMap.put(mainOutputTag, 0); @@ -590,7 +592,8 @@ public void translateNode( mainOutputTag, inputCoder, outputCoderMap, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); // Based on the fact that the signature is stateful, DoFnSignatures ensures // that it is also keyed. @@ -611,7 +614,8 @@ public void translateNode( mainOutputTag, context.getInput(transform).getCoder(), outputCoderMap, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); outputDataSet = new MapPartitionOperator<>(inputDataSet, typeInformation, doFnWrapper, fullName); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java index 97b93fae58a81..88ed53af9e2d3 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingTransformTranslators.java @@ -425,7 +425,8 @@ DoFnOperator createDoFnOperator( Coder keyCoder, KeySelector, ?> keySelector, Map> transformedSideInputs, - DoFnSchemaInformation doFnSchemaInformation); + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping); } static void translateParDo( @@ -437,6 +438,7 @@ static void translateParDo( TupleTag mainOutputTag, List> additionalOutputTags, DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping, FlinkStreamingTranslationContext context, DoFnOperatorFactory doFnOperatorFactory) { @@ -516,7 +518,8 @@ static void translateParDo( keyCoder, keySelector, new HashMap<>() /* side-input mapping */, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); outputStream = inputDataStream.transform(transformName, outputTypeInformation, doFnOperator); @@ -543,7 +546,8 @@ static void translateParDo( keyCoder, keySelector, transformedSideInputs.f0, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); if (stateful) { // we have to manually construct the two-input transform because we're not @@ -621,6 +625,9 @@ public void translateNode( throw new RuntimeException(e); } + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + TupleTagList additionalOutputTags; try { additionalOutputTags = @@ -641,6 +648,7 @@ public void translateNode( mainOutputTag, additionalOutputTags.getAll(), doFnSchemaInformation, + sideInputMapping, context, (doFn1, stepName, @@ -658,7 +666,8 @@ public void translateNode( keyCoder, keySelector, transformedSideInputs, - doFnSchemaInformation1) -> + doFnSchemaInformation1, + sideInputMapping1) -> new DoFnOperator<>( doFn1, stepName, @@ -675,7 +684,8 @@ public void translateNode( context1.getPipelineOptions(), keyCoder, keySelector, - doFnSchemaInformation1)); + doFnSchemaInformation1, + sideInputMapping1)); } } @@ -700,6 +710,7 @@ public void translateNode( transform.getMainOutputTag(), transform.getAdditionalOutputTags().getAll(), DoFnSchemaInformation.create(), + Collections.emptyMap(), context, (doFn, stepName, @@ -717,7 +728,8 @@ public void translateNode( keyCoder, keySelector, transformedSideInputs, - doFnSchemaInformation) -> + doFnSchemaInformation, + sideInputMapping) -> new SplittableDoFnOperator<>( doFn, stepName, diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java index 84016fb44b816..6004008648841 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkDoFnFunction.java @@ -67,6 +67,7 @@ public class FlinkDoFnFunction private final Coder inputCoder; private final Map, Coder> outputCoderMap; private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; private transient DoFnInvoker doFnInvoker; @@ -80,7 +81,8 @@ public FlinkDoFnFunction( TupleTag mainOutputTag, Coder inputCoder, Map, Coder> outputCoderMap, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.doFn = doFn; this.stepName = stepName; @@ -92,6 +94,7 @@ public FlinkDoFnFunction( this.inputCoder = inputCoder; this.outputCoderMap = outputCoderMap; this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } @Override @@ -123,7 +126,8 @@ public void mapPartition( inputCoder, outputCoderMap, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); if ((serializedOptions.get().as(FlinkPipelineOptions.class)).getEnableMetrics()) { doFnRunner = new DoFnRunnerWithMetricsUpdate<>(stepName, doFnRunner, getRuntimeContext()); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java index 9615727d2994e..9a8e085fbaf02 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/functions/FlinkStatefulDoFnFunction.java @@ -69,6 +69,7 @@ public class FlinkStatefulDoFnFunction private final Coder> inputCoder; private final Map, Coder> outputCoderMap; private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; private transient DoFnInvoker doFnInvoker; public FlinkStatefulDoFnFunction( @@ -81,7 +82,8 @@ public FlinkStatefulDoFnFunction( TupleTag mainOutputTag, Coder> inputCoder, Map, Coder> outputCoderMap, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.dofn = dofn; this.stepName = stepName; @@ -93,6 +95,7 @@ public FlinkStatefulDoFnFunction( this.inputCoder = inputCoder; this.outputCoderMap = outputCoderMap; this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } @Override @@ -149,7 +152,8 @@ public TimerInternals timerInternals() { inputCoder, outputCoderMap, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); if ((serializedOptions.get().as(FlinkPipelineOptions.class)).getEnableMetrics()) { doFnRunner = new DoFnRunnerWithMetricsUpdate<>(stepName, doFnRunner, getRuntimeContext()); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java index fffb0e0140422..6fd5e85a4e129 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperator.java @@ -179,6 +179,8 @@ public class DoFnOperator extends AbstractStreamOperator> sideInputMapping; + /** If true, we must process elements only after a checkpoint is finished. */ private final boolean requiresStableInput; @@ -217,7 +219,8 @@ public DoFnOperator( PipelineOptions options, Coder keyCoder, KeySelector, ?> keySelector, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.doFn = doFn; this.stepName = stepName; this.windowedInputCoder = inputWindowedCoder; @@ -246,6 +249,7 @@ public DoFnOperator( this.maxBundleTimeMills = flinkOptions.getMaxBundleTimeMills(); Preconditions.checkArgument(maxBundleTimeMills > 0, "Bundle time must be at least 1"); this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; this.requiresStableInput = // WindowDoFnOperator does not use a DoFn @@ -418,7 +422,8 @@ public void open() throws Exception { inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); if (requiresStableInput) { // put this in front of the root FnRunner before any additional wrappers diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java index 8d2c1108b6f34..beda7747c7fd5 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/ExecutableStageDoFnOperator.java @@ -162,7 +162,8 @@ public ExecutableStageDoFnOperator( options, keyCoder, keySelector, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); this.payload = payload; this.jobInfo = jobInfo; this.contextFactory = contextFactory; diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java index b75ebc39f6020..4f8e0d752a348 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/SplittableDoFnOperator.java @@ -94,7 +94,8 @@ public SplittableDoFnOperator( options, keyCoder, keySelector, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); } @Override diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java index 858f614731c11..e2058bb08da58 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java @@ -80,7 +80,8 @@ public WindowDoFnOperator( options, keyCoder, keySelector, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); this.systemReduceFn = systemReduceFn; } diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java index 4daf6c7ebf3cf..e5e297a17413c 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/FlinkPipelineOptionsTest.java @@ -112,7 +112,8 @@ public void parDoBaseClassPipelineOptionsNullTest() { null, null, /* key coder */ null /* key selector */, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); } /** Tests that PipelineOptions are present after serialization. */ @@ -138,7 +139,8 @@ public void parDoBaseClassPipelineOptionsSerializationTest() throws Exception { options, null, /* key coder */ null /* key selector */, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); final byte[] serialized = SerializationUtils.serialize(doFnOperator); diff --git a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java index 27b9c08dcd469..ec4a2b90d34d0 100644 --- a/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java +++ b/runners/flink/src/test/java/org/apache/beam/runners/flink/translation/wrappers/streaming/DoFnOperatorTest.java @@ -141,7 +141,8 @@ public void testSingleOutput() throws Exception { PipelineOptionsFactory.as(FlinkPipelineOptions.class), null, null, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -202,7 +203,8 @@ public void testMultiOutputOutput() throws Exception { PipelineOptionsFactory.as(FlinkPipelineOptions.class), null, null, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -316,7 +318,8 @@ public void onProcessingTime(OnTimerContext context) { PipelineOptionsFactory.as(FlinkPipelineOptions.class), VarIntCoder.of(), /* key coder */ WindowedValue::getValue, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new KeyedOneInputStreamOperatorTestHarness<>( @@ -409,7 +412,8 @@ public void processElement(ProcessContext context) { PipelineOptionsFactory.as(FlinkPipelineOptions.class), VarIntCoder.of(), /* key coder */ WindowedValue::getValue, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new KeyedOneInputStreamOperatorTestHarness<>( @@ -517,7 +521,8 @@ public void onTimer(OnTimerContext context, @StateId(stateId) ValueState PipelineOptionsFactory.as(FlinkPipelineOptions.class), StringUtf8Coder.of(), /* key coder */ kvWindowedValue -> kvWindowedValue.getValue().getKey(), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); KeyedOneInputStreamOperatorTestHarness< String, WindowedValue>, WindowedValue>> @@ -622,7 +627,8 @@ void testSideInputs(boolean keyed) throws Exception { PipelineOptionsFactory.as(FlinkPipelineOptions.class), keyCoder, keyed ? WindowedValue::getValue : null, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); TwoInputStreamOperatorTestHarness, RawUnionValue, WindowedValue> testHarness = new TwoInputStreamOperatorTestHarness<>(doFnOperator); @@ -793,7 +799,8 @@ public void nonKeyedParDoSideInputCheckpointing() throws Exception { PipelineOptionsFactory.as(FlinkPipelineOptions.class), null, null, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); return new TwoInputStreamOperatorTestHarness<>(doFnOperator); }); @@ -831,7 +838,8 @@ public void keyedParDoSideInputCheckpointing() throws Exception { PipelineOptionsFactory.as(FlinkPipelineOptions.class), keyCoder, WindowedValue::getValue, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); return new KeyedTwoInputStreamOperatorTestHarness<>( doFnOperator, @@ -928,7 +936,8 @@ public void nonKeyedParDoPushbackDataCheckpointing() throws Exception { PipelineOptionsFactory.as(FlinkPipelineOptions.class), null, null, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); return new TwoInputStreamOperatorTestHarness<>(doFnOperator); }); @@ -967,7 +976,8 @@ public void keyedParDoPushbackDataCheckpointing() throws Exception { PipelineOptionsFactory.as(FlinkPipelineOptions.class), keyCoder, WindowedValue::getValue, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); return new KeyedTwoInputStreamOperatorTestHarness<>( doFnOperator, @@ -1142,7 +1152,8 @@ OneInputStreamOperatorTestHarness, WindowedValue> creat PipelineOptionsFactory.as(FlinkPipelineOptions.class), VarIntCoder.of() /* key coder */, keySelector, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); return new KeyedOneInputStreamOperatorTestHarness<>(doFnOperator, keySelector, keyCoderInfo); } @@ -1189,7 +1200,8 @@ public void finishBundle(FinishBundleContext context) { options, null, null, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); @@ -1238,7 +1250,8 @@ public void finishBundle(FinishBundleContext context) { options, null, null, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); OneInputStreamOperatorTestHarness, WindowedValue> newHarness = new OneInputStreamOperatorTestHarness<>(newDoFnOperator); @@ -1331,7 +1344,8 @@ public void finishBundle(FinishBundleContext context) { options, keyCoder, keySelector, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); OneInputStreamOperatorTestHarness< WindowedValue>, WindowedValue>> @@ -1387,7 +1401,8 @@ public void finishBundle(FinishBundleContext context) { options, keyCoder, keySelector, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); testHarness = new KeyedOneInputStreamOperatorTestHarness( @@ -1470,7 +1485,8 @@ public void finishBundle(FinishBundleContext context) { options, null, null, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); DoFnOperator doFnOperator = doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = @@ -1585,7 +1601,8 @@ public void finishBundle(FinishBundleContext context) { options, keyCoder, keySelector, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); DoFnOperator, KV> doFnOperator = doFnOperatorSupplier.get(); OneInputStreamOperatorTestHarness< @@ -1698,7 +1715,8 @@ public void processElement(ProcessContext context) { options, keyCoder, keySelector, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); } /** @@ -1744,7 +1762,8 @@ public void finishBundle(FinishBundleContext context) { options, null, null, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); OneInputStreamOperatorTestHarness, WindowedValue> testHarness = new OneInputStreamOperatorTestHarness<>(doFnOperator); diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java index 4ee52dc93de11..fb5202b44b74a 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/ParDoMultiOutputTranslator.java @@ -52,7 +52,7 @@ public class ParDoMultiOutputTranslator public void translate(ParDo.MultiOutput transform, TranslationContext context) { PCollection inputT = (PCollection) context.getInput(); JavaStream> inputStream = context.getInputStream(inputT); - Collection> sideInputs = transform.getSideInputs(); + Collection> sideInputs = transform.getSideInputs().values(); Map> tagsToSideInputs = TranslatorUtils.getTagsToSideInputs(sideInputs); @@ -77,6 +77,9 @@ public void translate(ParDo.MultiOutput transform, TranslationC DoFnSchemaInformation doFnSchemaInformation; doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + JavaStream outputStream = TranslatorUtils.toList(unionStream) .flatMap( @@ -89,7 +92,8 @@ public void translate(ParDo.MultiOutput transform, TranslationC mainOutput, outputCoders, sideOutputs, - doFnSchemaInformation), + doFnSchemaInformation, + sideInputMapping), transform.getName()); for (Map.Entry, PValue> output : outputs.entrySet()) { JavaStream> taggedStream = diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java index 30a86619d5b9f..b55af69ac1687 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/functions/DoFnFunction.java @@ -74,7 +74,8 @@ public DoFnFunction( TupleTag mainOutput, Map, Coder> outputCoders, List> sideOutputs, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.doFn = doFn; this.outputManager = new DoFnOutputManager(); this.doFnRunnerFactory = @@ -88,7 +89,8 @@ public DoFnFunction( new NoOpStepContext(), outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); this.sideInputs = sideInputs; this.tagsToSideInputs = sideInputTagMapping; this.mainOutput = mainOutput; diff --git a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java index 83c4b0896fe9b..7a6dab830bb5a 100644 --- a/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java +++ b/runners/gearpump/src/main/java/org/apache/beam/runners/gearpump/translators/utils/DoFnRunnerFactory.java @@ -50,6 +50,7 @@ public class DoFnRunnerFactory implements Serializable { private final List> sideOutputTags; private final StepContext stepContext; private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; Map, Coder> outputCoders; private final WindowingStrategy windowingStrategy; @@ -63,7 +64,8 @@ public DoFnRunnerFactory( StepContext stepContext, Map, Coder> outputCoders, WindowingStrategy windowingStrategy, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.fn = doFn; this.serializedOptions = new SerializablePipelineOptions(pipelineOptions); this.sideInputs = sideInputs; @@ -74,6 +76,7 @@ public DoFnRunnerFactory( this.outputCoders = outputCoders; this.windowingStrategy = windowingStrategy; this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } public PushbackSideInputDoFnRunner createRunner( @@ -91,7 +94,8 @@ public PushbackSideInputDoFnRunner createRunner( null, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); return SimplePushbackSideInputDoFnRunner.create(underlying, sideInputs, sideInputReader); } } diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java index 3e3c6434edeb6..c0c5e553036e2 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowPipelineTranslator.java @@ -919,14 +919,18 @@ private void translateMultiHelper( DoFnSchemaInformation doFnSchemaInformation; doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); - + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); Map, Coder> outputCoders = context.getOutputs(transform).entrySet().stream() .collect( Collectors.toMap( Map.Entry::getKey, e -> ((PCollection) e.getValue()).getCoder())); translateInputs( - stepContext, context.getInput(transform), transform.getSideInputs(), context); + stepContext, + context.getInput(transform), + transform.getSideInputs().values(), + context); translateOutputs(context.getOutputs(transform), stepContext); String ptransformId = context.getSdkComponents().getPTransformIdOrThrow(context.getCurrentTransform()); @@ -935,12 +939,13 @@ private void translateMultiHelper( ptransformId, transform.getFn(), context.getInput(transform).getWindowingStrategy(), - transform.getSideInputs(), + transform.getSideInputs().values(), context.getInput(transform).getCoder(), context, transform.getMainOutputTag(), outputCoders, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); // TODO: Move this logic into translateFn once the legacy ProcessKeyedElements is // removed. @@ -972,7 +977,8 @@ private void translateSingleHelper( DoFnSchemaInformation doFnSchemaInformation; doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); - + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); StepTranslationContext stepContext = context.addStep(transform, "ParallelDo"); Map, Coder> outputCoders = context.getOutputs(transform).entrySet().stream() @@ -981,7 +987,10 @@ private void translateSingleHelper( Map.Entry::getKey, e -> ((PCollection) e.getValue()).getCoder())); translateInputs( - stepContext, context.getInput(transform), transform.getSideInputs(), context); + stepContext, + context.getInput(transform), + transform.getSideInputs().values(), + context); stepContext.addOutput( transform.getMainOutputTag().getId(), context.getOutput(transform)); String ptransformId = @@ -991,12 +1000,13 @@ private void translateSingleHelper( ptransformId, transform.getFn(), context.getInput(transform).getWindowingStrategy(), - transform.getSideInputs(), + transform.getSideInputs().values(), context.getInput(transform).getCoder(), context, transform.getMainOutputTag(), outputCoders, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); // TODO: Move this logic into translateFn once the legacy ProcessKeyedElements is // removed. @@ -1058,7 +1068,8 @@ private void translateTyped( DoFnSchemaInformation doFnSchemaInformation; doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); - + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); StepTranslationContext stepContext = context.addStep(transform, "SplittableProcessKeyed"); Map, Coder> outputCoders = @@ -1081,7 +1092,8 @@ private void translateTyped( context, transform.getMainOutputTag(), outputCoders, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); stepContext.addInput( PropertyNames.RESTRICTION_CODER, @@ -1093,7 +1105,7 @@ private void translateTyped( private static void translateInputs( StepTranslationContext stepContext, PCollection input, - List> sideInputs, + Iterable> sideInputs, TranslationContext context) { stepContext.addInput(PropertyNames.PARALLEL_INPUT, input); translateSideInputs(stepContext, sideInputs, context); @@ -1102,7 +1114,7 @@ private static void translateInputs( // Used for ParDo private static void translateSideInputs( StepTranslationContext stepContext, - List> sideInputs, + Iterable> sideInputs, TranslationContext context) { Map nonParInputs = new HashMap<>(); @@ -1125,7 +1137,8 @@ private static void translateFn( TranslationContext context, TupleTag mainOutput, Map, Coder> outputCoders, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); if (signature.usesState() || signature.usesTimers()) { @@ -1151,7 +1164,8 @@ private static void translateFn( inputCoder, outputCoders, mainOutput, - doFnSchemaInformation)))); + doFnSchemaInformation, + sideInputMapping)))); } // Setting USES_KEYED_STATE will cause an ungrouped shuffle, which works diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java index dc16bfac4a7c1..390f0f0e6c0d8 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactory.java @@ -118,13 +118,13 @@ public TupleTag getMainOutputTag() { return onlyOutputTag; } - public List> getSideInputs() { + public Map> getSideInputs() { return original.getSideInputs(); } @Override public Map, PValue> getAdditionalInputs() { - return PCollectionViews.toAdditionalInputs(getSideInputs()); + return PCollectionViews.toAdditionalInputs(getSideInputs().values()); } @Override @@ -178,7 +178,7 @@ private static RunnerApi.ParDoPayload payloadForParDoSingle( Set allInputs = transform.getInputs().keySet().stream().map(TupleTag::getId).collect(Collectors.toSet()); Set sideInputs = - parDo.getSideInputs().stream() + parDo.getSideInputs().values().stream() .map(s -> s.getTagInternal().getId()) .collect(Collectors.toSet()); Set timerInputs = signature.timerDeclarations().keySet(); @@ -195,7 +195,11 @@ private static RunnerApi.ParDoPayload payloadForParDoSingle( @Override public RunnerApi.SdkFunctionSpec translateDoFn(SdkComponents newComponents) { return ParDoTranslation.translateDoFn( - parDo.getFn(), parDo.getMainOutputTag(), doFnSchemaInformation, newComponents); + parDo.getFn(), + parDo.getMainOutputTag(), + parDo.getSideInputs(), + doFnSchemaInformation, + newComponents); } @Override @@ -206,7 +210,8 @@ public List translateParameters() { @Override public Map translateSideInputs(SdkComponents components) { - return ParDoTranslation.translateSideInputs(parDo.getSideInputs(), components); + return ParDoTranslation.translateSideInputs( + parDo.getSideInputs().values().stream().collect(Collectors.toList()), components); } @Override diff --git a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java index a37af7b723b28..c55bffc10694f 100644 --- a/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java +++ b/runners/google-cloud-dataflow-java/src/test/java/org/apache/beam/runners/dataflow/PrimitiveParDoSingleFactoryTest.java @@ -115,7 +115,7 @@ public void getReplacementTransformGetSideInputs() { factory.getReplacementTransform(application); ParDoSingle parDoSingle = (ParDoSingle) replacementTransform.getTransform(); - assertThat(parDoSingle.getSideInputs(), containsInAnyOrder(sideStrings, sideLong)); + assertThat(parDoSingle.getSideInputs().values(), containsInAnyOrder(sideStrings, sideLong)); } @Test diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/CombineValuesFnFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/CombineValuesFnFactory.java index 203532439794e..5b865de75218f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/CombineValuesFnFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/CombineValuesFnFactory.java @@ -92,6 +92,7 @@ public ParDoFn create( executionContext.getStepContext(operationContext), operationContext, doFnInfo.getDoFnSchemaInformation(), + doFnInfo.getSideInputMapping(), SimpleDoFnRunnerFactory.INSTANCE); } @@ -141,7 +142,8 @@ private static class CombineValuesDoFn inputCoder, Collections.emptyMap(), // Not needed here. new TupleTag<>(PropertyNames.OUTPUT), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); } private final GlobalCombineFnRunner combineFnRunner; @@ -206,7 +208,8 @@ private static class AddInputsDoFn inputCoder, Collections.emptyMap(), // Not needed here. new TupleTag<>(PropertyNames.OUTPUT), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); } private final GlobalCombineFnRunner combineFnRunner; @@ -265,7 +268,8 @@ private static class MergeAccumulatorsDoFn inputCoder, Collections.emptyMap(), // Not needed here. new TupleTag<>(PropertyNames.OUTPUT), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); } private final GlobalCombineFnRunner combineFnRunner; @@ -314,7 +318,8 @@ private static class ExtractOutputDoFn inputCoder, Collections.emptyMap(), // Not needed here. new TupleTag<>(PropertyNames.OUTPUT), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); } private final GlobalCombineFnRunner combineFnRunner; diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DoFnRunnerFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DoFnRunnerFactory.java index 004a49bbd969e..93ecccafe3f4a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DoFnRunnerFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DoFnRunnerFactory.java @@ -44,5 +44,6 @@ DoFnRunner createRunner( DataflowExecutionContext.DataflowStepContext stepContext, DataflowExecutionContext.DataflowStepContext userStepContext, OutputManager outputManager, - DoFnSchemaInformation doFnSchemaInformation); + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleDoFnRunnerFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleDoFnRunnerFactory.java index 2601ceaf0e4ae..0804a327c85f8 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleDoFnRunnerFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleDoFnRunnerFactory.java @@ -49,7 +49,8 @@ public DoFnRunner createRunner( DataflowExecutionContext.DataflowStepContext stepContext, DataflowExecutionContext.DataflowStepContext userStepContext, OutputManager outputManager, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { DoFnRunner fnRunner = DoFnRunners.simpleRunner( options, @@ -62,7 +63,8 @@ public DoFnRunner createRunner( inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); boolean hasStreamingSideInput = options.as(StreamingOptions.class).isStreaming() && !sideInputReader.isEmpty(); if (hasStreamingSideInput) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java index a9af3673ef346..712a017fd27d1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java @@ -54,6 +54,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.DoFnInfo; import org.apache.beam.sdk.util.WindowedValue; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting; @@ -93,6 +94,7 @@ public class SimpleParDoFn implements ParDoFn { private final boolean hasStreamingSideInput; private final OutputsPerElementTracker outputsPerElementTracker; private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; // Various DoFn helpers, null between bundles @Nullable private DoFnRunner fnRunner; @@ -113,6 +115,7 @@ public class SimpleParDoFn implements ParDoFn { DataflowExecutionContext.DataflowStepContext stepContext, DataflowOperationContext operationContext, DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping, DoFnRunnerFactory runnerFactory) { this.options = options; this.doFnInstanceManager = doFnInstanceManager; @@ -143,6 +146,7 @@ public class SimpleParDoFn implements ParDoFn { options.as(StreamingOptions.class).isStreaming() && !sideInputReader.isEmpty(); this.outputsPerElementTracker = createOutputsPerElementTracker(); this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } private OutputsPerElementTracker createOutputsPerElementTracker() { @@ -302,7 +306,8 @@ public void output(TupleTag tag, WindowedValue output) { stepContext, userStepContext, outputManager, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); fnRunner.startBundle(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java index b90c49b2f481d..b840c6899f467 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SplittableProcessFnFactory.java @@ -96,7 +96,8 @@ private static class ProcessFnExtractor implements UserParDoFnFactory.DoFnExtrac doFnInfo.getWindowingStrategy().getWindowFn().windowCoder()), doFnInfo.getOutputCoders(), doFnInfo.getMainOutput(), - doFnInfo.getDoFnSchemaInformation()); + doFnInfo.getDoFnSchemaInformation(), + doFnInfo.getSideInputMapping()); } } @@ -121,7 +122,8 @@ public DoFnRunner>, OutputT> crea DataflowExecutionContext.DataflowStepContext stepContext, DataflowExecutionContext.DataflowStepContext userStepContext, OutputManager outputManager, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { ProcessFn processFn = (ProcessFn) fn; processFn.setStateInternalsFactory(key -> (StateInternals) stepContext.stateInternals()); @@ -170,7 +172,8 @@ public void outputWindowedValue( inputCoder, outputCoders, processFn.getInputWindowingStrategy(), - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); DoFnRunner>, OutputT> fnRunner = new DataflowProcessFnRunner<>(simpleRunner); boolean hasStreamingSideInput = diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactory.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactory.java index 910a558eba219..84d471217d9c2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactory.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactory.java @@ -120,6 +120,7 @@ public ParDoFn create( stepContext, operationContext, doFnInfo.getDoFnSchemaInformation(), + doFnInfo.getSideInputMapping(), runnerFactory)); } else if (doFnInfo.getDoFn() instanceof StreamingPCollectionViewWriterFn) { @@ -147,6 +148,7 @@ public ParDoFn create( stepContext, operationContext, doFnInfo.getDoFnSchemaInformation(), + doFnInfo.getSideInputMapping(), runnerFactory); } } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DefaultParDoFnFactoryTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DefaultParDoFnFactoryTest.java index e14488f154c66..a82878dc5191c 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DefaultParDoFnFactoryTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DefaultParDoFnFactoryTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; +import java.util.Collections; import org.apache.beam.runners.dataflow.util.CloudObject; import org.apache.beam.runners.dataflow.worker.counters.CounterSet; import org.apache.beam.runners.dataflow.worker.util.common.worker.OutputReceiver; @@ -91,7 +92,8 @@ public void testCreateSimpleParDoFn() throws Exception { null /* side input views */, null /* input coder */, new TupleTag<>("output") /* main output */, - DoFnSchemaInformation.create()))); + DoFnSchemaInformation.create(), + Collections.emptyMap()))); CloudObject cloudUserFn = CloudObject.forClassName("DoFn"); addString(cloudUserFn, "serialized_fn", serializedFn); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DoFnInstanceManagersTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DoFnInstanceManagersTest.java index 31b7ff606befd..904b462f90a5f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DoFnInstanceManagersTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/DoFnInstanceManagersTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; +import java.util.Collections; import org.apache.beam.runners.dataflow.util.PropertyNames; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; @@ -67,7 +68,8 @@ public void testInstanceReturnsInstance() throws Exception { null /* side input views */, null /* input coder */, new TupleTag<>(PropertyNames.OUTPUT) /* main output id */, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); DoFnInstanceManager mgr = DoFnInstanceManagers.singleInstance(info); assertThat(mgr.peek(), Matchers.>theInstance(info)); @@ -88,7 +90,8 @@ public void testInstanceIgnoresAbort() throws Exception { null /* side input views */, null /* input coder */, new TupleTag<>(PropertyNames.OUTPUT) /* main output id */, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); DoFnInstanceManager mgr = DoFnInstanceManagers.singleInstance(info); mgr.abort(mgr.get()); @@ -108,7 +111,8 @@ public void testCloningPoolReusesAfterComplete() throws Exception { null /* side input views */, null /* input coder */, new TupleTag<>(PropertyNames.OUTPUT) /* main output id */, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); DoFnInstanceManager mgr = DoFnInstanceManagers.cloningPool(info); DoFnInfo retrievedInfo = mgr.get(); @@ -130,7 +134,8 @@ public void testCloningPoolTearsDownAfterAbort() throws Exception { null /* side input views */, null /* input coder */, new TupleTag<>(PropertyNames.OUTPUT) /* main output id */, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); DoFnInstanceManager mgr = DoFnInstanceManagers.cloningPool(info); DoFnInfo retrievedInfo = mgr.get(); @@ -154,7 +159,8 @@ public void testCloningPoolMultipleOutstanding() throws Exception { null /* side input views */, null /* input coder */, new TupleTag<>(PropertyNames.OUTPUT) /* main output id */, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); DoFnInstanceManager mgr = DoFnInstanceManagers.cloningPool(info); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java index 7584220043ad8..a4178f48e18ad 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/IntrinsicMapTaskExecutorFactoryTest.java @@ -51,6 +51,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.function.Function; import javax.annotation.Nullable; @@ -484,7 +485,8 @@ static ParallelInstruction createParDoInstruction( null /* side input views */, null /* input coder */, new TupleTag<>(PropertyNames.OUTPUT) /* main output id */, - DoFnSchemaInformation.create()))); + DoFnSchemaInformation.create(), + Collections.emptyMap()))); CloudObject cloudUserFn = CloudObject.forClassName("DoFn"); addString(cloudUserFn, PropertyNames.SERIALIZED_FN, serializedFn); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java index 71b03a9507435..2933bf50baf94 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnTest.java @@ -205,7 +205,8 @@ public void testOutputReceivers() throws Exception { null /* side input views */, null /* input coder */, MAIN_OUTPUT, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); TestReceiver receiver = new TestReceiver(); TestReceiver receiver1 = new TestReceiver(); TestReceiver receiver2 = new TestReceiver(); @@ -230,6 +231,7 @@ public void testOutputReceivers() throws Exception { .getStepContext(operationContext), operationContext, DoFnSchemaInformation.create(), + Collections.emptyMap(), SimpleDoFnRunnerFactory.INSTANCE); userParDoFn.startBundle(receiver, receiver1, receiver2, receiver3); @@ -284,7 +286,8 @@ public void testUnexpectedNumberOfReceivers() throws Exception { null /* side input views */, null /* input coder */, MAIN_OUTPUT, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); TestReceiver receiver = new TestReceiver(); ParDoFn userParDoFn = @@ -298,6 +301,7 @@ public void testUnexpectedNumberOfReceivers() throws Exception { .getStepContext(operationContext), operationContext, DoFnSchemaInformation.create(), + Collections.emptyMap(), SimpleDoFnRunnerFactory.INSTANCE); try { @@ -333,7 +337,8 @@ public void testErrorPropagation() throws Exception { null /* side input views */, null /* input coder */, MAIN_OUTPUT, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); TestReceiver receiver = new TestReceiver(); ParDoFn userParDoFn = @@ -347,6 +352,7 @@ public void testErrorPropagation() throws Exception { .getStepContext(operationContext), operationContext, DoFnSchemaInformation.create(), + Collections.emptyMap(), SimpleDoFnRunnerFactory.INSTANCE); try { @@ -424,7 +430,8 @@ public void testUndeclaredSideOutputs() throws Exception { null /* side input views */, null /* input coder */, MAIN_OUTPUT, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); CounterSet counters = new CounterSet(); TestOperationContext operationContext = TestOperationContext.create(counters); ParDoFn userParDoFn = @@ -438,6 +445,7 @@ public void testUndeclaredSideOutputs() throws Exception { .getStepContext(operationContext), operationContext, DoFnSchemaInformation.create(), + Collections.emptyMap(), SimpleDoFnRunnerFactory.INSTANCE); userParDoFn.startBundle(new TestReceiver(), new TestReceiver()); @@ -484,7 +492,8 @@ public void processElement(ProcessContext c) throws Exception { null /* side input views */, null /* input coder */, MAIN_OUTPUT, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); ParDoFn userParDoFn = new SimpleParDoFn<>( @@ -498,6 +507,7 @@ public void processElement(ProcessContext c) throws Exception { .getStepContext(operationContext), operationContext, DoFnSchemaInformation.create(), + Collections.emptyMap(), SimpleDoFnRunnerFactory.INSTANCE); // This test ensures proper behavior of the state sampling even with lazy initialization. @@ -575,7 +585,8 @@ public void processElement(ProcessContext c) { null /* side input views */, null /* input coder */, MAIN_OUTPUT, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); ParDoFn parDoFn = new SimpleParDoFn<>( @@ -587,6 +598,7 @@ public void processElement(ProcessContext c) { stepContext, operationContext, DoFnSchemaInformation.create(), + Collections.emptyMap(), SimpleDoFnRunnerFactory.INSTANCE); parDoFn.startBundle(new TestReceiver()); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java index 96f46cde1be6d..356dddbedada6 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingDataflowWorkerTest.java @@ -319,7 +319,8 @@ private ParallelInstruction makeDoFnInstruction( null /* side input views */, null /* input coder */, new TupleTag<>(PropertyNames.OUTPUT) /* main output id */, - DoFnSchemaInformation.create())))); + DoFnSchemaInformation.create(), + Collections.emptyMap())))); return new ParallelInstruction() .setSystemName(DEFAULT_PARDO_SYSTEM_NAME) .setName(DEFAULT_PARDO_USER_NAME) diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java index 7ad76218f04b1..24d17ff4b2155 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java @@ -408,7 +408,8 @@ private StreamingSideInputDoFnRunner null, Collections.emptyMap(), WindowingStrategy.of(windowFn), - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); return new StreamingSideInputDoFnRunner<>(simpleDoFnRunner, sideInputFetcher); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java index c4ef8f5aa7bc4..5e688f3f4a699 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/UserParDoFnFactoryTest.java @@ -311,7 +311,8 @@ private CloudObject getCloudObject(DoFn fn, WindowingStrategy window null /* side input views */, null /* input coder */, new TupleTag<>(PropertyNames.OUTPUT) /* main output id */, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); object.set( PropertyNames.SERIALIZED_FN, StringUtils.byteArrayToJsonString(SerializableUtils.serializeToByteArray(info))); diff --git a/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/Utils.java b/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/Utils.java index 5dd9ba0521333..7489bea643e6d 100644 --- a/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/Utils.java +++ b/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/Utils.java @@ -149,10 +149,10 @@ static List> getSideInputs(AppliedPTransform applied PTransform transform = appliedTransform.getTransform(); if (transform instanceof ParDo.MultiOutput) { ParDo.MultiOutput multiParDo = (ParDo.MultiOutput) transform; - return multiParDo.getSideInputs(); + return (List) multiParDo.getSideInputs().values().stream().collect(Collectors.toList()); } else if (transform instanceof ParDo.SingleOutput) { ParDo.SingleOutput singleParDo = (ParDo.SingleOutput) transform; - return singleParDo.getSideInputs(); + return (List) singleParDo.getSideInputs().values().stream().collect(Collectors.toList()); } return Collections.emptyList(); } diff --git a/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/AbstractParDoP.java b/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/AbstractParDoP.java index c7f2354bc93aa..5287149479ad7 100644 --- a/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/AbstractParDoP.java +++ b/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/AbstractParDoP.java @@ -95,6 +95,7 @@ abstract class AbstractParDoP implements Processor { private SideInputReader sideInputReader; private Outbox outbox; private long lastMetricsFlushTime = System.currentTimeMillis(); + private Map> sideInputMapping; AbstractParDoP( DoFn doFn, @@ -165,7 +166,8 @@ public void init(@Nonnull Outbox outbox, @Nonnull Context context) { inputValueCoder, outputValueCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); } protected abstract DoFnRunner getDoFnRunner( @@ -178,7 +180,8 @@ protected abstract DoFnRunner getDoFnRunner( Coder inputValueCoder, Map, Coder> outputValueCoders, WindowingStrategy windowingStrategy, - DoFnSchemaInformation doFnSchemaInformation); + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping); @Override public boolean isCooperative() { diff --git a/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/ParDoP.java b/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/ParDoP.java index aa54f8e7c924a..38f07c9cae0e5 100644 --- a/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/ParDoP.java +++ b/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/ParDoP.java @@ -86,7 +86,8 @@ protected DoFnRunner getDoFnRunner( Coder inputValueCoder, Map, Coder> outputValueCoders, WindowingStrategy windowingStrategy, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { return DoFnRunners.simpleRunner( pipelineOptions, doFn, @@ -98,7 +99,8 @@ protected DoFnRunner getDoFnRunner( inputValueCoder, outputValueCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); } /** diff --git a/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/StatefulParDoP.java b/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/StatefulParDoP.java index 0b9c5aa04ca19..e291117f1ad67 100644 --- a/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/StatefulParDoP.java +++ b/runners/jet-experimental/src/main/java/org/apache/beam/runners/jet/processors/StatefulParDoP.java @@ -106,7 +106,8 @@ private static void fireTimer( Coder> inputValueCoder, Map, Coder> outputValueCoders, WindowingStrategy windowingStrategy, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { timerInternals = new InMemoryTimerInternals(); keyedStepContext = new KeyedStepContext(timerInternals); return DoFnRunners.simpleRunner( @@ -120,7 +121,8 @@ private static void fireTimer( inputValueCoder, outputValueCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); } @Override diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java index b167718ed93a6..e5df2411ebd5c 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/DoFnOp.java @@ -104,6 +104,7 @@ public class DoFnOp implements Op { private transient List> pushbackValues; private transient StageBundleFactory stageBundleFactory; private DoFnSchemaInformation doFnSchemaInformation; + private Map> sideInputMapping; public DoFnOp( TupleTag mainOutputTag, @@ -122,7 +123,8 @@ public DoFnOp( boolean isPortable, RunnerApi.ExecutableStagePayload stagePayload, Map> idToTupleTagMap, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.mainOutputTag = mainOutputTag; this.doFn = doFn; this.sideInputs = sideInputs; @@ -140,6 +142,7 @@ public DoFnOp( this.stagePayload = stagePayload; this.idToTupleTagMap = new HashMap<>(idToTupleTagMap); this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } @Override @@ -207,7 +210,8 @@ public void open( inputCoder, sideOutputTags, outputCoders, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); } this.pushbackFnRunner = diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/GroupByKeyOp.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/GroupByKeyOp.java index 4748ebb0350a0..8c96d97e1f888 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/GroupByKeyOp.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/GroupByKeyOp.java @@ -176,7 +176,8 @@ public TimerInternals timerInternals() { null, Collections.emptyMap(), windowingStrategy, - DoFnSchemaInformation.create()); + DoFnSchemaInformation.create(), + Collections.emptyMap()); final SamzaExecutionContext executionContext = (SamzaExecutionContext) context.getApplicationContainerContext(); diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java index 9074f5164d647..3f032e1eac212 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/runtime/SamzaDoFnRunners.java @@ -45,6 +45,7 @@ import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.WindowingStrategy; @@ -71,7 +72,8 @@ public static DoFnRunner create( Coder inputCoder, List> sideOutputTags, Map, Coder> outputCoders, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { final KeyedInternals keyedInternals; final TimerInternals timerInternals; final StateInternals stateInternals; @@ -104,7 +106,8 @@ public static DoFnRunner create( inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); final DoFnRunner doFnRunnerWithMetrics = pipelineOptions.getEnableMetrics() diff --git a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java index 5143bf2bb154d..49d32daf2de90 100644 --- a/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java +++ b/runners/samza/src/main/java/org/apache/beam/runners/samza/translation/ParDoBoundMultiTranslator.java @@ -106,7 +106,7 @@ private static void doTranslate( final MessageStream> inputStream = ctx.getMessageStream(input); final List>> sideInputStreams = - transform.getSideInputs().stream() + transform.getSideInputs().values().stream() .map(ctx::getViewStream) .collect(Collectors.toList()); final ArrayList, PValue>> outputs = @@ -128,13 +128,16 @@ private static void doTranslate( } final HashMap> idToPValueMap = new HashMap<>(); - for (PCollectionView view : transform.getSideInputs()) { + for (PCollectionView view : transform.getSideInputs().values()) { idToPValueMap.put(ctx.getViewId(view), view); } DoFnSchemaInformation doFnSchemaInformation; doFnSchemaInformation = ParDoTranslation.getSchemaInformation(ctx.getCurrentTransform()); + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(ctx.getCurrentTransform()); + final DoFnOp op = new DoFnOp<>( transform.getMainOutputTag(), @@ -142,7 +145,7 @@ private static void doTranslate( keyCoder, (Coder) input.getCoder(), outputCoders, - transform.getSideInputs(), + transform.getSideInputs().values(), transform.getAdditionalOutputTags().getAll(), input.getWindowingStrategy(), idToPValueMap, @@ -154,7 +157,8 @@ private static void doTranslate( false, null, Collections.emptyMap(), - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); final MessageStream> mergedStreams; if (sideInputStreams.isEmpty()) { @@ -240,6 +244,9 @@ private static void doTranslatePortable( final DoFnSchemaInformation doFnSchemaInformation; doFnSchemaInformation = ParDoTranslation.getSchemaInformation(transform.getTransform()); + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(transform.getTransform()); + final RunnerApi.PCollection input = pipeline.getComponents().getPcollectionsOrThrow(inputId); final PCollection.IsBounded isBounded = SamzaPipelineTranslatorUtils.isBounded(input); @@ -262,7 +269,8 @@ private static void doTranslatePortable( true, stagePayload, idToTupleTagMap, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); final MessageStream> mergedStreams; if (sideInputStreams.isEmpty()) { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java index 44e53dd9e6cb5..ee2a5812fe89b 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/MultiDoFnFunction.java @@ -42,6 +42,7 @@ import org.apache.beam.sdk.util.SerializableUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Function; @@ -75,6 +76,7 @@ public class MultiDoFnFunction private final WindowingStrategy windowingStrategy; private final boolean stateful; private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; /** * @param metricsAccum The Spark {@link AccumulatorV2} that backs the Beam metrics. @@ -100,7 +102,8 @@ public MultiDoFnFunction( Map, KV, SideInputBroadcast>> sideInputs, WindowingStrategy windowingStrategy, boolean stateful, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.metricsAccum = metricsAccum; this.stepName = stepName; this.doFn = SerializableUtils.clone(doFn); @@ -113,6 +116,7 @@ public MultiDoFnFunction( this.windowingStrategy = windowingStrategy; this.stateful = stateful; this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } @Override @@ -166,7 +170,8 @@ public TimerInternals timerInternals() { inputCoder, outputCoders, windowingStrategy, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); DoFnRunnerWithMetrics doFnRunnerWithMetrics = new DoFnRunnerWithMetrics<>(stepName, doFnRunner, metricsAccum); diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java index 23560e21335c8..a575ec43798cd 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TransformTranslator.java @@ -367,6 +367,9 @@ public void evaluate( doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); + Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + MultiDoFnFunction multiDoFnFunction = new MultiDoFnFunction<>( metricsAccum, @@ -377,10 +380,11 @@ public void evaluate( transform.getAdditionalOutputTags().getAll(), inputCoder, outputCoders, - TranslationUtils.getSideInputs(transform.getSideInputs(), context), + TranslationUtils.getSideInputs(transform.getSideInputs().values(), context), windowingStrategy, stateful, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); if (stateful) { // Based on the fact that the signature is stateful, DoFnSignatures ensures diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java index a4e6081488922..16a4ca9266c39 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/TranslationUtils.java @@ -21,7 +21,6 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Iterator; -import java.util.List; import java.util.Map; import javax.annotation.Nonnull; import org.apache.beam.runners.core.InMemoryStateInternals; @@ -228,7 +227,7 @@ public Boolean call(Tuple2, WindowedValue> input) { * @return a map of tagged {@link SideInputBroadcast}s and their {@link WindowingStrategy}. */ static Map, KV, SideInputBroadcast>> getSideInputs( - List> views, EvaluationContext context) { + Iterable> views, EvaluationContext context) { return getSideInputs(views, context.getSparkContext(), context.getPViews()); } @@ -241,7 +240,7 @@ public Boolean call(Tuple2, WindowedValue> input) { * @return a map of tagged {@link SideInputBroadcast}s and their {@link WindowingStrategy}. */ public static Map, KV, SideInputBroadcast>> getSideInputs( - List> views, JavaSparkContext context, SparkPCollectionView pviews) { + Iterable> views, JavaSparkContext context, SparkPCollectionView pviews) { if (views == null) { return ImmutableMap.of(); } else { diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java index 2480c6bfa1796..ee43f7076552a 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/translation/streaming/StreamingTransformTranslator.java @@ -79,6 +79,7 @@ import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; @@ -396,6 +397,9 @@ public void evaluate( final DoFnSchemaInformation doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform()); + final Map> sideInputMapping = + ParDoTranslation.getSideInputMapping(context.getCurrentTransform()); + final String stepName = context.getCurrentTransform().getFullName(); JavaPairDStream, WindowedValue> all = dStream.transformToPair( @@ -405,7 +409,7 @@ public void evaluate( final Map, KV, SideInputBroadcast>> sideInputs = TranslationUtils.getSideInputs( - transform.getSideInputs(), + transform.getSideInputs().values(), JavaSparkContext.fromSparkContext(rdd.context()), pviews); @@ -422,7 +426,8 @@ public void evaluate( sideInputs, windowingStrategy, false, - doFnSchemaInformation)); + doFnSchemaInformation, + sideInputMapping)); }); Map, PValue> outputs = context.getOutputs(transform); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java index 2352e3abb5171..4aa8dbf3a0791 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFn.java @@ -625,6 +625,14 @@ public interface MultiOutputReceiver { @Target(ElementType.PARAMETER) public @interface Timestamp {} + /** Parameter annotation for the SideInput for a {@link ProcessElement} method. */ + @Documented + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.PARAMETER) + public @interface SideInput { + /** The SideInput tag ID. */ + String value(); + } /** * Experimental - no backwards compatibility guarantees. The exact name or usage of this * feature may change. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java index 8926c05dcc6c6..9ceb37e1a6677 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/DoFnTester.java @@ -250,6 +250,11 @@ public InputT element(DoFn doFn) { return processContext.element(); } + @Override + public InputT sideInput(String sideInputTag) { + throw new UnsupportedOperationException("SideInputs are not supported by DoFnTester"); + } + @Override public InputT schemaElement(int index) { throw new UnsupportedOperationException("Schemas are not supported by DoFnTester"); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index 6e5efae20d563..dad8b4f54b47e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -24,8 +24,9 @@ import java.lang.reflect.Type; import java.util.Arrays; import java.util.Collections; -import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.StreamSupport; import javax.annotation.Nullable; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.PipelineRunner; @@ -64,7 +65,7 @@ import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; import org.apache.beam.sdk.values.TypeDescriptor; -import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; /** * {@link ParDo} is the core element-wise transform in Apache Beam, invoking a user-specified @@ -395,7 +396,7 @@ public class ParDo { */ public static SingleOutput of(DoFn fn) { validate(fn); - return new SingleOutput<>(fn, Collections.emptyList(), displayDataForFn(fn)); + return new SingleOutput<>(fn, Collections.emptyMap(), displayDataForFn(fn)); } private static DisplayData.ItemSpec> displayDataForFn(T fn) { @@ -562,7 +563,6 @@ private static void validate(DoFn fn) { fn.getClass().getName())); } } - /** * Extract information on how the DoFn uses schemas. In particular, if the schema of an element * parameter does not match the input PCollection's schema, convert. @@ -629,19 +629,18 @@ public static class SingleOutput private static final String MAIN_OUTPUT_TAG = "output"; - private final List> sideInputs; + private final Map> sideInputs; private final DoFn fn; private final DisplayData.ItemSpec> fnDisplayData; SingleOutput( DoFn fn, - List> sideInputs, + Map> sideInputs, DisplayData.ItemSpec> fnDisplayData) { this.fn = fn; this.fnDisplayData = fnDisplayData; this.sideInputs = sideInputs; } - /** * Returns a new {@link ParDo} {@link PTransform} that's like this {@link PTransform} but with * the specified additional side inputs. Does not modify this {@link PTransform}. @@ -660,15 +659,38 @@ public SingleOutput withSideInputs(PCollectionView... sideIn */ public SingleOutput withSideInputs( Iterable> sideInputs) { + Map> mappedInputs = + StreamSupport.stream(sideInputs.spliterator(), false) + .collect(Collectors.toMap(v -> v.getTagInternal().getId(), v -> v)); + return withSideInputs(mappedInputs); + } + + /** + * Returns a new {@link ParDo} {@link PTransform} that's like this {@link PTransform} but with + * the specified additional side inputs. Does not modify this {@link PTransform}. + * + *

See the discussion of Side Inputs above for more explanation. + */ + public SingleOutput withSideInputs( + Map> sideInputs) { return new SingleOutput<>( fn, - ImmutableList.>builder() - .addAll(this.sideInputs) - .addAll(sideInputs) + ImmutableMap.>builder() + .putAll(this.sideInputs) + .putAll(sideInputs) .build(), fnDisplayData); } + /** + * Returns a new {@link ParDo} {@link PTransform} that's like this {@link PTransform} but with + * the specified additional side inputs. Does not modify this {@link PTransform}. + */ + public SingleOutput withSideInput( + String tagId, PCollectionView pCollectionView) { + return withSideInputs(Collections.singletonMap(tagId, pCollectionView)); + } + /** * Returns a new multi-output {@link ParDo} {@link PTransform} that's like this {@link * PTransform} but with the specified output tags. Does not modify this {@link PTransform}. @@ -685,7 +707,6 @@ public PCollection expand(PCollection input) { SchemaRegistry schemaRegistry = input.getPipeline().getSchemaRegistry(); CoderRegistry registry = input.getPipeline().getCoderRegistry(); finishSpecifyingStateSpecs(fn, registry, input.getCoder()); - TupleTag mainOutput = new TupleTag<>(MAIN_OUTPUT_TAG); PCollection res = input.apply(withOutputTags(mainOutput, TupleTagList.empty())).get(mainOutput); @@ -732,7 +753,7 @@ public DoFn getFn() { return fn; } - public List> getSideInputs() { + public Map> getSideInputs() { return sideInputs; } @@ -743,7 +764,7 @@ public List> getSideInputs() { */ @Override public Map, PValue> getAdditionalInputs() { - return PCollectionViews.toAdditionalInputs(sideInputs); + return PCollectionViews.toAdditionalInputs(sideInputs.values()); } @Override @@ -763,7 +784,7 @@ public String toString() { */ public static class MultiOutput extends PTransform, PCollectionTuple> { - private final List> sideInputs; + private final Map> sideInputs; private final TupleTag mainOutputTag; private final TupleTagList additionalOutputTags; private final DisplayData.ItemSpec> fnDisplayData; @@ -771,7 +792,7 @@ public static class MultiOutput MultiOutput( DoFn fn, - List> sideInputs, + Map> sideInputs, TupleTag mainOutputTag, TupleTagList additionalOutputTags, ItemSpec> fnDisplayData) { @@ -793,6 +814,14 @@ public MultiOutput withSideInputs(PCollectionView... sideInp return withSideInputs(Arrays.asList(sideInputs)); } + public MultiOutput withSideInputs( + Iterable> sideInputs) { + Map> mappedInputs = + StreamSupport.stream(sideInputs.spliterator(), false) + .collect(Collectors.toMap(v -> v.getTagInternal().getId(), v -> v)); + return withSideInputs(mappedInputs); + } + /** * Returns a new multi-output {@link ParDo} {@link PTransform} that's like this {@link * PTransform} but with the specified additional side inputs. Does not modify this {@link @@ -800,19 +829,28 @@ public MultiOutput withSideInputs(PCollectionView... sideInp * *

See the discussion of Side Inputs above for more explanation. */ - public MultiOutput withSideInputs( - Iterable> sideInputs) { + public MultiOutput withSideInputs(Map> sideInputs) { return new MultiOutput<>( fn, - ImmutableList.>builder() - .addAll(this.sideInputs) - .addAll(sideInputs) + ImmutableMap.>builder() + .putAll(this.sideInputs) + .putAll(sideInputs) .build(), mainOutputTag, additionalOutputTags, fnDisplayData); } + /** + * Returns a new multi-output {@link ParDo} {@link PTransform} that's like this {@link + * PTransform} but with the specified additional side inputs. Does not modify this {@link + * PTransform}. + */ + public MultiOutput withSideInput( + String tagId, PCollectionView pCollectionView) { + return withSideInputs(Collections.singletonMap(tagId, pCollectionView)); + } + @Override public PCollectionTuple expand(PCollection input) { // SplittableDoFn should be forbidden on the runner-side. @@ -883,7 +921,7 @@ public TupleTagList getAdditionalOutputTags() { return additionalOutputTags; } - public List> getSideInputs() { + public Map> getSideInputs() { return sideInputs; } @@ -894,7 +932,7 @@ public List> getSideInputs() { */ @Override public Map, PValue> getAdditionalInputs() { - return PCollectionViews.toAdditionalInputs(sideInputs); + return PCollectionViews.toAdditionalInputs(sideInputs.values()); } @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java index 7e86972d2b308..2f54c27b09788 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/ByteBuddyDoFnInvokerFactory.java @@ -45,6 +45,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ProcessContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StartBundleContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TaggedOutputReceiverParameter; @@ -112,6 +113,7 @@ public class ByteBuddyDoFnInvokerFactory implements DoFnInvokerFactory { public static final String RESTRICTION_TRACKER_PARAMETER_METHOD = "restrictionTracker"; public static final String STATE_PARAMETER_METHOD = "state"; public static final String TIMER_PARAMETER_METHOD = "timer"; + public static final String SIDE_INPUT_PARAMETER_METHOD = "sideInput"; /** * Returns a {@link ByteBuddyDoFnInvokerFactory} shared with all other invocations, so that its @@ -787,6 +789,16 @@ public StackManipulation dispatch(TimerParameter p) { public StackManipulation dispatch(DoFnSignature.Parameter.PipelineOptionsParameter p) { return simpleExtraContextParameter(PIPELINE_OPTIONS_PARAMETER_METHOD); } + + @Override + public StackManipulation dispatch(SideInputParameter p) { + return new StackManipulation.Compound( + new TextConstant(p.sideInputId()), + MethodInvocation.invoke( + getExtraContextFactoryMethodDescription( + SIDE_INPUT_PARAMETER_METHOD, String.class)), + TypeCasting.to(new TypeDescription.ForLoadedType(p.elementT().getRawType()))); + } }); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java index d8504eef220d7..fe31c6431d000 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnInvoker.java @@ -135,6 +135,9 @@ interface ArgumentProvider { /** Provide a reference to the input element. */ InputT element(DoFn doFn); + /** Provide a reference to the input sideInput with the specified tag. */ + Object sideInput(String tagId); + /** * Provide a reference to the selected schema field corresponding to the input argument * specified by index. @@ -190,6 +193,14 @@ public InputT element(DoFn doFn) { FakeArgumentProvider.class.getSimpleName())); } + @Override + public InputT sideInput(String tagId) { + throw new UnsupportedOperationException( + String.format( + "Should never call non-overridden methods of %s", + FakeArgumentProvider.class.getSimpleName())); + } + @Override public InputT schemaElement(int index) { throw new UnsupportedOperationException( diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index 69641aede15ef..bc81e38d09ddf 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -242,6 +242,8 @@ public ResultT match(Cases cases) { return cases.dispatch((TaggedOutputReceiverParameter) this); } else if (this instanceof TimeDomainParameter) { return cases.dispatch((TimeDomainParameter) this); + } else if (this instanceof SideInputParameter) { + return cases.dispatch((SideInputParameter) this); } else { throw new IllegalStateException( String.format( @@ -284,6 +286,8 @@ public interface Cases { ResultT dispatch(PipelineOptionsParameter p); + ResultT dispatch(SideInputParameter p); + /** A base class for a visitor with a default method for cases it is not interested in. */ abstract class WithDefault implements Cases { @@ -368,6 +372,11 @@ public ResultT dispatch(TimerParameter p) { public ResultT dispatch(PipelineOptionsParameter p) { return dispatchDefault(p); } + + @Override + public ResultT dispatch(SideInputParameter p) { + return dispatchDefault(p); + } } } @@ -413,6 +422,14 @@ public static TimestampParameter timestampParameter() { return TIMESTAMP_PARAMETER; } + public static SideInputParameter sideInputParameter( + TypeDescriptor elementT, String sideInputId) { + return new AutoValue_DoFnSignature_Parameter_SideInputParameter.Builder() + .setElementT(elementT) + .setSideInputId(sideInputId) + .build(); + } + public static TimeDomainParameter timeDomainParameter() { return TIME_DOMAIN_PARAMETER; } @@ -556,6 +573,28 @@ public abstract static class TimeDomainParameter extends Parameter { TimeDomainParameter() {} } + /** Descriptor for a {@link Parameter} of type {@link DoFn.SideInput}. */ + @AutoValue + public abstract static class SideInputParameter extends Parameter { + SideInputParameter() {} + + public abstract TypeDescriptor elementT(); + + public abstract String sideInputId(); + + /** Builder class. */ + @AutoValue.Builder + public abstract static class Builder { + public abstract SideInputParameter.Builder setElementT(TypeDescriptor elementT); + + public abstract SideInputParameter.Builder setSideInputId(String sideInput); + + public abstract SideInputParameter build(); + } + + public abstract SideInputParameter.Builder toBuilder(); + } + /** * Descriptor for a {@link Parameter} of type {@link DoFn.OutputReceiver}. * diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java index fa54e841f4dab..b3fde4fb39bd3 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignatures.java @@ -49,6 +49,7 @@ import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.MultiOutputReceiver; import org.apache.beam.sdk.transforms.DoFn.OutputReceiver; +import org.apache.beam.sdk.transforms.DoFn.SideInput; import org.apache.beam.sdk.transforms.DoFn.StateId; import org.apache.beam.sdk.transforms.DoFn.TimerId; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.FieldAccessDeclaration; @@ -96,7 +97,8 @@ private DoFnSignatures() {} Parameter.PaneInfoParameter.class, Parameter.PipelineOptionsParameter.class, Parameter.TimerParameter.class, - Parameter.StateParameter.class); + Parameter.StateParameter.class, + Parameter.SideInputParameter.class); private static final ImmutableList> ALLOWED_SPLITTABLE_PROCESS_ELEMENT_PARAMETERS = @@ -107,7 +109,8 @@ private DoFnSignatures() {} Parameter.OutputReceiverParameter.class, Parameter.TaggedOutputReceiverParameter.class, Parameter.ProcessContextParameter.class, - Parameter.RestrictionTrackerParameter.class); + Parameter.RestrictionTrackerParameter.class, + Parameter.SideInputParameter.class); private static final ImmutableList> ALLOWED_ON_TIMER_PARAMETERS = ImmutableList.of( @@ -255,7 +258,6 @@ public Map getStateParameters() { public Map getTimerParameters() { return Collections.unmodifiableMap(timerParameters); } - /** Extra parameters in their entirety. Unmodifiable. */ public List getExtraParameters() { return Collections.unmodifiableList(extraParameters); @@ -904,6 +906,11 @@ private static Parameter analyzeExtraParameter( return Parameter.timestampParameter(); } else if (rawType.equals(TimeDomain.class)) { return Parameter.timeDomainParameter(); + } else if (hasSideInputAnnotation(param.getAnnotations())) { + String sideInputId = getSideInputId(param.getAnnotations()); + paramErrors.checkArgument( + sideInputId != null, "%s missing %s annotation", SideInput.class.getSimpleName()); + return Parameter.sideInputParameter(paramT, sideInputId); } else if (rawType.equals(PaneInfo.class)) { return Parameter.paneInfoParameter(); } else if (rawType.equals(DoFn.ProcessContext.class)) { @@ -1055,6 +1062,12 @@ private static String getFieldAccessId(List annotations) { return access != null ? access.value() : null; } + @Nullable + private static String getSideInputId(List annotations) { + DoFn.SideInput sideInputId = findFirstOfType(annotations, DoFn.SideInput.class); + return sideInputId != null ? sideInputId.value() : null; + } + @Nullable static T findFirstOfType(List annotations, Class clazz) { Optional annotation = @@ -1070,6 +1083,10 @@ private static boolean hasTimestampAnnotation(List annotations) { return annotations.stream().anyMatch(a -> a.annotationType().equals(DoFn.Timestamp.class)); } + private static boolean hasSideInputAnnotation(List annotations) { + return annotations.stream().anyMatch(a -> a.annotationType().equals(DoFn.SideInput.class)); + } + @Nullable private static TypeDescriptor getTrackerType(TypeDescriptor fnClass, Method method) { Type[] params = method.getGenericParameterTypes(); diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnInfo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnInfo.java index 32998b4d3b2df..9097277e283af 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnInfo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnInfo.java @@ -41,6 +41,7 @@ public class DoFnInfo implements Serializable { Map, Coder> outputCoders; private final TupleTag mainOutput; private final DoFnSchemaInformation doFnSchemaInformation; + private final Map> sideInputMapping; /** * Creates a {@link DoFnInfo} for the given {@link DoFn}. @@ -54,7 +55,8 @@ public static DoFnInfo forFn( Iterable> sideInputViews, Coder inputCoder, TupleTag mainOutput, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { return new DoFnInfo<>( doFn, windowingStrategy, @@ -62,7 +64,8 @@ public static DoFnInfo forFn( inputCoder, Collections.emptyMap(), mainOutput, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); } /** Creates a {@link DoFnInfo} for the given {@link DoFn}. */ @@ -73,7 +76,8 @@ public static DoFnInfo forFn( Coder inputCoder, Map, Coder> outputCoders, TupleTag mainOutput, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { return new DoFnInfo<>( doFn, windowingStrategy, @@ -81,7 +85,8 @@ public static DoFnInfo forFn( inputCoder, outputCoders, mainOutput, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); } public DoFnInfo withFn(DoFn newFn) { @@ -92,7 +97,8 @@ public DoFnInfo withFn(DoFn newFn) { inputCoder, outputCoders, mainOutput, - doFnSchemaInformation); + doFnSchemaInformation, + sideInputMapping); } private DoFnInfo( @@ -102,7 +108,8 @@ private DoFnInfo( Coder inputCoder, Map, Coder> outputCoders, TupleTag mainOutput, - DoFnSchemaInformation doFnSchemaInformation) { + DoFnSchemaInformation doFnSchemaInformation, + Map> sideInputMapping) { this.doFn = doFn; this.windowingStrategy = windowingStrategy; this.sideInputViews = sideInputViews; @@ -110,6 +117,7 @@ private DoFnInfo( this.outputCoders = outputCoders; this.mainOutput = mainOutput; this.doFnSchemaInformation = doFnSchemaInformation; + this.sideInputMapping = sideInputMapping; } /** Returns the embedded function. */ @@ -140,4 +148,8 @@ public TupleTag getMainOutput() { public DoFnSchemaInformation getDoFnSchemaInformation() { return doFnSchemaInformation; } + + public Map> getSideInputMapping() { + return sideInputMapping; + } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnWithExecutionInformation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnWithExecutionInformation.java index a812ecad988ff..0306a15bbf45e 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnWithExecutionInformation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/DoFnWithExecutionInformation.java @@ -19,21 +19,29 @@ import com.google.auto.value.AutoValue; import java.io.Serializable; +import java.util.Map; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.TupleTag; /** The data that the Java SDK harness needs to execute a DoFn. */ @AutoValue public abstract class DoFnWithExecutionInformation implements Serializable { public static DoFnWithExecutionInformation of( - DoFn fn, TupleTag tag, DoFnSchemaInformation doFnSchemaInformation) { - return new AutoValue_DoFnWithExecutionInformation(fn, tag, doFnSchemaInformation); + DoFn fn, + TupleTag tag, + Map> sideInputMapping, + DoFnSchemaInformation doFnSchemaInformation) { + return new AutoValue_DoFnWithExecutionInformation( + fn, tag, sideInputMapping, doFnSchemaInformation); } public abstract DoFn getDoFn(); public abstract TupleTag getMainOutputTag(); + public abstract Map> getSideInputMapping(); + public abstract DoFnSchemaInformation getSchemaInformation(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionView.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionView.java index 74363c0103386..038cf42c3db9a 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionView.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionView.java @@ -58,7 +58,6 @@ public interface PCollectionView extends PValue, Serializable { @Nullable @Internal PCollection getPCollection(); - /** * For internal use only. * diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java index 29ff0d4567ccc..366ad3fb0d5b1 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java @@ -98,7 +98,6 @@ public static PCollectionView> listView( windowingStrategy.getWindowFn().getDefaultWindowMappingFn(), windowingStrategy); } - /** * Returns a {@code PCollectionView>} capable of processing elements windowed using the * provided {@link WindowingStrategy}. diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 724bc3f01491f..04587e74960ab 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -53,7 +53,6 @@ import java.util.Set; import org.apache.beam.sdk.coders.AtomicCoder; import org.apache.beam.sdk.coders.Coder; -import org.apache.beam.sdk.coders.CoderException; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.ListCoder; import org.apache.beam.sdk.coders.SetCoder; @@ -738,6 +737,93 @@ public void testParDoWithSideInputs() { pipeline.run(); } + @Test + @Category({ValidatesRunner.class, UsesSideInputs.class}) + public void testSideInputAnnotation() { + + final PCollectionView> sideInput1 = + pipeline + .apply("CreateSideInput1", Create.of(2, 1, 0)) + .apply("ViewSideInput1", View.asList()); + + // SideInput tag id + final String sideInputTag1 = "tag1"; + + DoFn> fn = + new DoFn>() { + @ProcessElement + public void processElement( + OutputReceiver> r, @SideInput(sideInputTag1) List tag1) { + + List sideSorted = Lists.newArrayList(tag1); + Collections.sort(sideSorted); + r.output(sideSorted); + } + }; + + PCollection> output = + pipeline + .apply("Create main input", Create.of(2)) + .apply(ParDo.of(fn).withSideInput(sideInputTag1, sideInput1)); + + PAssert.that(output).containsInAnyOrder(Lists.newArrayList(0, 1, 2)); + pipeline.run(); + } + + @Test + @Category({ValidatesRunner.class, UsesSideInputs.class}) + public void testSideInputAnnotationWithMultipleSideInputs() { + + final PCollectionView> sideInput1 = + pipeline + .apply("CreateSideInput1", Create.of(2, 0)) + .apply("ViewSideInput1", View.asList()); + + final PCollectionView sideInput2 = + pipeline + .apply("CreateSideInput2", Create.of(5)) + .apply("ViewSideInput2", View.asSingleton()); + + final PCollectionView> sideInput3 = + pipeline + .apply("CreateSideInput3", Create.of(1, 3)) + .apply("ViewSideInput3", View.asList()); + + // SideInput tag id + final String sideInputTag1 = "tag1"; + final String sideInputTag2 = "tag2"; + final String sideInputTag3 = "tag3"; + + DoFn> fn = + new DoFn>() { + @ProcessElement + public void processElement( + OutputReceiver> r, + @SideInput(sideInputTag1) List tag1, + @SideInput(sideInputTag2) Integer tag2, + @SideInput(sideInputTag3) List tag3) { + + List sideSorted = Lists.newArrayList(tag1); + sideSorted.add(tag2); + sideSorted.addAll(tag3); + Collections.sort(sideSorted); + r.output(sideSorted); + } + }; + + PCollection> output = + pipeline + .apply("Create main input", Create.of(2)) + .apply( + ParDo.of(fn) + .withSideInput(sideInputTag1, sideInput1) + .withSideInput(sideInputTag2, sideInput2) + .withSideInput(sideInputTag3, sideInput3)); + + PAssert.that(output).containsInAnyOrder(Lists.newArrayList(0, 1, 2, 3, 5)); + pipeline.run(); + } + @Test @Category({ValidatesRunner.class, UsesSideInputs.class}) public void testParDoWithSideInputsIsCumulative() { @@ -1967,7 +2053,6 @@ public void testBagStateSideInput() { final PCollectionView> listView = pipeline.apply("Create list for side input", Create.of(2, 1, 0)).apply(View.asList()); - final String stateId = "foo"; DoFn, List> fn = new DoFn, List>() { @@ -3337,11 +3422,10 @@ public static TestDummyCoder of() { } @Override - public void encode(TestDummy value, OutputStream outStream) - throws CoderException, IOException {} + public void encode(TestDummy value, OutputStream outStream) throws IOException {} @Override - public TestDummy decode(InputStream inStream) throws CoderException, IOException { + public TestDummy decode(InputStream inStream) throws IOException { return new TestDummy(); } @@ -3445,12 +3529,12 @@ public static MyIntegerCoder of() { } @Override - public void encode(MyInteger value, OutputStream outStream) throws CoderException, IOException { + public void encode(MyInteger value, OutputStream outStream) throws IOException { delegate.encode(value.getValue(), outStream); } @Override - public MyInteger decode(InputStream inStream) throws CoderException, IOException { + public MyInteger decode(InputStream inStream) throws IOException { return new MyInteger(delegate.decode(inStream)); } } diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java index 8c80cfc42497f..e0427f28224fe 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/reflect/DoFnSignaturesTest.java @@ -56,6 +56,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.PipelineOptionsParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.ProcessContextParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TaggedOutputReceiverParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimeDomainParameter; @@ -110,10 +111,12 @@ public void process( BoundedWindow window, PaneInfo paneInfo, OutputReceiver receiver, - PipelineOptions options) {} + PipelineOptions options, + @SideInput("tag1") String input1, + @SideInput("tag2") Integer input2) {} }.getClass()); - assertThat(sig.processElement().extraParameters().size(), equalTo(6)); + assertThat(sig.processElement().extraParameters().size(), equalTo(8)); assertThat(sig.processElement().extraParameters().get(0), instanceOf(ElementParameter.class)); assertThat(sig.processElement().extraParameters().get(1), instanceOf(TimestampParameter.class)); assertThat(sig.processElement().extraParameters().get(2), instanceOf(WindowParameter.class)); @@ -122,6 +125,8 @@ public void process( sig.processElement().extraParameters().get(4), instanceOf(OutputReceiverParameter.class)); assertThat( sig.processElement().extraParameters().get(5), instanceOf(PipelineOptionsParameter.class)); + assertThat(sig.processElement().extraParameters().get(6), instanceOf(SideInputParameter.class)); + assertThat(sig.processElement().extraParameters().get(7), instanceOf(SideInputParameter.class)); } @Test diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java index cc55d92a5c19d..67d7d0c35bd5f 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/FnApiDoFnRunner.java @@ -396,6 +396,11 @@ public InputT element(DoFn doFn) { return element(); } + @Override + public Object sideInput(String tagId) { + throw new UnsupportedOperationException("hello world"); + } + @Override public Object schemaElement(int index) { SerializableFunction converter = doFnSchemaInformation.getElementConverters().get(index); @@ -580,6 +585,11 @@ public InputT element(DoFn doFn) { throw new UnsupportedOperationException("Element parameters are not supported."); } + @Override + public InputT sideInput(String tagId) { + throw new UnsupportedOperationException("SideInput parameters are not supported."); + } + @Override public Object schemaElement(int index) { throw new UnsupportedOperationException("Element parameters are not supported.");