Skip to content

Commit

Permalink
Expose GroupIntoBatches parameters for WriteFiles and TextIO (#31617)
Browse files Browse the repository at this point in the history
  • Loading branch information
Amar3tto committed Jun 25, 2024
1 parent 970a3f5 commit 1037ede
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 8 deletions.
72 changes: 72 additions & 0 deletions sdks/java/core/src/main/java/org/apache/beam/sdk/io/TextIO.java
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,15 @@ public abstract static class TypedWrite<UserT, DestinationT>
/** A function that converts UserT to a String, for writing to the file. */
abstract @Nullable SerializableFunction<UserT, String> getFormatFunction();

/** Batch size for input records. */
abstract @Nullable Integer getBatchSize();

/** Batch size in bytes for input records. */
abstract @Nullable Integer getBatchSizeBytes();

/** Batch max buffering duration for input records. */
abstract @Nullable Duration getBatchMaxBufferingDuration();

/** Whether to write windowed output files. */
abstract boolean getWindowedWrites();

Expand Down Expand Up @@ -757,6 +766,13 @@ abstract Builder<UserT, DestinationT> setFormatFunction(
abstract Builder<UserT, DestinationT> setNumShards(
@Nullable ValueProvider<Integer> numShards);

abstract Builder<UserT, DestinationT> setBatchSize(@Nullable Integer batchSize);

abstract Builder<UserT, DestinationT> setBatchSizeBytes(@Nullable Integer batchSizeBytes);

abstract Builder<UserT, DestinationT> setBatchMaxBufferingDuration(
@Nullable Duration batchMaxBufferingDuration);

abstract Builder<UserT, DestinationT> setWindowedWrites(boolean windowedWrites);

abstract Builder<UserT, DestinationT> setAutoSharding(boolean windowedWrites);
Expand Down Expand Up @@ -868,6 +884,38 @@ public TypedWrite<UserT, DestinationT> withFormatFunction(
return toBuilder().setFormatFunction(formatFunction).build();
}

/**
* Returns a new {@link TypedWrite} that will batch the input records using specified batch
* size. The default value is {@link WriteFiles#FILE_TRIGGERING_RECORD_COUNT}.
*
* <p>This option is used only for writing unbounded data with auto-sharding.
*/
public TypedWrite<UserT, DestinationT> withBatchSize(@Nullable Integer batchSize) {
return toBuilder().setBatchSize(batchSize).build();
}

/**
* Returns a new {@link TypedWrite} that will batch the input records using specified batch size
* in bytes. The default value is {@link WriteFiles#FILE_TRIGGERING_BYTE_COUNT}.
*
* <p>This option is used only for writing unbounded data with auto-sharding.
*/
public TypedWrite<UserT, DestinationT> withBatchSizeBytes(@Nullable Integer batchSizeBytes) {
return toBuilder().setBatchSizeBytes(batchSizeBytes).build();
}

/**
* Returns a new {@link TypedWrite} that will batch the input records using specified max
* buffering duration. The default value is {@link
* WriteFiles#FILE_TRIGGERING_RECORD_BUFFERING_DURATION}.
*
* <p>This option is used only for writing unbounded data with auto-sharding.
*/
public TypedWrite<UserT, DestinationT> withBatchMaxBufferingDuration(
@Nullable Duration batchMaxBufferingDuration) {
return toBuilder().setBatchMaxBufferingDuration(batchMaxBufferingDuration).build();
}

/** Set the base directory used to generate temporary files. */
public TypedWrite<UserT, DestinationT> withTempDirectory(
ValueProvider<ResourceId> tempDirectory) {
Expand Down Expand Up @@ -1119,6 +1167,15 @@ public WriteFilesResult<DestinationT> expand(PCollection<UserT> input) {
if (getSkipIfEmpty()) {
write = write.withSkipIfEmpty();
}
if (getBatchSize() != null) {
write = write.withBatchSize(getBatchSize());
}
if (getBatchSizeBytes() != null) {
write = write.withBatchSizeBytes(getBatchSizeBytes());
}
if (getBatchMaxBufferingDuration() != null) {
write = write.withBatchMaxBufferingDuration(getBatchMaxBufferingDuration());
}
return input.apply("WriteFiles", write);
}

Expand Down Expand Up @@ -1291,6 +1348,21 @@ public Write withNoSpilling() {
return new Write(inner.withNoSpilling());
}

/** See {@link TypedWrite#withBatchSize(Integer)}. */
public Write withBatchSize(@Nullable Integer batchSize) {
return new Write(inner.withBatchSize(batchSize));
}

/** See {@link TypedWrite#withBatchSizeBytes(Integer)}. */
public Write withBatchSizeBytes(@Nullable Integer batchSizeBytes) {
return new Write(inner.withBatchSizeBytes(batchSizeBytes));
}

/** See {@link TypedWrite#withBatchMaxBufferingDuration(Duration)}. */
public Write withBatchMaxBufferingDuration(@Nullable Duration batchMaxBufferingDuration) {
return new Write(inner.withBatchMaxBufferingDuration(batchMaxBufferingDuration));
}

/**
* Specify that output filenames are wanted.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@ public abstract class WriteFiles<UserT, DestinationT, OutputT>

// The record count and buffering duration to trigger flushing records to a tmp file. Mainly used
// for writing unbounded data to avoid generating too many small files.
private static final int FILE_TRIGGERING_RECORD_COUNT = 100000;
private static final int FILE_TRIGGERING_BYTE_COUNT = 64 * 1024 * 1024; // 64MiB as of now
private static final Duration FILE_TRIGGERING_RECORD_BUFFERING_DURATION =
public static final int FILE_TRIGGERING_RECORD_COUNT = 100000;
public static final int FILE_TRIGGERING_BYTE_COUNT = 64 * 1024 * 1024; // 64MiB as of now
public static final Duration FILE_TRIGGERING_RECORD_BUFFERING_DURATION =
Duration.standardSeconds(5);

static final int UNKNOWN_SHARDNUM = -1;
Expand Down Expand Up @@ -196,6 +196,12 @@ public static <UserT, DestinationT, OutputT> WriteFiles<UserT, DestinationT, Out

abstract boolean getSkipIfEmpty();

abstract @Nullable Integer getBatchSize();

abstract @Nullable Integer getBatchSizeBytes();

abstract @Nullable Duration getBatchMaxBufferingDuration();

abstract List<PCollectionView<?>> getSideInputs();

public abstract @Nullable ShardingFunction<UserT, DestinationT> getShardingFunction();
Expand Down Expand Up @@ -226,6 +232,14 @@ abstract Builder<UserT, DestinationT, OutputT> setMaxNumWritersPerBundle(

abstract Builder<UserT, DestinationT, OutputT> setSkipIfEmpty(boolean skipIfEmpty);

abstract Builder<UserT, DestinationT, OutputT> setBatchSize(@Nullable Integer batchSize);

abstract Builder<UserT, DestinationT, OutputT> setBatchSizeBytes(
@Nullable Integer batchSizeBytes);

abstract Builder<UserT, DestinationT, OutputT> setBatchMaxBufferingDuration(
@Nullable Duration batchMaxBufferingDuration);

abstract Builder<UserT, DestinationT, OutputT> setSideInputs(
List<PCollectionView<?>> sideInputs);

Expand Down Expand Up @@ -286,6 +300,38 @@ public WriteFiles<UserT, DestinationT, OutputT> withSkipIfEmpty(boolean skipIfEm
return toBuilder().setSkipIfEmpty(skipIfEmpty).build();
}

/**
* Returns a new {@link WriteFiles} that will batch the input records using specified batch size.
* The default value is {@link #FILE_TRIGGERING_RECORD_COUNT}.
*
* <p>This option is used only for writing unbounded data with auto-sharding.
*/
public WriteFiles<UserT, DestinationT, OutputT> withBatchSize(@Nullable Integer batchSize) {
return toBuilder().setBatchSize(batchSize).build();
}

/**
* Returns a new {@link WriteFiles} that will batch the input records using specified batch size
* in bytes. The default value is {@link #FILE_TRIGGERING_BYTE_COUNT}.
*
* <p>This option is used only for writing unbounded data with auto-sharding.
*/
public WriteFiles<UserT, DestinationT, OutputT> withBatchSizeBytes(
@Nullable Integer batchSizeBytes) {
return toBuilder().setBatchSizeBytes(batchSizeBytes).build();
}

/**
* Returns a new {@link WriteFiles} that will batch the input records using specified max
* buffering duration. The default value is {@link #FILE_TRIGGERING_RECORD_BUFFERING_DURATION}.
*
* <p>This option is used only for writing unbounded data with auto-sharding.
*/
public WriteFiles<UserT, DestinationT, OutputT> withBatchMaxBufferingDuration(
@Nullable Duration batchMaxBufferingDuration) {
return toBuilder().setBatchMaxBufferingDuration(batchMaxBufferingDuration).build();
}

public WriteFiles<UserT, DestinationT, OutputT> withSideInputs(
List<PCollectionView<?>> sideInputs) {
return toBuilder().setSideInputs(sideInputs).build();
Expand Down Expand Up @@ -445,7 +491,12 @@ public WriteFilesResult<DestinationT> expand(PCollection<UserT> input) {
tempFileResults =
input.apply(
"WriteAutoShardedBundlesToTempFiles",
new WriteAutoShardedBundlesToTempFiles(destinationCoder, fileResultCoder));
new WriteAutoShardedBundlesToTempFiles(
destinationCoder,
fileResultCoder,
getBatchSize(),
getBatchSizeBytes(),
getBatchMaxBufferingDuration()));
}
}

Expand Down Expand Up @@ -887,11 +938,24 @@ private class WriteAutoShardedBundlesToTempFiles
extends PTransform<PCollection<UserT>, PCollection<List<FileResult<DestinationT>>>> {
private final Coder<DestinationT> destinationCoder;
private final Coder<FileResult<DestinationT>> fileResultCoder;
private final int batchSize;
private final int batchSizeBytes;
private final Duration maxBufferingDuration;

private WriteAutoShardedBundlesToTempFiles(
Coder<DestinationT> destinationCoder, Coder<FileResult<DestinationT>> fileResultCoder) {
Coder<DestinationT> destinationCoder,
Coder<FileResult<DestinationT>> fileResultCoder,
Integer batchSize,
Integer batchSizeBytes,
Duration maxBufferingDuration) {
this.destinationCoder = destinationCoder;
this.fileResultCoder = fileResultCoder;
this.batchSize = batchSize != null ? batchSize : FILE_TRIGGERING_RECORD_COUNT;
this.batchSizeBytes = batchSizeBytes != null ? batchSizeBytes : FILE_TRIGGERING_BYTE_COUNT;
this.maxBufferingDuration =
maxBufferingDuration != null
? maxBufferingDuration
: FILE_TRIGGERING_RECORD_BUFFERING_DURATION;
}

@Override
Expand Down Expand Up @@ -919,9 +983,9 @@ public PCollection<List<FileResult<DestinationT>>> expand(PCollection<UserT> inp
.setCoder(KvCoder.of(VarIntCoder.of(), input.getCoder()))
.apply(
"ShardAndBatch",
GroupIntoBatches.<Integer, UserT>ofSize(FILE_TRIGGERING_RECORD_COUNT)
.withByteSize(FILE_TRIGGERING_BYTE_COUNT)
.withMaxBufferingDuration(FILE_TRIGGERING_RECORD_BUFFERING_DURATION)
GroupIntoBatches.<Integer, UserT>ofSize(batchSize)
.withByteSize(batchSizeBytes)
.withMaxBufferingDuration(maxBufferingDuration)
.withShardedKey())
.setCoder(
KvCoder.of(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import org.apache.beam.sdk.testing.NeedsRunner;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.UsesUnboundedPCollections;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.Values;
Expand Down Expand Up @@ -97,6 +98,10 @@
public class TextIOWriteTest {
private static final String MY_HEADER = "myHeader";
private static final String MY_FOOTER = "myFooter";
private static final int CUSTOM_FILE_TRIGGERING_RECORD_COUNT = 50000;
private static final int CUSTOM_FILE_TRIGGERING_BYTE_COUNT = 32 * 1024 * 1024; // 32MiB
private static final Duration CUSTOM_FILE_TRIGGERING_RECORD_BUFFERING_DURATION =
Duration.standardSeconds(4);

@Rule public transient TemporaryFolder tempFolder = new TemporaryFolder();

Expand Down Expand Up @@ -698,6 +703,42 @@ public void testRuntimeOptionsNotCalledInApply() throws Exception {
p.apply(Create.of("")).apply(TextIO.write().to(options.getOutput()));
}

@Test
@Category({NeedsRunner.class, UsesUnboundedPCollections.class})
public void testWriteUnboundedWithCustomBatchParameters() throws Exception {
Coder<String> coder = StringUtf8Coder.of();
String outputName = "file.txt";
Path baseDir = Files.createTempDirectory(tempFolder.getRoot().toPath(), "testwrite");
ResourceId baseFilename =
FileBasedSink.convertToFileResourceIfPossible(baseDir.resolve(outputName).toString());

PCollection<String> input =
p.apply(Create.of(Arrays.asList(LINES2_ARRAY)).withCoder(coder))
.setIsBoundedInternal(PCollection.IsBounded.UNBOUNDED)
.apply(Window.into(FixedWindows.of(Duration.standardSeconds(10))));

TextIO.Write write =
TextIO.write()
.to(baseFilename)
.withWindowedWrites()
.withShardNameTemplate(DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE)
.withBatchSize(CUSTOM_FILE_TRIGGERING_RECORD_COUNT)
.withBatchSizeBytes(CUSTOM_FILE_TRIGGERING_BYTE_COUNT)
.withBatchMaxBufferingDuration(CUSTOM_FILE_TRIGGERING_RECORD_BUFFERING_DURATION);

input.apply(write);
p.run();

assertOutputFiles(
LINES2_ARRAY,
null,
null,
3,
baseFilename,
DefaultFilenamePolicy.DEFAULT_UNWINDOWED_SHARD_TEMPLATE,
false);
}

@Test
@Category(NeedsRunner.class)
public void testWindowedWritesWithOnceTrigger() throws Throwable {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,11 @@ public class WriteFilesTest {
@Rule public final TestPipeline p = TestPipeline.create();
@Rule public ExpectedException thrown = ExpectedException.none();

private static final int CUSTOM_FILE_TRIGGERING_RECORD_COUNT = 50000;
private static final int CUSTOM_FILE_TRIGGERING_BYTE_COUNT = 32 * 1024 * 1024; // 32MiB
private static final Duration CUSTOM_FILE_TRIGGERING_RECORD_BUFFERING_DURATION =
Duration.standardSeconds(4);

@SuppressWarnings("unchecked") // covariant cast
private static final PTransform<PCollection<String>, PCollection<String>> IDENTITY_MAP =
(PTransform)
Expand Down Expand Up @@ -344,6 +349,23 @@ public void testWithShardingUnbounded() throws IOException {
true);
}

@Test
@Category({NeedsRunner.class, UsesUnboundedPCollections.class})
public void testWriteUnboundedWithCustomBatchParameters() throws IOException {
runShardedWrite(
Arrays.asList("one", "two", "three", "four", "five", "six"),
Window.into(FixedWindows.of(Duration.standardSeconds(10))),
getBaseOutputFilename(),
WriteFiles.to(makeSimpleSink())
.withWindowedWrites()
.withAutoSharding()
.withBatchSize(CUSTOM_FILE_TRIGGERING_RECORD_COUNT)
.withBatchSizeBytes(CUSTOM_FILE_TRIGGERING_BYTE_COUNT)
.withBatchMaxBufferingDuration(CUSTOM_FILE_TRIGGERING_RECORD_BUFFERING_DURATION),
null,
true);
}

@Test
@Category({
NeedsRunner.class,
Expand Down

0 comments on commit 1037ede

Please sign in to comment.