Skip to content

Commit

Permalink
Use plain Impulse rather than a full Read for 0, 1 element Creates. (a…
Browse files Browse the repository at this point in the history
…pache#26332)

Most of the complexity here is in preserving the read for update compatibility.
  • Loading branch information
robertwb authored and cushon committed May 24, 2024
1 parent 826272d commit b794479
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ public void testEmptyPipeline() {

@Test
public void testCompositePipeline() {
p.apply(Create.timestamped(TimestampedValue.of(KV.of(1, 1), new Instant(1))))
p.apply(
Create.timestamped(
TimestampedValue.of(KV.of(1, 1), new Instant(1)),
TimestampedValue.of(KV.of(2, 2), new Instant(2))))
.apply(Window.into(FixedWindows.of(Duration.millis(10))))
.apply(Sum.integersPerKey());
assertEquals(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.Combine.GroupedValues;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.GroupIntoBatches;
import org.apache.beam.sdk.transforms.Impulse;
Expand Down Expand Up @@ -495,6 +496,23 @@ protected DataflowRunner(DataflowPipelineOptions options) {
this.ptransformViewsWithNonDeterministicKeyCoders = new HashSet<>();
}

private static class AlwaysCreateViaRead<T>
implements PTransformOverrideFactory<PBegin, PCollection<T>, Create.Values<T>> {
@Override
public PTransformOverrideFactory.PTransformReplacement<PBegin, PCollection<T>>
getReplacementTransform(
AppliedPTransform<PBegin, PCollection<T>, Create.Values<T>> appliedTransform) {
return PTransformOverrideFactory.PTransformReplacement.of(
appliedTransform.getPipeline().begin(), appliedTransform.getTransform().alwaysUseRead());
}

@Override
public final Map<PCollection<?>, ReplacementOutput> mapOutputs(
Map<TupleTag<?>, PCollection<?>> outputs, PCollection<T> newOutput) {
return ReplacementOutputs.singleton(outputs, newOutput);
}
}

private List<PTransformOverride> getOverrides(boolean streaming) {
ImmutableList.Builder<PTransformOverride> overridesBuilder = ImmutableList.builder();

Expand All @@ -509,6 +527,13 @@ private List<PTransformOverride> getOverrides(boolean streaming) {
PTransformOverride.of(
PTransformMatchers.emptyFlatten(), EmptyFlattenAsCreateFactory.instance()));

if (streaming) {
// For update compatibility, always use a Read for Create in streaming mode.
overridesBuilder.add(
PTransformOverride.of(
PTransformMatchers.classEqualTo(Create.Values.class), new AlwaysCreateViaRead()));
}

// By default Dataflow runner replaces single-output ParDo with a ParDoSingle override.
// However, we want a different expansion for single-output splittable ParDo.
overridesBuilder
Expand Down Expand Up @@ -661,6 +686,29 @@ private List<PTransformOverride> getOverrides(boolean streaming) {
PTransformOverride.of(
PTransformMatchers.classEqualTo(ParDo.SingleOutput.class),
new PrimitiveParDoSingleFactory()));

if (streaming) {
// For update compatibility, always use a Read for Create in streaming mode.
overridesBuilder
.add(
PTransformOverride.of(
PTransformMatchers.classEqualTo(Create.Values.class), new AlwaysCreateViaRead()))
// Create is implemented in terms of BoundedRead.
.add(
PTransformOverride.of(
PTransformMatchers.classEqualTo(Read.Bounded.class),
new StreamingBoundedReadOverrideFactory()))
// Streaming Bounded Read is implemented in terms of Streaming Unbounded Read.
.add(
PTransformOverride.of(
PTransformMatchers.classEqualTo(Read.Unbounded.class),
new StreamingUnboundedReadOverrideFactory()))
.add(
PTransformOverride.of(
PTransformMatchers.classEqualTo(ParDo.SingleOutput.class),
new PrimitiveParDoSingleFactory()));
}

return overridesBuilder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ public void testTaggedNamesOverridden() throws Exception {

PCollectionTuple outputs =
pipeline
.apply(Create.of(3))
.apply(Create.of(3, 4))
.apply(
ParDo.of(
new DoFn<Integer, Integer>() {
Expand Down Expand Up @@ -775,7 +775,7 @@ public void testBatchStatefulParDoTranslation() throws Exception {
TupleTag<Integer> mainOutputTag = new TupleTag<Integer>() {};

pipeline
.apply(Create.of(KV.of(1, 1)))
.apply(Create.of(KV.of(1, 1), KV.of(2, 3)))
.apply(
ParDo.of(
new DoFn<KV<Integer, Integer>, Integer>() {
Expand Down Expand Up @@ -917,7 +917,7 @@ public void processElement(ProcessContext c) {
// No need to actually check the pipeline as the ValidatesRunner tests
// ensure translation is correct. This is just a quick check to see that translation
// does not crash.
assertEquals(24, steps.size());
assertEquals(25, steps.size());
}

/** Smoke test to fail fast if translation of a splittable ParDo in streaming breaks. */
Expand Down Expand Up @@ -1057,7 +1057,7 @@ public void testToSingletonTranslationWithIsmSideInput() throws Exception {
assertAllStepOutputsHaveUniqueIds(job);

List<Step> steps = job.getSteps();
assertEquals(9, steps.size());
assertEquals(10, steps.size());

@SuppressWarnings("unchecked")
List<Map<String, Object>> toIsmRecordOutputs =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ public void testCompositePipeline() throws IOException {

Pipeline p = Pipeline.create(options);

p.apply(Create.timestamped(TimestampedValue.of(KV.of(1, 1), new Instant(1))))
p.apply(
Create.timestamped(
TimestampedValue.of(KV.of(1, 1), new Instant(1)),
TimestampedValue.of(KV.of(2, 2), new Instant(2))))
.apply(Window.into(FixedWindows.of(Duration.millis(10))))
.apply(Sum.integersPerKey());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ public void debugBatchPipeline() {
.apply(TextIO.write().to("!!PLACEHOLDER-OUTPUT-DIR!!").withNumShards(3).withSuffix(".txt"));

final String expectedPipeline =
"sparkContext.<readFrom(org.apache.beam.sdk.transforms.Create$Values$CreateSource)>()\n"
"sparkContext.<impulse>()\n"
+ "_.mapPartitions(new org.apache.beam.sdk.transforms.FlatMapElements$3())\n"
+ "_.mapPartitions("
+ "new org.apache.beam.runners.spark.examples.WordCount$ExtractWordsFn())\n"
+ "_.mapPartitions(new org.apache.beam.sdk.transforms.Count$PerElement$1())\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
Expand Down Expand Up @@ -122,7 +123,7 @@ public class Create<T> {
* Otherwise, use {@link Create.Values#withCoder} to set the coder explicitly.
*/
public static <T> Values<T> of(Iterable<T> elems) {
return new Values<>(elems, Optional.absent(), Optional.absent());
return new Values<>(elems, Optional.absent(), Optional.absent(), false);
}

/**
Expand Down Expand Up @@ -154,7 +155,7 @@ public static <T> Values<T> of(@Nullable T elem, @Nullable T... elems) {
*/
public static Values<Row> empty(Schema schema) {
return new Values<Row>(
new ArrayList<>(), Optional.of(SchemaCoder.of(schema)), Optional.absent());
new ArrayList<>(), Optional.of(SchemaCoder.of(schema)), Optional.absent(), false);
}

/**
Expand All @@ -167,7 +168,7 @@ public static Values<Row> empty(Schema schema) {
* the {@code Coder} is provided via the {@code coder} argument.
*/
public static <T> Values<T> empty(Coder<T> coder) {
return new Values<>(new ArrayList<>(), Optional.of(coder), Optional.absent());
return new Values<>(new ArrayList<>(), Optional.of(coder), Optional.absent(), false);
}

/**
Expand All @@ -181,7 +182,7 @@ public static <T> Values<T> empty(Coder<T> coder) {
* must be registered for the class described in the {@code TypeDescriptor<T>}.
*/
public static <T> Values<T> empty(TypeDescriptor<T> type) {
return new Values<>(new ArrayList<>(), Optional.absent(), Optional.of(type));
return new Values<>(new ArrayList<>(), Optional.absent(), Optional.of(type), false);
}

/**
Expand Down Expand Up @@ -284,7 +285,7 @@ public static class Values<T> extends PTransform<PBegin, PCollection<T>> {
* <p>Note that for {@link Create.Values} with no elements, the {@link VoidCoder} is used.
*/
public Values<T> withCoder(Coder<T> coder) {
return new Values<>(elems, Optional.of(coder), typeDescriptor);
return new Values<>(elems, Optional.of(coder), typeDescriptor, alwaysUseRead);
}

/**
Expand Down Expand Up @@ -321,7 +322,11 @@ public Values<T> withRowSchema(Schema schema) {
* <p>Note that for {@link Create.Values} with no elements, the {@link VoidCoder} is used.
*/
public Values<T> withType(TypeDescriptor<T> type) {
return new Values<>(elems, coder, Optional.of(type));
return new Values<>(elems, coder, Optional.of(type), alwaysUseRead);
}

public Values<T> alwaysUseRead() {
return new AlwaysUseRead<>(elems, coder, typeDescriptor);
}

public Iterable<T> getElements() {
Expand Down Expand Up @@ -362,6 +367,41 @@ public PCollection<T> expand(PBegin input) {
e);
}
try {
if (!alwaysUseRead) {
int numElements = Iterables.size(elems);
if (numElements == 0) {
return input
.apply(Impulse.create())
.apply(
FlatMapElements.via(
new SimpleFunction<byte[], Iterable<T>>() {
@Override
public Iterable<T> apply(byte[] input) {
return Collections.emptyList();
}
}))
.setCoder(coder);
} else if (numElements == 1) {
final byte[] encodedElement =
CoderUtils.encodeToByteArray(coder, Iterables.getOnlyElement(elems));
final Coder<T> capturedCoder = coder;
return input
.apply(Impulse.create())
.apply(
MapElements.via(
new SimpleFunction<byte[], T>() {
@Override
public T apply(byte[] input) {
try {
return CoderUtils.decodeFromByteArray(capturedCoder, encodedElement);
} catch (CoderException exn) {
throw new RuntimeException(exn);
}
}
}))
.setCoder(coder);
}
}
CreateSource<T> source = CreateSource.fromIterable(elems, coder);
return input.getPipeline().apply(Read.from(source));
} catch (IOException e) {
Expand All @@ -381,17 +421,24 @@ public PCollection<T> expand(PBegin input) {
/** The value type. */
private final transient Optional<TypeDescriptor<T>> typeDescriptor;

/** Whether to unconditionally implement this via reading a CreateSource. */
private final transient boolean alwaysUseRead;

/**
* Constructs a {@code Create.Values} transform that produces a {@link PCollection} containing
* the specified elements.
*
* <p>The arguments should not be modified after this is called.
*/
private Values(
Iterable<T> elems, Optional<Coder<T>> coder, Optional<TypeDescriptor<T>> typeDescriptor) {
Iterable<T> elems,
Optional<Coder<T>> coder,
Optional<TypeDescriptor<T>> typeDescriptor,
boolean alwaysUseRead) {
this.elems = elems;
this.coder = coder;
this.typeDescriptor = typeDescriptor;
this.alwaysUseRead = alwaysUseRead;
}

@VisibleForTesting
Expand Down Expand Up @@ -514,6 +561,14 @@ protected boolean advanceImpl() throws IOException {
return true;
}
}

/** A subclass to avoid getting re-matched. */
private static class AlwaysUseRead<T> extends Values<T> {
private AlwaysUseRead(
Iterable<T> elems, Optional<Coder<T>> coder, Optional<TypeDescriptor<T>> typeDescriptor) {
super(elems, coder, typeDescriptor, true);
}
}
}

/////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ public void testPAssertEqualsSingletonFalseDefaultReasonString() throws Exceptio

String message = thrown.getMessage();

assertThat(message, containsString("Create.Values/Read(CreateSource)"));
assertThat(message, containsString("Create.Values/"));
assertThat(message, containsString("Expected: <44>"));
assertThat(message, containsString("but: was <42>"));
}
Expand Down

0 comments on commit b794479

Please sign in to comment.