Skip to content

Commit

Permalink
Fix OutputSampler's coder. (apache#25805)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohdesamuel authored Mar 14, 2023
1 parent afce68d commit 04c2de6
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,14 @@ public FnDataReceiver<WindowedValue<?>> getMultiplexingConsumer(String pCollecti
String coderId =
processBundleDescriptor.getPcollectionsOrThrow(pCollectionId).getCoderId();
Coder<?> coder;
OutputSampler<?> sampler = null;
try {
Coder<?> maybeWindowedValueInputCoder = rehydratedComponents.getCoder(coderId);

if (dataSampler != null) {
sampler = dataSampler.sampleOutput(pCollectionId, maybeWindowedValueInputCoder);
}

// TODO: Stop passing windowed value coders within PCollections.
if (maybeWindowedValueInputCoder instanceof WindowedValue.WindowedValueCoder) {
coder = ((WindowedValueCoder) maybeWindowedValueInputCoder).getValueCoder();
Expand All @@ -215,16 +221,16 @@ public FnDataReceiver<WindowedValue<?>> getMultiplexingConsumer(String pCollecti
ConsumerAndMetadata consumerAndMetadata = consumerAndMetadatas.get(0);
if (consumerAndMetadata.getConsumer() instanceof HandlesSplits) {
return new SplittingMetricTrackingFnDataReceiver(
pcId, coder, consumerAndMetadata, dataSampler);
pcId, coder, consumerAndMetadata, sampler);
}
return new MetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata, dataSampler);
return new MetricTrackingFnDataReceiver(pcId, coder, consumerAndMetadata, sampler);
} else {
/* TODO(SDF), Consider supporting splitting each consumer individually. This would never
come up in the existing SDF expansion, but might be useful to support fused SDF nodes.
This would require dedicated delivery of the split results to each of the consumers
separately. */
return new MultiplexingMetricTrackingFnDataReceiver(
pcId, coder, consumerAndMetadatas, dataSampler);
pcId, coder, consumerAndMetadatas, sampler);
}
});
}
Expand All @@ -248,7 +254,7 @@ public MetricTrackingFnDataReceiver(
String pCollectionId,
Coder<T> coder,
ConsumerAndMetadata consumerAndMetadata,
@Nullable DataSampler dataSampler) {
@Nullable OutputSampler<T> outputSampler) {
this.delegate = consumerAndMetadata.getConsumer();
this.executionState = consumerAndMetadata.getExecutionState();

Expand Down Expand Up @@ -284,11 +290,7 @@ public MetricTrackingFnDataReceiver(
bundleProgressReporterRegistrar.register(sampledByteSizeUnderlyingDistribution);

this.coder = coder;
if (dataSampler == null) {
this.outputSampler = null;
} else {
this.outputSampler = dataSampler.sampleOutput(pCollectionId, coder);
}
this.outputSampler = outputSampler;
}

@Override
Expand All @@ -300,7 +302,7 @@ public void accept(WindowedValue<T> input) throws Exception {
this.sampledByteSizeDistribution.tryUpdate(input.getValue(), this.coder);

if (outputSampler != null) {
outputSampler.sample(input.getValue());
outputSampler.sample(input);
}

// Use the ExecutionStateTracker and enter an appropriate state to track the
Expand Down Expand Up @@ -329,13 +331,13 @@ private class MultiplexingMetricTrackingFnDataReceiver<T>
private final BundleCounter elementCountCounter;
private final SampleByteSizeDistribution<T> sampledByteSizeDistribution;
private final Coder<T> coder;
private final @Nullable OutputSampler<T> outputSampler;
private @Nullable OutputSampler<T> outputSampler = null;

public MultiplexingMetricTrackingFnDataReceiver(
String pCollectionId,
Coder<T> coder,
List<ConsumerAndMetadata> consumerAndMetadatas,
@Nullable DataSampler dataSampler) {
@Nullable OutputSampler<T> outputSampler) {
this.consumerAndMetadatas = consumerAndMetadatas;

HashMap<String, String> labels = new HashMap<>();
Expand Down Expand Up @@ -370,11 +372,7 @@ public MultiplexingMetricTrackingFnDataReceiver(
bundleProgressReporterRegistrar.register(sampledByteSizeUnderlyingDistribution);

this.coder = coder;
if (dataSampler == null) {
this.outputSampler = null;
} else {
this.outputSampler = dataSampler.sampleOutput(pCollectionId, coder);
}
this.outputSampler = outputSampler;
}

@Override
Expand All @@ -386,7 +384,7 @@ public void accept(WindowedValue<T> input) throws Exception {
this.sampledByteSizeDistribution.tryUpdate(input.getValue(), coder);

if (outputSampler != null) {
outputSampler.sample(input.getValue());
outputSampler.sample(input);
}

// Use the ExecutionStateTracker and enter an appropriate state to track the
Expand Down Expand Up @@ -422,8 +420,8 @@ public SplittingMetricTrackingFnDataReceiver(
String pCollection,
Coder<T> coder,
ConsumerAndMetadata consumerAndMetadata,
@Nullable DataSampler dataSampler) {
super(pCollection, coder, consumerAndMetadata, dataSampler);
@Nullable OutputSampler<T> outputSampler) {
super(pCollection, coder, consumerAndMetadata, outputSampler);
this.delegate = (HandlesSplits) consumerAndMetadata.getConsumer();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.Nullable;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.util.ByteStringOutputStream;
import org.apache.beam.sdk.util.WindowedValue;

/**
* This class holds samples for a single PCollection until queried by the parent DataSampler. This
Expand All @@ -35,7 +37,7 @@
public class OutputSampler<T> {

// Temporarily holds elements until the SDK receives a sample data request.
private List<T> buffer;
private List<WindowedValue<T>> buffer;

// Maximum number of elements in buffer.
private final int maxElements;
Expand All @@ -49,13 +51,27 @@ public class OutputSampler<T> {
// Index into the buffer of where to overwrite samples.
private int resampleIndex = 0;

private final Coder<T> coder;
@Nullable private final Coder<T> valueCoder;

public OutputSampler(Coder<T> coder, int maxElements, int sampleEveryN) {
this.coder = coder;
@Nullable private final Coder<WindowedValue<T>> windowedValueCoder;

public OutputSampler(Coder<?> coder, int maxElements, int sampleEveryN) {
this.maxElements = maxElements;
this.sampleEveryN = sampleEveryN;
this.buffer = new ArrayList<>(this.maxElements);

// The samples taken and encoded should match exactly to the specification from the
// ProcessBundleDescriptor. The coder given can either be a WindowedValueCoder, in which the
// element itself is sampled. Or, it's non a WindowedValueCoder and the value inside the
// windowed value must be sampled. This is because WindowedValue is the element type used in
// all receivers, which doesn't necessarily match the PBD encoding.
if (coder instanceof WindowedValue.WindowedValueCoder) {
this.valueCoder = null;
this.windowedValueCoder = (Coder<WindowedValue<T>>) coder;
} else {
this.valueCoder = (Coder<T>) coder;
this.windowedValueCoder = null;
}
}

/**
Expand All @@ -67,7 +83,7 @@ public OutputSampler(Coder<T> coder, int maxElements, int sampleEveryN) {
*
* @param element the element to sample.
*/
public void sample(T element) {
public void sample(WindowedValue<T> element) {
// Only sample the first 10 elements then after every `sampleEveryN`th element.
long samples = numSamples.get() + 1;

Expand Down Expand Up @@ -104,7 +120,7 @@ public List<BeamFnApi.SampledElement> samples() throws IOException {

// Serializing can take a lot of CPU time for larger or complex elements. Copy the array here
// so as to not slow down the main processing hot path.
List<T> bufferToSend;
List<WindowedValue<T>> bufferToSend;
int sampleIndex = 0;
synchronized (this) {
bufferToSend = buffer;
Expand All @@ -116,10 +132,13 @@ public List<BeamFnApi.SampledElement> samples() throws IOException {
ByteStringOutputStream stream = new ByteStringOutputStream();
for (int i = 0; i < bufferToSend.size(); i++) {
int index = (sampleIndex + i) % bufferToSend.size();
// This is deprecated, but until this is fully removed, this specifically needs the nested
// context. This is because the SDK will need to decode the sampled elements with the
// ToStringFn.
coder.encode(bufferToSend.get(index), stream, Coder.Context.NESTED);

if (valueCoder != null) {
this.valueCoder.encode(bufferToSend.get(index).getValue(), stream, Coder.Context.NESTED);
} else if (windowedValueCoder != null) {
this.windowedValueCoder.encode(bufferToSend.get(index), stream, Coder.Context.NESTED);
}

ret.add(
BeamFnApi.SampledElement.newBuilder().setElement(stream.toByteStringAndReset()).build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.junit.Test;
Expand Down Expand Up @@ -65,6 +66,10 @@ byte[] encodeByteArray(byte[] b) throws IOException {
return stream.toByteArray();
}

<T> WindowedValue<T> globalWindowedValue(T el) {
return WindowedValue.valueInGlobalWindow(el);
}

BeamFnApi.InstructionResponse getAllSamples(DataSampler dataSampler) {
BeamFnApi.InstructionRequest request =
BeamFnApi.InstructionRequest.newBuilder()
Expand Down Expand Up @@ -122,7 +127,7 @@ public void testSingleOutput() throws Exception {
DataSampler sampler = new DataSampler();

VarIntCoder coder = VarIntCoder.of();
sampler.sampleOutput("pcollection-id", coder).sample(1);
sampler.sampleOutput("pcollection-id", coder).sample(globalWindowedValue(1));

BeamFnApi.InstructionResponse samples = getAllSamples(sampler);
assertHasSamples(samples, "pcollection-id", Collections.singleton(encodeInt(1)));
Expand All @@ -140,7 +145,7 @@ public void testNestedContext() throws Exception {
String rawString = "hello";
byte[] byteArray = rawString.getBytes(StandardCharsets.US_ASCII);
ByteArrayCoder coder = ByteArrayCoder.of();
sampler.sampleOutput("pcollection-id", coder).sample(byteArray);
sampler.sampleOutput("pcollection-id", coder).sample(globalWindowedValue(byteArray));

BeamFnApi.InstructionResponse samples = getAllSamples(sampler);
assertHasSamples(samples, "pcollection-id", Collections.singleton(encodeByteArray(byteArray)));
Expand All @@ -156,8 +161,8 @@ public void testMultipleOutputs() throws Exception {
DataSampler sampler = new DataSampler();

VarIntCoder coder = VarIntCoder.of();
sampler.sampleOutput("pcollection-id-1", coder).sample(1);
sampler.sampleOutput("pcollection-id-2", coder).sample(2);
sampler.sampleOutput("pcollection-id-1", coder).sample(globalWindowedValue(1));
sampler.sampleOutput("pcollection-id-2", coder).sample(globalWindowedValue(2));

BeamFnApi.InstructionResponse samples = getAllSamples(sampler);
assertHasSamples(samples, "pcollection-id-1", Collections.singleton(encodeInt(1)));
Expand All @@ -174,21 +179,21 @@ public void testMultipleSamePCollections() throws Exception {
DataSampler sampler = new DataSampler();

VarIntCoder coder = VarIntCoder.of();
sampler.sampleOutput("pcollection-id", coder).sample(1);
sampler.sampleOutput("pcollection-id", coder).sample(2);
sampler.sampleOutput("pcollection-id", coder).sample(globalWindowedValue(1));
sampler.sampleOutput("pcollection-id", coder).sample(globalWindowedValue(2));

BeamFnApi.InstructionResponse samples = getAllSamples(sampler);
assertHasSamples(samples, "pcollection-id", ImmutableList.of(encodeInt(1), encodeInt(2)));
}

void generateStringSamples(DataSampler sampler) {
StringUtf8Coder coder = StringUtf8Coder.of();
sampler.sampleOutput("a", coder).sample("a1");
sampler.sampleOutput("a", coder).sample("a2");
sampler.sampleOutput("b", coder).sample("b1");
sampler.sampleOutput("b", coder).sample("b2");
sampler.sampleOutput("c", coder).sample("c1");
sampler.sampleOutput("c", coder).sample("c2");
sampler.sampleOutput("a", coder).sample(globalWindowedValue("a1"));
sampler.sampleOutput("a", coder).sample(globalWindowedValue("a2"));
sampler.sampleOutput("b", coder).sample(globalWindowedValue("b1"));
sampler.sampleOutput("b", coder).sample(globalWindowedValue("b2"));
sampler.sampleOutput("c", coder).sample(globalWindowedValue("c1"));
sampler.sampleOutput("c", coder).sample(globalWindowedValue("c2"));
}

/**
Expand Down Expand Up @@ -250,7 +255,7 @@ public void testConcurrentNewSampler() throws Exception {
}

for (int j = 0; j < 100; j++) {
sampler.sampleOutput("pcollection-" + j, coder).sample(0);
sampler.sampleOutput("pcollection-" + j, coder).sample(globalWindowedValue(0));
}

doneSignal.countDown();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
import java.util.concurrent.CountDownLatch;
import org.apache.beam.model.fnexecution.v1.BeamFnApi;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
Expand All @@ -45,6 +48,17 @@ public BeamFnApi.SampledElement encodeInt(Integer i) throws IOException {
.build();
}

public BeamFnApi.SampledElement encodeGlobalWindowedInt(Integer i) throws IOException {
WindowedValue.WindowedValueCoder<Integer> coder =
WindowedValue.FullWindowedValueCoder.of(VarIntCoder.of(), GlobalWindow.Coder.INSTANCE);

ByteArrayOutputStream stream = new ByteArrayOutputStream();
coder.encode(WindowedValue.valueInGlobalWindow(i), stream);
return BeamFnApi.SampledElement.newBuilder()
.setElement(ByteString.copyFrom(stream.toByteArray()))
.build();
}

/**
* Test that the first N are always sampled.
*
Expand All @@ -57,7 +71,7 @@ public void testSamplesFirstN() throws Exception {

// Purposely go over maxSamples and sampleEveryN. This helps to increase confidence.
for (int i = 0; i < 15; ++i) {
outputSampler.sample(i);
outputSampler.sample(WindowedValue.valueInGlobalWindow(i));
}

// The expected list is only 0..9 inclusive.
Expand All @@ -70,6 +84,33 @@ public void testSamplesFirstN() throws Exception {
assertThat(samples, containsInAnyOrder(expected.toArray()));
}

@Test
public void testWindowedValueSample() throws Exception {
WindowedValue.WindowedValueCoder<Integer> coder =
WindowedValue.FullWindowedValueCoder.of(VarIntCoder.of(), GlobalWindow.Coder.INSTANCE);

OutputSampler<Integer> outputSampler = new OutputSampler<>(coder, 10, 10);
outputSampler.sample(WindowedValue.valueInGlobalWindow(0));

// The expected list is only 0..9 inclusive.
List<BeamFnApi.SampledElement> expected = ImmutableList.of(encodeGlobalWindowedInt(0));
List<BeamFnApi.SampledElement> samples = outputSampler.samples();
assertThat(samples, containsInAnyOrder(expected.toArray()));
}

@Test
public void testNonWindowedValueSample() throws Exception {
VarIntCoder coder = VarIntCoder.of();

OutputSampler<Integer> outputSampler = new OutputSampler<>(coder, 10, 10);
outputSampler.sample(WindowedValue.valueInGlobalWindow(0));

// The expected list is only 0..9 inclusive.
List<BeamFnApi.SampledElement> expected = ImmutableList.of(encodeInt(0));
List<BeamFnApi.SampledElement> samples = outputSampler.samples();
assertThat(samples, containsInAnyOrder(expected.toArray()));
}

/**
* Test that the previous values are overwritten and only the most recent `maxSamples` are kept.
*
Expand All @@ -81,7 +122,7 @@ public void testActsLikeCircularBuffer() throws Exception {
OutputSampler<Integer> outputSampler = new OutputSampler<>(coder, 5, 20);

for (int i = 0; i < 100; ++i) {
outputSampler.sample(i);
outputSampler.sample(WindowedValue.valueInGlobalWindow(i));
}

// The first 10 are always sampled, but with maxSamples = 5, the first ten are downsampled to
Expand Down Expand Up @@ -124,7 +165,7 @@ public void testConcurrentSamples() throws Exception {
}

for (int i = 0; i < 1000000; i++) {
outputSampler.sample(i);
outputSampler.sample(WindowedValue.valueInGlobalWindow(i));
}

doneSignal.countDown();
Expand All @@ -141,7 +182,7 @@ public void testConcurrentSamples() throws Exception {
}

for (int i = -1000000; i < 0; i++) {
outputSampler.sample(i);
outputSampler.sample(WindowedValue.valueInGlobalWindow(i));
}

doneSignal.countDown();
Expand Down

0 comments on commit 04c2de6

Please sign in to comment.