diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java index 9f0aea11e6a7c..0a23961ffa79b 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/InProcessPipelineRunner.java @@ -19,6 +19,7 @@ import com.google.cloud.dataflow.sdk.Pipeline.PipelineExecutionException; import com.google.cloud.dataflow.sdk.PipelineResult; import com.google.cloud.dataflow.sdk.annotations.Experimental; +import com.google.cloud.dataflow.sdk.io.Write; import com.google.cloud.dataflow.sdk.options.PipelineOptions; import com.google.cloud.dataflow.sdk.runners.AggregatorPipelineExtractor; import com.google.cloud.dataflow.sdk.runners.AggregatorRetrievalException; @@ -50,9 +51,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; - import org.joda.time.Instant; - import java.util.Collection; import java.util.HashMap; import java.util.Map; @@ -78,6 +77,7 @@ public class InProcessPipelineRunner .put(Create.Values.class, new InProcessCreateOverrideFactory()) .put(GroupByKey.class, new InProcessGroupByKeyOverrideFactory()) .put(CreatePCollectionView.class, new InProcessViewOverrideFactory()) + .put(Write.Bound.class, new WriteWithShardingFactory()) .build(); /** diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WriteWithShardingFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WriteWithShardingFactory.java new file mode 100644 index 0000000000000..6de64603f826c --- /dev/null +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/runners/inprocess/WriteWithShardingFactory.java @@ -0,0 +1,134 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static com.google.common.base.Preconditions.checkArgument; + +import com.google.cloud.dataflow.sdk.io.Write; +import com.google.cloud.dataflow.sdk.transforms.Count; +import com.google.cloud.dataflow.sdk.transforms.DoFn; +import com.google.cloud.dataflow.sdk.transforms.Flatten; +import com.google.cloud.dataflow.sdk.transforms.GroupByKey; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.transforms.ParDo; +import com.google.cloud.dataflow.sdk.transforms.Values; +import com.google.cloud.dataflow.sdk.transforms.windowing.DefaultTrigger; +import com.google.cloud.dataflow.sdk.transforms.windowing.GlobalWindows; +import com.google.cloud.dataflow.sdk.transforms.windowing.Window; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.cloud.dataflow.sdk.values.PInput; +import com.google.cloud.dataflow.sdk.values.POutput; +import com.google.common.annotations.VisibleForTesting; +import org.joda.time.Duration; +import java.util.concurrent.ThreadLocalRandom; + +/** + * A {@link PTransformOverrideFactory} that overrides {@link Write} {@link PTransform PTransforms} + * with an unspecified number of shards with a write with a specified number of shards. The number + * of shards is the log base 10 of the number of input records, with up to 2 additional shards. + */ +class WriteWithShardingFactory implements PTransformOverrideFactory { + static final int MAX_RANDOM_EXTRA_SHARDS = 3; + + @Override + public PTransform override( + PTransform transform) { + if (transform instanceof Write.Bound) { + Write.Bound that = (Write.Bound) transform; + if (that.getNumShards() == 0) { + return (PTransform) new DynamicallyReshardedWrite(that); + } + } + return transform; + } + + private static class DynamicallyReshardedWrite extends PTransform, PDone> { + private final transient Write.Bound original; + + private DynamicallyReshardedWrite(Write.Bound original) { + this.original = original; + } + + @Override + public PDone apply(PCollection input) { + PCollection records = input.apply("RewindowInputs", + Window.into(new GlobalWindows()).triggering(DefaultTrigger.of()) + .withAllowedLateness(Duration.ZERO) + .discardingFiredPanes()); + final PCollectionView numRecords = records.apply(Count.globally().asSingletonView()); + PCollection resharded = + records + .apply( + "ApplySharding", + ParDo.withSideInputs(numRecords) + .of( + new KeyBasedOnCountFn( + numRecords, + ThreadLocalRandom.current().nextInt(MAX_RANDOM_EXTRA_SHARDS)))) + .apply("GroupIntoShards", GroupByKey.create()) + .apply("DropShardingKeys", Values.>create()) + .apply("FlattenShardIterables", Flatten.iterables()); + // This is an inverted application to apply the expansion of the original Write PTransform + // without adding a new Write Transform Node, which would be overwritten the same way, leading + // to an infinite recursion. We cannot modify the number of shards, because that is determined + // at runtime. + return original.apply(resharded); + } + } + + @VisibleForTesting + static class KeyBasedOnCountFn extends DoFn> { + @VisibleForTesting + static final int MIN_SHARDS_FOR_LOG = 3; + + private final PCollectionView numRecords; + private final int randomExtraShards; + private int currentShard; + private int maxShards; + + KeyBasedOnCountFn(PCollectionView numRecords, int extraShards) { + this.numRecords = numRecords; + this.randomExtraShards = extraShards; + } + + @Override + public void processElement(ProcessContext c) throws Exception { + if (maxShards == 0L) { + maxShards = calculateShards(c.sideInput(numRecords)); + currentShard = ThreadLocalRandom.current().nextInt(maxShards); + } + int shard = currentShard; + currentShard = (currentShard + 1) % maxShards; + c.output(KV.of(shard, c.element())); + } + + private int calculateShards(long totalRecords) { + checkArgument( + totalRecords > 0, + "KeyBasedOnCountFn cannot be invoked on an element if there are no elements"); + if (totalRecords < MIN_SHARDS_FOR_LOG + randomExtraShards) { + return (int) totalRecords; + } + // 100mil records before >7 output files + int floorLogRecs = Double.valueOf(Math.log10(totalRecords)).intValue(); + int shards = Math.max(floorLogRecs, MIN_SHARDS_FOR_LOG) + randomExtraShards; + return shards; + } + } +} diff --git a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java index 77d0b830cb924..443c5e240f106 100644 --- a/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java +++ b/sdk/src/main/java/com/google/cloud/dataflow/sdk/util/FileIOChannelFactory.java @@ -107,7 +107,8 @@ public WritableByteChannel create(String spec, String mimeType) File file = new File(spec); if (file.getAbsoluteFile().getParentFile() != null && !file.getAbsoluteFile().getParentFile().exists() - && !file.getAbsoluteFile().getParentFile().mkdirs()) { + && !file.getAbsoluteFile().getParentFile().mkdirs() + && !file.getAbsoluteFile().getParentFile().exists()) { throw new IOException("Unable to create parent directories for '" + spec + "'"); } return Channels.newChannel( diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java index 53295b0e56c2e..b3a43f67e5303 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/TextIOTest.java @@ -269,11 +269,19 @@ public static void assertOutputFiles( String shardNameTemplate) throws Exception { List expectedFiles = new ArrayList<>(); - for (int i = 0; i < numShards; i++) { - expectedFiles.add( - new File( - rootLocation.getRoot(), - IOChannelUtils.constructName(outputName, shardNameTemplate, "", i, numShards))); + if (numShards == 0) { + String pattern = + IOChannelUtils.resolve(rootLocation.getRoot().getAbsolutePath(), outputName + "*"); + for (String expected : IOChannelUtils.getFactory(pattern).match(pattern)) { + expectedFiles.add(new File(expected)); + } + } else { + for (int i = 0; i < numShards; i++) { + expectedFiles.add( + new File( + rootLocation.getRoot(), + IOChannelUtils.constructName(outputName, shardNameTemplate, "", i, numShards))); + } } List actual = new ArrayList<>(); diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/WriteTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/WriteTest.java index 58025aa10038c..9023ad0387954 100644 --- a/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/WriteTest.java +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/io/WriteTest.java @@ -19,6 +19,7 @@ import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; @@ -137,7 +138,7 @@ public void testWrite() { public void testEmptyWrite() { runWrite(Collections.emptyList(), IDENTITY_MAP); // Note we did not request a sharded write, so runWrite will not validate the number of shards. - assertEquals(1, numShards.intValue()); + assertThat(numShards.intValue(), greaterThan(0)); } /** diff --git a/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WriteWithShardingFactoryTest.java b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WriteWithShardingFactoryTest.java new file mode 100644 index 0000000000000..2e84ca31d3d54 --- /dev/null +++ b/sdk/src/test/java/com/google/cloud/dataflow/sdk/runners/inprocess/WriteWithShardingFactoryTest.java @@ -0,0 +1,288 @@ +/* + * Copyright (C) 2016 Google Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package com.google.cloud.dataflow.sdk.runners.inprocess; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.not; +import static org.junit.Assert.assertThat; + +import com.google.cloud.dataflow.sdk.Pipeline; +import com.google.cloud.dataflow.sdk.coders.VarLongCoder; +import com.google.cloud.dataflow.sdk.io.Sink; +import com.google.cloud.dataflow.sdk.io.TextIO; +import com.google.cloud.dataflow.sdk.io.Write; +import com.google.cloud.dataflow.sdk.options.PipelineOptions; +import com.google.cloud.dataflow.sdk.runners.inprocess.WriteWithShardingFactory.KeyBasedOnCountFn; +import com.google.cloud.dataflow.sdk.testing.TestPipeline; +import com.google.cloud.dataflow.sdk.transforms.Create; +import com.google.cloud.dataflow.sdk.transforms.DoFnTester; +import com.google.cloud.dataflow.sdk.transforms.PTransform; +import com.google.cloud.dataflow.sdk.util.IOChannelUtils; +import com.google.cloud.dataflow.sdk.util.PCollectionViews; +import com.google.cloud.dataflow.sdk.util.WindowedValue; +import com.google.cloud.dataflow.sdk.util.WindowingStrategy; +import com.google.cloud.dataflow.sdk.values.KV; +import com.google.cloud.dataflow.sdk.values.PCollection; +import com.google.cloud.dataflow.sdk.values.PCollectionView; +import com.google.cloud.dataflow.sdk.values.PDone; +import com.google.common.base.Function; +import com.google.common.collect.Iterables; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import java.io.File; +import java.io.FileReader; +import java.io.Reader; +import java.nio.CharBuffer; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.ThreadLocalRandom; + +/** + * Tests for {@link WriteWithShardingFactory}. + */ +public class WriteWithShardingFactoryTest { + public static final int INPUT_SIZE = 10000; + @Rule public TemporaryFolder tmp = new TemporaryFolder(); + private WriteWithShardingFactory factory = new WriteWithShardingFactory(); + + @Test + public void dynamicallyReshardedWrite() throws Exception { + List strs = new ArrayList<>(INPUT_SIZE); + for (int i = 0; i < INPUT_SIZE; i++) { + strs.add(UUID.randomUUID().toString()); + } + Collections.shuffle(strs); + + String fileName = "resharded_write"; + String outputPath = tmp.getRoot().getAbsolutePath(); + String targetLocation = IOChannelUtils.resolve(outputPath, fileName); + PipelineOptions testOptions = TestPipeline.testingPipelineOptions(); + testOptions.setRunner(InProcessPipelineRunner.class); + Pipeline p = Pipeline.create(testOptions); + // TextIO is implemented in terms of the Write PTransform. When sharding is not specified, + // resharding should be automatically applied + p.apply(Create.of(strs)).apply(TextIO.Write.to(targetLocation)); + + p.run(); + + Collection files = IOChannelUtils.getFactory(outputPath).match(targetLocation + "*"); + List actuals = new ArrayList<>(strs.size()); + for (String file : files) { + CharBuffer buf = CharBuffer.allocate((int) new File(file).length()); + try (Reader reader = new FileReader(file)) { + reader.read(buf); + buf.flip(); + } + + String[] readStrs = buf.toString().split("\n"); + for (String read : readStrs) { + if (read.length() > 0) { + actuals.add(read); + } + } + } + + assertThat(actuals, containsInAnyOrder(strs.toArray())); + assertThat( + files, + hasSize( + allOf( + greaterThan(1), + lessThan( + (int) + (Math.log10(INPUT_SIZE) + + WriteWithShardingFactory.MAX_RANDOM_EXTRA_SHARDS))))); + } + + @Test + public void withShardingSpecifiesOriginalTransform() { + PTransform, PDone> original = Write.to(new TestSink()).withNumShards(3); + + assertThat(factory.override(original), equalTo(original)); + } + + @Test + public void withNonWriteReturnsOriginalTransform() { + PTransform, PDone> original = + new PTransform, PDone>() { + @Override + public PDone apply(PCollection input) { + return PDone.in(input.getPipeline()); + } + }; + + assertThat(factory.override(original), equalTo(original)); + } + + @Test + public void withNoShardingSpecifiedReturnsNewTransform() { + PTransform, PDone> original = Write.to(new TestSink()); + assertThat(factory.override(original), not(equalTo(original))); + } + + @Test + public void keyBasedOnCountFnWithOneElement() throws Exception { + PCollectionView elementCountView = + PCollectionViews.singletonView( + TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); + KeyBasedOnCountFn fn = new KeyBasedOnCountFn<>(elementCountView, 0); + DoFnTester> fnTester = DoFnTester.of(fn); + + fnTester.setSideInput(elementCountView, + Collections.>singleton(WindowedValue.valueInGlobalWindow(1L))); + + List> outputs = fnTester.processBatch("foo", "bar", "bazbar"); + assertThat( + outputs, containsInAnyOrder(KV.of(0, "foo"), KV.of(0, "bar"), KV.of(0, "bazbar"))); + } + + @Test + public void keyBasedOnCountFnWithTwoElements() throws Exception { + PCollectionView elementCountView = + PCollectionViews.singletonView( + TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); + KeyBasedOnCountFn fn = new KeyBasedOnCountFn<>(elementCountView, 0); + DoFnTester> fnTester = DoFnTester.of(fn); + + fnTester.setSideInput(elementCountView, + Collections.>singleton(WindowedValue.valueInGlobalWindow(2L))); + + List> outputs = fnTester.processBatch("foo", "bar"); + assertThat( + outputs, + anyOf( + containsInAnyOrder(KV.of(0, "foo"), KV.of(1, "bar")), + containsInAnyOrder(KV.of(1, "foo"), KV.of(0, "bar")))); + } + + @Test + public void keyBasedOnCountFnFewElementsThreeShards() throws Exception { + PCollectionView elementCountView = + PCollectionViews.singletonView( + TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); + KeyBasedOnCountFn fn = new KeyBasedOnCountFn<>(elementCountView, 0); + DoFnTester> fnTester = DoFnTester.of(fn); + + fnTester.setSideInput(elementCountView, + Collections.>singleton(WindowedValue.valueInGlobalWindow(100L))); + + List> outputs = + fnTester.processBatch("foo", "bar", "baz", "foobar", "foobaz", "barbaz"); + assertThat( + Iterables.transform( + outputs, + new Function, Integer>() { + @Override + public Integer apply(KV input) { + return input.getKey(); + } + }), + containsInAnyOrder(0, 0, 1, 1, 2, 2)); + } + + @Test + public void keyBasedOnCountFnManyElements() throws Exception { + PCollectionView elementCountView = + PCollectionViews.singletonView( + TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); + KeyBasedOnCountFn fn = new KeyBasedOnCountFn<>(elementCountView, 0); + DoFnTester> fnTester = DoFnTester.of(fn); + + double count = Math.pow(10, 10); + fnTester.setSideInput(elementCountView, + Collections.>singleton(WindowedValue.valueInGlobalWindow((long) count))); + + List strings = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong())); + } + List> kvs = fnTester.processBatch(strings); + long maxKey = -1L; + for (KV kv : kvs) { + maxKey = Math.max(maxKey, kv.getKey()); + } + assertThat(maxKey, equalTo(9L)); + } + + @Test + public void keyBasedOnCountFnFewElementsExtraShards() throws Exception { + PCollectionView elementCountView = + PCollectionViews.singletonView( + TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); + KeyBasedOnCountFn fn = new KeyBasedOnCountFn<>(elementCountView, 10); + DoFnTester> fnTester = DoFnTester.of(fn); + + long countValue = (long) KeyBasedOnCountFn.MIN_SHARDS_FOR_LOG + 3; + fnTester.setSideInput(elementCountView, + Collections.>singleton(WindowedValue.valueInGlobalWindow(countValue))); + + List strings = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong())); + } + List> kvs = fnTester.processBatch(strings); + long maxKey = -1L; + for (KV kv : kvs) { + maxKey = Math.max(maxKey, kv.getKey()); + } + // 0 to n-1 shard ids. + assertThat(maxKey, equalTo(countValue - 1)); + } + + @Test + public void keyBasedOnCountFnManyElementsExtraShards() throws Exception { + PCollectionView elementCountView = + PCollectionViews.singletonView( + TestPipeline.create(), WindowingStrategy.globalDefault(), true, 0L, VarLongCoder.of()); + KeyBasedOnCountFn fn = new KeyBasedOnCountFn<>(elementCountView, 3); + DoFnTester> fnTester = DoFnTester.of(fn); + + double count = Math.pow(10, 10); + fnTester.setSideInput(elementCountView, + Collections.>singleton(WindowedValue.valueInGlobalWindow((long) count))); + + List strings = new ArrayList<>(); + for (int i = 0; i < 100; i++) { + strings.add(Long.toHexString(ThreadLocalRandom.current().nextLong())); + } + List> kvs = fnTester.processBatch(strings); + long maxKey = -1L; + for (KV kv : kvs) { + maxKey = Math.max(maxKey, kv.getKey()); + } + assertThat(maxKey, equalTo(12L)); + } + + private static class TestSink extends Sink { + @Override + public void validate(PipelineOptions options) {} + + @Override + public WriteOperation createWriteOperation(PipelineOptions options) { + throw new IllegalArgumentException("Should not be used"); + } + } +}