Skip to content

Commit

Permalink
Merge pull request #9275: [BEAM-6858] Support side inputs injected in…
Browse files Browse the repository at this point in the history
…to a DoFn
  • Loading branch information
salmanVD authored and reuvenlax committed Aug 24, 2019
1 parent f085cb5 commit 64262a6
Show file tree
Hide file tree
Showing 75 changed files with 769 additions and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.datatorrent.api.Operator;
import com.datatorrent.api.Operator.OutputPort;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -41,6 +42,7 @@
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -76,11 +78,14 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, TranslationC

Map<TupleTag<?>, PValue> outputs = context.getOutputs();
PCollection<InputT> input = context.getInput();
List<PCollectionView<?>> sideInputs = transform.getSideInputs();
Iterable<PCollectionView<?>> sideInputs = transform.getSideInputs().values();

DoFnSchemaInformation doFnSchemaInformation;
doFnSchemaInformation = ParDoTranslation.getSchemaInformation(context.getCurrentTransform());

Map<String, PCollectionView<?>> sideInputMapping =
ParDoTranslation.getSideInputMapping(context.getCurrentTransform());

Map<TupleTag<?>, Coder<?>> outputCoders =
outputs.entrySet().stream()
.filter(e -> e.getValue() instanceof PCollection)
Expand All @@ -97,6 +102,7 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, TranslationC
input.getCoder(),
outputCoders,
doFnSchemaInformation,
sideInputMapping,
context.getStateBackend());

Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size());
Expand Down Expand Up @@ -124,7 +130,7 @@ public void translate(ParDo.MultiOutput<InputT, OutputT> transform, TranslationC
}
context.addOperator(operator, ports);
context.addStream(context.getInput(), operator.input);
if (!sideInputs.isEmpty()) {
if (!Iterables.isEmpty(sideInputs)) {
addSideInputs(operator.sideInput1, sideInputs, context);
}
}
Expand All @@ -139,7 +145,7 @@ public void translate(

Map<TupleTag<?>, PValue> outputs = context.getOutputs();
PCollection<InputT> input = context.getInput();
List<PCollectionView<?>> sideInputs = transform.getSideInputs();
Iterable<PCollectionView<?>> sideInputs = transform.getSideInputs();

Map<TupleTag<?>, Coder<?>> outputCoders =
outputs.entrySet().stream()
Expand All @@ -160,6 +166,7 @@ public void translate(
input.getCoder(),
outputCoders,
DoFnSchemaInformation.create(),
Collections.emptyMap(),
context.getStateBackend());

Map<PCollection<?>, OutputPort<?>> ports = Maps.newHashMapWithExpectedSize(outputs.size());
Expand Down Expand Up @@ -188,37 +195,37 @@ public void translate(

context.addOperator(operator, ports);
context.addStream(context.getInput(), operator.input);
if (!sideInputs.isEmpty()) {
if (!Iterables.isEmpty(sideInputs)) {
addSideInputs(operator.sideInput1, sideInputs, context);
}
}
}

static void addSideInputs(
Operator.InputPort<?> sideInputPort,
List<PCollectionView<?>> sideInputs,
Iterable<PCollectionView<?>> sideInputs,
TranslationContext context) {
Operator.InputPort<?>[] sideInputPorts = {sideInputPort};
if (sideInputs.size() > sideInputPorts.length) {
if (Iterables.size(sideInputs) > sideInputPorts.length) {
PCollection<?> unionCollection = unionSideInputs(sideInputs, context);
context.addStream(unionCollection, sideInputPorts[0]);
} else {
// the number of ports for side inputs is fixed and each port can only take one input.
for (int i = 0; i < sideInputs.size(); i++) {
context.addStream(context.getViewInput(sideInputs.get(i)), sideInputPorts[i]);
for (int i = 0; i < Iterables.size(sideInputs); i++) {
context.addStream(context.getViewInput(Iterables.get(sideInputs, i)), sideInputPorts[i]);
}
}
}

private static PCollection<?> unionSideInputs(
List<PCollectionView<?>> sideInputs, TranslationContext context) {
checkArgument(sideInputs.size() > 1, "requires multiple side inputs");
Iterable<PCollectionView<?>> sideInputs, TranslationContext context) {
checkArgument(Iterables.size(sideInputs) > 1, "requires multiple side inputs");
// flatten and assign union tag
List<PCollection<Object>> sourceCollections = new ArrayList<>();
Map<PCollection<?>, Integer> unionTags = new HashMap<>();
PCollection<Object> firstSideInput = context.getViewInput(sideInputs.get(0));
for (int i = 0; i < sideInputs.size(); i++) {
PCollectionView<?> sideInput = sideInputs.get(i);
PCollection<Object> firstSideInput = context.getViewInput(Iterables.get(sideInputs, 0));
for (int i = 0; i < Iterables.size(sideInputs); i++) {
PCollectionView<?> sideInput = Iterables.get(sideInputs, i);
PCollection<?> sideInputCollection = context.getViewInput(sideInput);
if (!sideInputCollection
.getWindowingStrategy()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.joda.time.Duration;
import org.joda.time.Instant;
Expand Down Expand Up @@ -115,7 +116,7 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator
private final WindowingStrategy<?, ?> windowingStrategy;

@Bind(JavaSerializer.class)
private final List<PCollectionView<?>> sideInputs;
private final Iterable<PCollectionView<?>> sideInputs;

@Bind(JavaSerializer.class)
private final Coder<WindowedValue<InputT>> windowedInputCoder;
Expand All @@ -129,6 +130,9 @@ public class ApexParDoOperator<InputT, OutputT> extends BaseOperator
@Bind(JavaSerializer.class)
private final DoFnSchemaInformation doFnSchemaInformation;

@Bind(JavaSerializer.class)
private final Map<String, PCollectionView<?>> sideInputMapping;

private StateInternalsProxy<?> currentKeyStateInternals;
private final ApexTimerInternals<Object> currentKeyTimerInternals;

Expand All @@ -152,10 +156,11 @@ public ApexParDoOperator(
TupleTag<OutputT> mainOutputTag,
List<TupleTag<?>> additionalOutputTags,
WindowingStrategy<?, ?> windowingStrategy,
List<PCollectionView<?>> sideInputs,
Iterable<PCollectionView<?>> sideInputs,
Coder<InputT> inputCoder,
Map<TupleTag<?>, Coder<?>> outputCoders,
DoFnSchemaInformation doFnSchemaInformation,
Map<String, PCollectionView<?>> sideInputMapping,
ApexStateBackend stateBackend) {
this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions);
this.doFn = doFn;
Expand Down Expand Up @@ -186,6 +191,7 @@ public ApexParDoOperator(
TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
this.currentKeyTimerInternals = new ApexTimerInternals<>(timerCoder);
this.doFnSchemaInformation = doFnSchemaInformation;
this.sideInputMapping = sideInputMapping;

if (doFn instanceof ProcessFn) {
// we know that it is keyed on byte[]
Expand Down Expand Up @@ -219,6 +225,7 @@ private ApexParDoOperator() {
this.outputCoders = Collections.emptyMap();
this.currentKeyTimerInternals = null;
this.doFnSchemaInformation = null;
this.sideInputMapping = null;
}

public final transient DefaultInputPort<ApexStreamTuple<WindowedValue<InputT>>> input =
Expand Down Expand Up @@ -260,7 +267,7 @@ public void process(ApexStreamTuple<WindowedValue<Iterable<?>>> t) {
LOG.debug("\nsideInput {} {}\n", sideInputIndex, t.getValue());
}

PCollectionView<?> sideInput = sideInputs.get(sideInputIndex);
PCollectionView<?> sideInput = Iterables.get(sideInputs, sideInputIndex);
sideInputHandler.addSideInputValue(sideInput, t.getValue());

List<WindowedValue<InputT>> newPushedBack = new ArrayList<>();
Expand Down Expand Up @@ -408,7 +415,7 @@ private void processWatermark(ApexStreamTuple.WatermarkTuple<?> mark) {
currentInputWatermark);
}
}
if (sideInputs.isEmpty()) {
if (Iterables.isEmpty(sideInputs)) {
outputWatermark(mark);
return;
}
Expand Down Expand Up @@ -439,8 +446,9 @@ public void setup(OperatorContext context) {
ApexStreamTuple.Logging.isDebugEnabled(
pipelineOptions.get().as(ApexPipelineOptions.class), this);
SideInputReader sideInputReader = NullSideInputReader.of(sideInputs);
if (!sideInputs.isEmpty()) {
sideInputHandler = new SideInputHandler(sideInputs, sideInputStateInternals);
if (!Iterables.isEmpty(sideInputs)) {
sideInputHandler =
new SideInputHandler(Lists.newArrayList(sideInputs), sideInputStateInternals);
sideInputReader = sideInputHandler;
}

Expand Down Expand Up @@ -476,7 +484,8 @@ public TimerInternals timerInternals() {
inputCoder,
outputCoders,
windowingStrategy,
doFnSchemaInformation);
doFnSchemaInformation,
sideInputMapping);

doFnInvoker = DoFnInvokers.invokerFor(doFn);
doFnInvoker.invokeSetup();
Expand All @@ -501,7 +510,8 @@ public TimerInternals timerInternals() {
}

pushbackDoFnRunner =
SimplePushbackSideInputDoFnRunner.create(doFnRunner, sideInputs, sideInputHandler);
SimplePushbackSideInputDoFnRunner.create(
doFnRunner, Lists.newArrayList(sideInputs), sideInputHandler);

if (doFn instanceof ProcessFn) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ public void testSerialization() throws Exception {
VarIntCoder.of(),
Collections.emptyMap(),
DoFnSchemaInformation.create(),
Collections.emptyMap(),
new ApexStateInternals.ApexStateBackend());
operator.setup(null);
operator.beginWindow(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public static ParDoPayload translateParDo(
.map(TupleTag::getId)
.collect(Collectors.toSet());
Set<String> sideInputs =
parDo.getSideInputs().stream()
parDo.getSideInputs().values().stream()
.map(s -> s.getTagInternal().getId())
.collect(Collectors.toSet());
Set<String> timerInputs = signature.timerDeclarations().keySet();
Expand Down Expand Up @@ -209,7 +209,11 @@ public static ParDoPayload translateParDo(
@Override
public SdkFunctionSpec translateDoFn(SdkComponents newComponents) {
return ParDoTranslation.translateDoFn(
parDo.getFn(), parDo.getMainOutputTag(), doFnSchemaInformation, newComponents);
parDo.getFn(),
parDo.getMainOutputTag(),
parDo.getSideInputs(),
doFnSchemaInformation,
newComponents);
}

@Override
Expand All @@ -221,7 +225,7 @@ public List<RunnerApi.Parameter> translateParameters() {
@Override
public Map<String, SideInput> translateSideInputs(SdkComponents components) {
Map<String, SideInput> sideInputs = new HashMap<>();
for (PCollectionView<?> sideInput : parDo.getSideInputs()) {
for (PCollectionView<?> sideInput : parDo.getSideInputs().values()) {
sideInputs.put(
sideInput.getTagInternal().getId(), translateView(sideInput, components));
}
Expand Down Expand Up @@ -315,6 +319,28 @@ public static TupleTag<?> getMainOutputTag(ParDoPayload payload)
return doFnWithExecutionInformationFromProto(payload.getDoFn()).getMainOutputTag();
}

public static Map<String, PCollectionView<?>> getSideInputMapping(
AppliedPTransform<?, ?, ?> application) {
try {
return getSideInputMapping(getParDoPayload(application));
} catch (IOException e) {
throw new RuntimeException(e);
}
}

public static Map<String, PCollectionView<?>> getSideInputMapping(
RunnerApi.PTransform pTransform) {
try {
return getSideInputMapping(getParDoPayload(pTransform));
} catch (IOException e) {
throw new RuntimeException(e);
}
}

public static Map<String, PCollectionView<?>> getSideInputMapping(ParDoPayload payload) {
return doFnWithExecutionInformationFromProto(payload.getDoFn()).getSideInputMapping();
}

public static TupleTag<?> getMainOutputTag(AppliedPTransform<?, ?, ?> application)
throws IOException {
PTransform<?, ?> transform = application.getTransform();
Expand Down Expand Up @@ -359,7 +385,8 @@ public static List<PCollectionView<?>> getSideInputs(AppliedPTransform<?, ?, ?>
throws IOException {
PTransform<?, ?> transform = application.getTransform();
if (transform instanceof ParDo.MultiOutput) {
return ((ParDo.MultiOutput<?, ?>) transform).getSideInputs();
return ((ParDo.MultiOutput<?, ?>) transform)
.getSideInputs().values().stream().collect(Collectors.toList());
}

SdkComponents sdkComponents = SdkComponents.create(application.getPipeline().getOptions());
Expand Down Expand Up @@ -551,6 +578,7 @@ private static RunnerApi.TimeDomain.Enum translateTimeDomain(TimeDomain timeDoma
public static SdkFunctionSpec translateDoFn(
DoFn<?, ?> fn,
TupleTag<?> tag,
Map<String, PCollectionView<?>> sideInputMapping,
DoFnSchemaInformation doFnSchemaInformation,
SdkComponents components) {
return SdkFunctionSpec.newBuilder()
Expand All @@ -561,7 +589,8 @@ public static SdkFunctionSpec translateDoFn(
.setPayload(
ByteString.copyFrom(
SerializableUtils.serializeToByteArray(
DoFnWithExecutionInformation.of(fn, tag, doFnSchemaInformation))))
DoFnWithExecutionInformation.of(
fn, tag, sideInputMapping, doFnSchemaInformation))))
.build())
.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import com.google.auto.service.AutoService;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ThreadLocalRandom;
Expand Down Expand Up @@ -368,7 +369,11 @@ public FunctionSpec translate(
public SdkFunctionSpec translateDoFn(SdkComponents newComponents) {
// Schemas not yet supported on splittable DoFn.
return ParDoTranslation.translateDoFn(
fn, pke.getMainOutputTag(), DoFnSchemaInformation.create(), newComponents);
fn,
pke.getMainOutputTag(),
Collections.emptyMap(),
DoFnSchemaInformation.create(),
newComponents);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,11 @@ public InputT element(DoFn<InputT, OutputT> doFn) {
return element;
}

@Override
public Object sideInput(String tagId) {
throw new UnsupportedOperationException();
}

@Override
public Object schemaElement(int index) {
throw new UnsupportedOperationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ public void testToProto() throws Exception {

assertThat(ParDoTranslation.getDoFn(payload), equalTo(parDo.getFn()));
assertThat(ParDoTranslation.getMainOutputTag(payload), equalTo(parDo.getMainOutputTag()));
for (PCollectionView<?> view : parDo.getSideInputs()) {
for (PCollectionView<?> view : parDo.getSideInputs().values()) {
payload.getSideInputsOrThrow(view.getTagInternal().getId());
}
}
Expand All @@ -148,7 +148,7 @@ public void toTransformProto() throws Exception {

// Decode
ParDoPayload parDoPayload = ParDoPayload.parseFrom(protoTransform.getSpec().getPayload());
for (PCollectionView<?> view : parDo.getSideInputs()) {
for (PCollectionView<?> view : parDo.getSideInputs().values()) {
SideInput sideInput = parDoPayload.getSideInputsOrThrow(view.getTagInternal().getId());
PCollectionView<?> restoredView =
PCollectionViewTranslation.viewFromProto(
Expand Down
Loading

0 comments on commit 64262a6

Please sign in to comment.