diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslationTest.java index 265793f94d39f..bf47ff545c5aa 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CreatePCollectionViewTranslationTest.java @@ -30,6 +30,8 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; +import org.apache.beam.sdk.values.PCollectionViews.TypeDescriptorSupplier; +import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; import org.hamcrest.Matchers; import org.junit.Test; @@ -50,12 +52,16 @@ public class CreatePCollectionViewTranslationTest { CreatePCollectionView.of( PCollectionViews.singletonView( testPCollection, + (TypeDescriptorSupplier) () -> TypeDescriptors.strings(), testPCollection.getWindowingStrategy(), false, null, StringUtf8Coder.of())), CreatePCollectionView.of( - PCollectionViews.listView(testPCollection, testPCollection.getWindowingStrategy()))); + PCollectionViews.listView( + testPCollection, + (TypeDescriptorSupplier) () -> TypeDescriptors.strings(), + testPCollection.getWindowingStrategy()))); } @Parameter(0) diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java index 1b711c705ad61..765a339d94131 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PCollectionViewTranslationTest.java @@ -22,6 +22,7 @@ import org.apache.beam.sdk.transforms.Materialization; import org.apache.beam.sdk.transforms.ViewFn; import org.apache.beam.sdk.transforms.windowing.GlobalWindows; +import org.apache.beam.sdk.values.TypeDescriptor; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -62,6 +63,11 @@ public Object apply(Object o) { throw new UnsupportedOperationException(); } + @Override + public TypeDescriptor getTypeDescriptor() { + return new TypeDescriptor() {}; + } + @Override public boolean equals(Object obj) { return obj instanceof TestViewFn; diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java index 52e34b3dc35ad..3260d8138bdb4 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/PTransformMatchersTest.java @@ -66,9 +66,11 @@ import org.apache.beam.sdk.values.PCollection.IsBounded; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; +import org.apache.beam.sdk.values.PCollectionViews.IterableViewFn; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; +import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; @@ -405,7 +407,8 @@ public void createViewWithViewFnDifferentViewFn() { PCollectionView> view = input.apply(View.asIterable()); // Purposely create a subclass to get a different class then what was expected. - ViewFn viewFn = new PCollectionViews.IterableViewFn() {}; + IterableViewFn viewFn = + new PCollectionViews.IterableViewFn(() -> TypeDescriptors.integers()) {}; CreatePCollectionView createView = CreatePCollectionView.of(view); PTransformMatcher matcher = PTransformMatchers.createViewWithViewFn(viewFn.getClass()); diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java index 5fa354dbac481..48a784750de1d 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/FlinkStreamingPortablePipelineTranslator.java @@ -85,6 +85,8 @@ import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeDescriptors; import org.apache.beam.sdk.values.ValueWithRecordId; import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException; @@ -805,7 +807,11 @@ private void translateTestStream( new LinkedHashMap<>(); // for PCollectionView compatibility, not used to transform materialization ViewFn>, ?> viewFn = - (ViewFn) new PCollectionViews.MultimapViewFn>, Void>(); + (ViewFn) + new PCollectionViews.MultimapViewFn<>( + (PCollectionViews.TypeDescriptorSupplier>>) + () -> TypeDescriptors.iterables(new TypeDescriptor>() {}), + (PCollectionViews.TypeDescriptorSupplier) TypeDescriptors::voids); for (RunnerApi.ExecutableStagePayload.SideInputId sideInputId : stagePayload.getSideInputsList()) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowPortabilityPCollectionView.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowPortabilityPCollectionView.java index 9a3740e249431..ed95c9cdce565 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowPortabilityPCollectionView.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/DataflowPortabilityPCollectionView.java @@ -38,6 +38,7 @@ import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptor; import org.apache.beam.sdk.values.WindowingStrategy; /** @@ -105,6 +106,11 @@ public Materialization> getMaterialization() { public MultimapView apply(MultimapView o) { return o; } + + @Override + public TypeDescriptor> getTypeDescriptor() { + throw new UnsupportedOperationException(); + } }; @Override diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java index aacb3674d23e1..2bd8365d89363 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/Combine.java @@ -64,6 +64,7 @@ import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; +import org.apache.beam.sdk.values.PCollectionViews.TypeDescriptorSupplier; import org.apache.beam.sdk.values.PValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.sdk.values.TupleTagList; @@ -1305,9 +1306,12 @@ public PCollectionView expand(PCollection input) { input.apply(Combine.globally(fn).withoutDefaults().withFanout(fanout)); PCollection> materializationInput = combined.apply(new VoidKeyToMultimapMaterialization<>()); + Coder outputCoder = combined.getCoder(); PCollectionView view = PCollectionViews.singletonView( materializationInput, + (TypeDescriptorSupplier) + () -> outputCoder != null ? outputCoder.getEncodedTypeDescriptor() : null, input.getWindowingStrategy(), insertDefault, insertDefault ? fn.defaultValue() : null, diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java index dad8b4f54b47e..fe85b11f83d12 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ParDo.java @@ -53,6 +53,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnSignature.MethodWithExtraParameters; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.OnTimerMethod; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.WindowFn; @@ -437,6 +438,29 @@ private static void validateStateApplicableForInput(DoFn fn, PCollection> sideInputs, DoFn fn) { + DoFnSignature signature = DoFnSignatures.getSignature(fn.getClass()); + DoFnSignature.ProcessElementMethod processElementMethod = signature.processElement(); + for (SideInputParameter sideInput : processElementMethod.getSideInputParameters()) { + PCollectionView view = sideInputs.get(sideInput.sideInputId()); + checkArgument( + view != null, + "the ProcessElement method expects a side input identified with the tag %s, but no such side input was" + + " supplied. Use withSideInput(String, PCollectionView) to supply this side input.", + sideInput.sideInputId()); + TypeDescriptor viewType = view.getViewFn().getTypeDescriptor(); + + // Currently check that the types exactly match, even if the types are convertible. + checkArgument( + viewType.equals(sideInput.elementT()), + "Side Input with tag %s and type %s cannot be bound to ProcessElement parameter with type %s", + sideInput.sideInputId(), + viewType, + sideInput.elementT()); + } + } + private static FieldAccessDescriptor getFieldAccessDescriptorFromParameter( @Nullable String fieldAccessString, Schema inputSchema, @@ -865,6 +889,8 @@ public PCollectionTuple expand(PCollection input) { validateStateApplicableForInput(fn, input); } + validateSideInputTypes(sideInputs, fn); + // TODO: We should validate OutputReceiver only happens if the output PCollection // as schema. However coder/schema inference may not have happened yet at this point. // Need to figure out where to validate this. diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java index df1debb15dfc9..75583cea2c1fd 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/View.java @@ -32,6 +32,7 @@ import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; import org.apache.beam.sdk.values.PCollectionViews; +import org.apache.beam.sdk.values.PCollectionViews.TypeDescriptorSupplier; /** * Transforms for creating {@link PCollectionView PCollectionViews} from {@link PCollection @@ -233,9 +234,12 @@ public PCollectionView> expand(PCollection input) { PCollection> materializationInput = input.apply(new VoidKeyToMultimapMaterialization<>()); + Coder inputCoder = input.getCoder(); PCollectionView> view = PCollectionViews.listView( - materializationInput, materializationInput.getWindowingStrategy()); + materializationInput, + (TypeDescriptorSupplier) inputCoder::getEncodedTypeDescriptor, + materializationInput.getWindowingStrategy()); materializationInput.apply(CreatePCollectionView.of(view)); return view; } @@ -263,9 +267,12 @@ public PCollectionView> expand(PCollection input) { PCollection> materializationInput = input.apply(new VoidKeyToMultimapMaterialization<>()); + Coder inputCoder = input.getCoder(); PCollectionView> view = PCollectionViews.iterableView( - materializationInput, materializationInput.getWindowingStrategy()); + materializationInput, + (TypeDescriptorSupplier) inputCoder::getEncodedTypeDescriptor, + materializationInput.getWindowingStrategy()); materializationInput.apply(CreatePCollectionView.of(view)); return view; } @@ -402,11 +409,17 @@ public PCollectionView>> expand(PCollection> input) throw new IllegalStateException("Unable to create a side-input view from input", e); } + KvCoder kvCoder = (KvCoder) input.getCoder(); + Coder keyCoder = kvCoder.getKeyCoder(); + Coder valueCoder = kvCoder.getValueCoder(); PCollection>> materializationInput = input.apply(new VoidKeyToMultimapMaterialization<>()); PCollectionView>> view = PCollectionViews.multimapView( - materializationInput, materializationInput.getWindowingStrategy()); + materializationInput, + (TypeDescriptorSupplier) keyCoder::getEncodedTypeDescriptor, + (TypeDescriptorSupplier) valueCoder::getEncodedTypeDescriptor, + materializationInput.getWindowingStrategy()); materializationInput.apply(CreatePCollectionView.of(view)); return view; } @@ -438,11 +451,18 @@ public PCollectionView> expand(PCollection> input) { throw new IllegalStateException("Unable to create a side-input view from input", e); } + KvCoder kvCoder = (KvCoder) input.getCoder(); + Coder keyCoder = kvCoder.getKeyCoder(); + Coder valueCoder = kvCoder.getValueCoder(); + PCollection>> materializationInput = input.apply(new VoidKeyToMultimapMaterialization<>()); PCollectionView> view = PCollectionViews.mapView( - materializationInput, materializationInput.getWindowingStrategy()); + materializationInput, + (TypeDescriptorSupplier) keyCoder::getEncodedTypeDescriptor, + (TypeDescriptorSupplier) valueCoder::getEncodedTypeDescriptor, + materializationInput.getWindowingStrategy()); materializationInput.apply(CreatePCollectionView.of(view)); return view; } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ViewFn.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ViewFn.java index 17a9e9c2e247c..41a3a4f7cc45f 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ViewFn.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/ViewFn.java @@ -21,6 +21,7 @@ import org.apache.beam.sdk.annotations.Internal; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TypeDescriptor; /** * For internal use only; no backwards-compatibility guarantees. @@ -45,4 +46,7 @@ public abstract class ViewFn implements Serializable { /** A function to adapt a primitive view type to a desired view type. */ public abstract ViewT apply(PrimitiveViewT primitiveViewT); + + /** Return the {@link TypeDescriptor} describing the output of this fn. */ + public abstract TypeDescriptor getTypeDescriptor(); } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java index bc81e38d09ddf..5737ac909f491 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/transforms/reflect/DoFnSignature.java @@ -40,6 +40,7 @@ import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.OutputReceiverParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.RestrictionTrackerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SchemaElementParameter; +import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.SideInputParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.StateParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.TimerParameter; import org.apache.beam.sdk.transforms.reflect.DoFnSignature.Parameter.WindowParameter; @@ -760,6 +761,14 @@ public List getSchemaElementParameters() { .collect(Collectors.toList()); } + @Nullable + public List getSideInputParameters() { + return extraParameters().stream() + .filter(Predicates.instanceOf(SideInputParameter.class)::apply) + .map(SideInputParameter.class::cast) + .collect(Collectors.toList()); + } + /** The {@link OutputReceiverParameter} for a main output, or null if there is none. */ @Nullable public OutputReceiverParameter getMainOutputReceiver() { diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java index 366ad3fb0d5b1..0f73699bfd882 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionViews.java @@ -18,6 +18,7 @@ package org.apache.beam.sdk.values; import java.io.IOException; +import java.io.Serializable; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; @@ -25,6 +26,7 @@ import java.util.Map; import java.util.NoSuchElementException; import java.util.Objects; +import java.util.function.Supplier; import javax.annotation.Nullable; import org.apache.beam.sdk.annotations.Experimental; import org.apache.beam.sdk.annotations.Experimental.Kind; @@ -52,6 +54,7 @@ */ @Internal public class PCollectionViews { + public interface TypeDescriptorSupplier extends Supplier>, Serializable {} /** * Returns a {@code PCollectionView} capable of processing elements windowed using the provided @@ -62,13 +65,14 @@ public class PCollectionViews { */ public static PCollectionView singletonView( PCollection> pCollection, + TypeDescriptorSupplier typeDescriptorSupplier, WindowingStrategy windowingStrategy, boolean hasDefault, @Nullable T defaultValue, Coder defaultValueCoder) { return new SimplePCollectionView<>( pCollection, - new SingletonViewFn<>(hasDefault, defaultValue, defaultValueCoder), + new SingletonViewFn(hasDefault, defaultValue, defaultValueCoder, typeDescriptorSupplier), windowingStrategy.getWindowFn().getDefaultWindowMappingFn(), windowingStrategy); } @@ -78,10 +82,12 @@ public static PCollectionView singletonView( * the provided {@link WindowingStrategy}. */ public static PCollectionView> iterableView( - PCollection> pCollection, WindowingStrategy windowingStrategy) { + PCollection> pCollection, + TypeDescriptorSupplier typeDescriptorSupplier, + WindowingStrategy windowingStrategy) { return new SimplePCollectionView<>( pCollection, - new IterableViewFn(), + new IterableViewFn(typeDescriptorSupplier), windowingStrategy.getWindowFn().getDefaultWindowMappingFn(), windowingStrategy); } @@ -91,10 +97,12 @@ public static PCollectionView> iterable * provided {@link WindowingStrategy}. */ public static PCollectionView> listView( - PCollection> pCollection, WindowingStrategy windowingStrategy) { + PCollection> pCollection, + TypeDescriptorSupplier typeDescriptorSupplier, + WindowingStrategy windowingStrategy) { return new SimplePCollectionView<>( pCollection, - new ListViewFn(), + new ListViewFn<>(typeDescriptorSupplier), windowingStrategy.getWindowFn().getDefaultWindowMappingFn(), windowingStrategy); } @@ -103,10 +111,13 @@ public static PCollectionView> listView( * provided {@link WindowingStrategy}. */ public static PCollectionView> mapView( - PCollection>> pCollection, WindowingStrategy windowingStrategy) { + PCollection>> pCollection, + TypeDescriptorSupplier keyTypeDescriptorSupplier, + TypeDescriptorSupplier valueTypeDescriptorSupplier, + WindowingStrategy windowingStrategy) { return new SimplePCollectionView<>( pCollection, - new MapViewFn(), + new MapViewFn<>(keyTypeDescriptorSupplier, valueTypeDescriptorSupplier), windowingStrategy.getWindowFn().getDefaultWindowMappingFn(), windowingStrategy); } @@ -116,10 +127,13 @@ public static PCollectionView> mapView * using the provided {@link WindowingStrategy}. */ public static PCollectionView>> multimapView( - PCollection>> pCollection, WindowingStrategy windowingStrategy) { + PCollection>> pCollection, + TypeDescriptorSupplier keyTypeDescriptorSupplier, + TypeDescriptorSupplier valueTypeDescriptorSupplier, + WindowingStrategy windowingStrategy) { return new SimplePCollectionView<>( pCollection, - new MultimapViewFn(), + new MultimapViewFn<>(keyTypeDescriptorSupplier, valueTypeDescriptorSupplier), windowingStrategy.getWindowFn().getDefaultWindowMappingFn(), windowingStrategy); } @@ -149,11 +163,17 @@ public static class SingletonViewFn extends ViewFn, T> @Nullable private transient T defaultValue; @Nullable private Coder valueCoder; private boolean hasDefault; + private TypeDescriptorSupplier typeDescriptorSupplier; - private SingletonViewFn(boolean hasDefault, T defaultValue, Coder valueCoder) { + private SingletonViewFn( + boolean hasDefault, + T defaultValue, + Coder valueCoder, + TypeDescriptorSupplier typeDescriptorSupplier) { this.hasDefault = hasDefault; this.defaultValue = defaultValue; this.valueCoder = valueCoder; + this.typeDescriptorSupplier = typeDescriptorSupplier; if (hasDefault) { try { this.encodedDefaultValue = CoderUtils.encodeToByteArray(valueCoder, defaultValue); @@ -212,6 +232,11 @@ public T apply(MultimapView primitiveViewT) { "PCollection with more than one element accessed as a singleton view."); } } + + @Override + public TypeDescriptor getTypeDescriptor() { + return typeDescriptorSupplier.get(); + } } /** @@ -223,6 +248,11 @@ public T apply(MultimapView primitiveViewT) { */ @Experimental(Kind.CORE_RUNNERS_ONLY) public static class IterableViewFn extends ViewFn, Iterable> { + private TypeDescriptorSupplier typeDescriptorSupplier; + + public IterableViewFn(TypeDescriptorSupplier typeDescriptorSupplier) { + this.typeDescriptorSupplier = typeDescriptorSupplier; + } @Override public Materialization> getMaterialization() { @@ -233,6 +263,11 @@ public Materialization> getMaterialization() { public Iterable apply(MultimapView primitiveViewT) { return Iterables.unmodifiableIterable(primitiveViewT.get(null)); } + + @Override + public TypeDescriptor> getTypeDescriptor() { + return TypeDescriptors.iterables(typeDescriptorSupplier.get()); + } } /** @@ -244,6 +279,12 @@ public Iterable apply(MultimapView primitiveViewT) { */ @Experimental(Kind.CORE_RUNNERS_ONLY) public static class ListViewFn extends ViewFn, List> { + private TypeDescriptorSupplier typeDescriptorSupplier; + + public ListViewFn(TypeDescriptorSupplier typeDescriptorSupplier) { + this.typeDescriptorSupplier = typeDescriptorSupplier; + } + @Override public Materialization> getMaterialization() { return Materializations.multimap(); @@ -258,6 +299,11 @@ public List apply(MultimapView primitiveViewT) { return Collections.unmodifiableList(list); } + @Override + public TypeDescriptor> getTypeDescriptor() { + return TypeDescriptors.lists(typeDescriptorSupplier.get()); + } + @Override public boolean equals(Object other) { return other instanceof ListViewFn; @@ -280,6 +326,16 @@ public int hashCode() { @Experimental(Kind.CORE_RUNNERS_ONLY) public static class MultimapViewFn extends ViewFn>, Map>> { + private TypeDescriptorSupplier keyTypeDescriptorSupplier; + private TypeDescriptorSupplier valueTypeDescriptorSupplier; + + public MultimapViewFn( + TypeDescriptorSupplier keyTypeDescriptorSupplier, + TypeDescriptorSupplier valueTypeDescriptorSupplier) { + this.keyTypeDescriptorSupplier = keyTypeDescriptorSupplier; + this.valueTypeDescriptorSupplier = valueTypeDescriptorSupplier; + } + @Override public Materialization>> getMaterialization() { return Materializations.multimap(); @@ -298,6 +354,13 @@ public Map> apply(MultimapView> primitiveViewT) { Map> resultMap = (Map) multimap.asMap(); return Collections.unmodifiableMap(resultMap); } + + @Override + public TypeDescriptor>> getTypeDescriptor() { + return TypeDescriptors.maps( + keyTypeDescriptorSupplier.get(), + TypeDescriptors.iterables(valueTypeDescriptorSupplier.get())); + } } /** @@ -309,6 +372,15 @@ public Map> apply(MultimapView> primitiveViewT) { */ @Experimental(Kind.CORE_RUNNERS_ONLY) public static class MapViewFn extends ViewFn>, Map> { + private TypeDescriptorSupplier keyTypeDescriptorSupplier; + private TypeDescriptorSupplier valueTypeDescriptorSupplier; + + public MapViewFn( + TypeDescriptorSupplier keyTypeDescriptorSupplier, + TypeDescriptorSupplier valueTypeDescriptorSupplier) { + this.keyTypeDescriptorSupplier = keyTypeDescriptorSupplier; + this.valueTypeDescriptorSupplier = valueTypeDescriptorSupplier; + } @Override public Materialization>> getMaterialization() { @@ -328,6 +400,12 @@ public Map apply(MultimapView> primitiveViewT) { } return Collections.unmodifiableMap(map); } + + @Override + public TypeDescriptor> getTypeDescriptor() { + return TypeDescriptors.maps( + keyTypeDescriptorSupplier.get(), valueTypeDescriptorSupplier.get()); + } } /** diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java index 04587e74960ab..b6df00fd03444 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/transforms/ParDoTest.java @@ -737,6 +737,154 @@ public void testParDoWithSideInputs() { pipeline.run(); } + @Test + @Category({NeedsRunner.class, UsesSideInputs.class}) + public void testSideInputAnnotationFailedValidationMissing() { + // SideInput tag id + final String sideInputTag1 = "tag1"; + + DoFn> fn = + new DoFn>() { + @ProcessElement + public void processElement(@SideInput(sideInputTag1) String tag1) {} + }; + + thrown.expect(IllegalArgumentException.class); + PCollection> output = + pipeline.apply("Create main input", Create.of(2)).apply(ParDo.of(fn)); + pipeline.run(); + } + + @Test + @Category({NeedsRunner.class, UsesSideInputs.class}) + public void testSideInputAnnotationFailedValidationSingletonType() { + + final PCollectionView sideInput1 = + pipeline + .apply("CreateSideInput1", Create.of(2)) + .apply("ViewSideInput1", View.asSingleton()); + + // SideInput tag id + final String sideInputTag1 = "tag1"; + + DoFn> fn = + new DoFn>() { + @ProcessElement + public void processElement(@SideInput(sideInputTag1) String tag1) {} + }; + + thrown.expect(IllegalArgumentException.class); + PCollection> output = + pipeline + .apply("Create main input", Create.of(2)) + .apply(ParDo.of(fn).withSideInput(sideInputTag1, sideInput1)); + pipeline.run(); + } + + @Test + @Category({NeedsRunner.class, UsesSideInputs.class}) + public void testSideInputAnnotationFailedValidationListType() { + + final PCollectionView> sideInput1 = + pipeline + .apply("CreateSideInput1", Create.of(2, 1, 0)) + .apply("ViewSideInput1", View.asList()); + + // SideInput tag id + final String sideInputTag1 = "tag1"; + + DoFn> fn = + new DoFn>() { + @ProcessElement + public void processElement(@SideInput(sideInputTag1) List tag1) {} + }; + + thrown.expect(IllegalArgumentException.class); + PCollection> output = + pipeline + .apply("Create main input", Create.of(2)) + .apply(ParDo.of(fn).withSideInput(sideInputTag1, sideInput1)); + pipeline.run(); + } + + @Test + @Category({NeedsRunner.class, UsesSideInputs.class}) + public void testSideInputAnnotationFailedValidationIterableType() { + + final PCollectionView> sideInput1 = + pipeline + .apply("CreateSideInput1", Create.of(2, 1, 0)) + .apply("ViewSideInput1", View.asIterable()); + + // SideInput tag id + final String sideInputTag1 = "tag1"; + + DoFn> fn = + new DoFn>() { + @ProcessElement + public void processElement(@SideInput(sideInputTag1) List tag1) {} + }; + + thrown.expect(IllegalArgumentException.class); + PCollection> output = + pipeline + .apply("Create main input", Create.of(2)) + .apply(ParDo.of(fn).withSideInput(sideInputTag1, sideInput1)); + pipeline.run(); + } + + @Test + @Category({NeedsRunner.class, UsesSideInputs.class}) + public void testSideInputAnnotationFailedValidationMapType() { + + final PCollectionView> sideInput1 = + pipeline + .apply("CreateSideInput1", Create.of(KV.of(1, 2), KV.of(2, 3), KV.of(3, 4))) + .apply("ViewSideInput1", View.asMap()); + + // SideInput tag id + final String sideInputTag1 = "tag1"; + + DoFn> fn = + new DoFn>() { + @ProcessElement + public void processElement(@SideInput(sideInputTag1) Map tag1) {} + }; + + thrown.expect(IllegalArgumentException.class); + PCollection> output = + pipeline + .apply("Create main input", Create.of(2)) + .apply(ParDo.of(fn).withSideInput(sideInputTag1, sideInput1)); + pipeline.run(); + } + + @Test + @Category({NeedsRunner.class, UsesSideInputs.class}) + public void testSideInputAnnotationFailedValidationMultiMapType() { + + final PCollectionView>> sideInput1 = + pipeline + .apply("CreateSideInput1", Create.of(KV.of(1, 2), KV.of(1, 3), KV.of(3, 4))) + .apply("ViewSideInput1", View.asMultimap()); + + // SideInput tag id + final String sideInputTag1 = "tag1"; + + DoFn> fn = + new DoFn>() { + @ProcessElement + public void processElement(@SideInput(sideInputTag1) Map tag1) {} + }; + + thrown.expect(IllegalArgumentException.class); + PCollection> output = + pipeline + .apply("Create main input", Create.of(2)) + .apply(ParDo.of(fn).withSideInput(sideInputTag1, sideInput1)); + pipeline.run(); + } + @Test @Category({ValidatesRunner.class, UsesSideInputs.class}) public void testSideInputAnnotation() { @@ -774,53 +922,98 @@ public void processElement( @Category({ValidatesRunner.class, UsesSideInputs.class}) public void testSideInputAnnotationWithMultipleSideInputs() { + final List side1Data = ImmutableList.of(2, 0); final PCollectionView> sideInput1 = pipeline - .apply("CreateSideInput1", Create.of(2, 0)) + .apply("CreateSideInput1", Create.of(side1Data)) .apply("ViewSideInput1", View.asList()); + final Integer side2Data = 5; final PCollectionView sideInput2 = pipeline - .apply("CreateSideInput2", Create.of(5)) + .apply("CreateSideInput2", Create.of(side2Data)) .apply("ViewSideInput2", View.asSingleton()); - final PCollectionView> sideInput3 = + final List side3Data = ImmutableList.of(1, 3); + final PCollectionView> sideInput3 = + pipeline + .apply("CreateSideInput3", Create.of(side3Data)) + .apply("ViewSideInput3", View.asIterable()); + + final List> side4Data = + ImmutableList.of(KV.of(1, 2), KV.of(2, 3), KV.of(3, 4)); + final PCollectionView> sideInput4 = + pipeline + .apply("CreateSideInput4", Create.of(side4Data)) + .apply("ViewSideInput4", View.asMap()); + + final List> side5Data = + ImmutableList.of(KV.of(1, 2), KV.of(1, 3), KV.of(3, 4)); + final PCollectionView>> sideInput5 = pipeline - .apply("CreateSideInput3", Create.of(1, 3)) - .apply("ViewSideInput3", View.asList()); + .apply("CreateSideInput5", Create.of(side5Data)) + .apply("ViewSideInput5", View.asMultimap()); // SideInput tag id final String sideInputTag1 = "tag1"; final String sideInputTag2 = "tag2"; final String sideInputTag3 = "tag3"; + final String sideInputTag4 = "tag4"; + final String sideInputTag5 = "tag5"; - DoFn> fn = - new DoFn>() { + final TupleTag outputTag1 = new TupleTag<>(); + final TupleTag outputTag2 = new TupleTag<>(); + final TupleTag outputTag3 = new TupleTag<>(); + final TupleTag> outputTag4 = new TupleTag<>(); + final TupleTag> outputTag5 = new TupleTag<>(); + + DoFn fn = + new DoFn() { @ProcessElement public void processElement( - OutputReceiver> r, - @SideInput(sideInputTag1) List tag1, - @SideInput(sideInputTag2) Integer tag2, - @SideInput(sideInputTag3) List tag3) { - - List sideSorted = Lists.newArrayList(tag1); - sideSorted.add(tag2); - sideSorted.addAll(tag3); - Collections.sort(sideSorted); - r.output(sideSorted); + MultiOutputReceiver r, + @SideInput(sideInputTag1) List side1, + @SideInput(sideInputTag2) Integer side2, + @SideInput(sideInputTag3) Iterable side3, + @SideInput(sideInputTag4) Map side4, + @SideInput(sideInputTag5) Map> side5) { + side1.forEach(i -> r.get(outputTag1).output(i)); + r.get(outputTag2).output(side2); + side3.forEach(i -> r.get(outputTag3).output(i)); + side4.forEach((k, v) -> r.get(outputTag4).output(KV.of(k, v))); + side5.forEach((k, v) -> v.forEach(v2 -> r.get(outputTag5).output(KV.of(k, v2)))); } }; - PCollection> output = + PCollectionTuple output = pipeline .apply("Create main input", Create.of(2)) .apply( ParDo.of(fn) .withSideInput(sideInputTag1, sideInput1) .withSideInput(sideInputTag2, sideInput2) - .withSideInput(sideInputTag3, sideInput3)); + .withSideInput(sideInputTag3, sideInput3) + .withSideInput(sideInputTag4, sideInput4) + .withSideInput(sideInputTag5, sideInput5) + .withOutputTags( + outputTag1, + TupleTagList.of(outputTag2) + .and(outputTag3) + .and(outputTag4) + .and(outputTag5))); + + output.get(outputTag1).setCoder(VarIntCoder.of()); + output.get(outputTag2).setCoder(VarIntCoder.of()); + output.get(outputTag3).setCoder(VarIntCoder.of()); + output.get(outputTag4).setCoder(KvCoder.of(VarIntCoder.of(), VarIntCoder.of())); + output.get(outputTag5).setCoder(KvCoder.of(VarIntCoder.of(), VarIntCoder.of())); + + PAssert.that(output.get(outputTag1)).containsInAnyOrder(side1Data); + PAssert.that(output.get(outputTag2)).containsInAnyOrder(side2Data); + PAssert.that(output.get(outputTag3)).containsInAnyOrder(side3Data); + PAssert.that(output.get(outputTag4)).containsInAnyOrder(side4Data); + PAssert.that(output.get(outputTag5)).containsInAnyOrder(side5Data); - PAssert.that(output).containsInAnyOrder(Lists.newArrayList(0, 1, 2, 3, 5)); pipeline.run(); }