Skip to content

Commit

Permalink
[BEAM-5392] GroupByKey optimized for non-merging windows
Browse files Browse the repository at this point in the history
  • Loading branch information
mareksimunek committed Feb 8, 2019
1 parent f1bb53a commit 4742939
Show file tree
Hide file tree
Showing 6 changed files with 427 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.beam.runners.spark.io.MicrobatchSource;
import org.apache.beam.runners.spark.stateful.SparkGroupAlsoByWindowViaWindowSet.StateAndTimers;
import org.apache.beam.runners.spark.translation.GroupCombineFunctions;
import org.apache.beam.runners.spark.translation.GroupNonMergingWindowsFunctions.WindowedKey;
import org.apache.beam.runners.spark.util.ByteArray;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.values.KV;
Expand Down Expand Up @@ -61,6 +62,7 @@ public void registerClasses(Kryo kryo) {
kryo.register(HashBasedTable.class);
kryo.register(KV.class);
kryo.register(PaneInfo.class);
kryo.register(WindowedKey.class);

try {
kryo.register(Class.forName("com.google.common.collect.HashBasedTable$Factory"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.runners.spark.SparkPipelineOptions;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
Expand Down Expand Up @@ -137,30 +139,39 @@ public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
}

/**
* Cache PCollection if {@link #isCacheDisabled()} flag is false and PCollection is used more then
* once in Pipeline.
* Cache PCollection if {@link #isCacheDisabled()} flag is false or transform isn't GroupByKey
* transformation and PCollection is used more then once in Pipeline.
*
* @param pvalue
* <p>PCollection is not cached in GroupByKey transformation, because Spark automatically persists
* some intermediate data in shuffle operations, even without users calling persist.
*
* @param pvalue output of transform
* @param transform
* @return if PCollection will be cached
*/
public boolean shouldCache(PValue pvalue) {
if (isCacheDisabled()) {
public boolean shouldCache(PValue pvalue, PTransform<?, ? extends PValue> transform) {
if (isCacheDisabled() || transform instanceof GroupByKey) {
return false;
}
return pvalue instanceof PCollection && cacheCandidates.getOrDefault(pvalue, 0L) > 1;
}

public void putDataset(PTransform<?, ? extends PValue> transform, Dataset dataset) {
putDataset(getOutput(transform), dataset);
putDataset(transform, getOutput(transform), dataset);
}

public void putDataset(PValue pvalue, Dataset dataset) {
putDataset(null, pvalue, dataset);
}

private void putDataset(
@Nullable PTransform<?, ? extends PValue> transform, PValue pvalue, Dataset dataset) {
try {
dataset.setName(pvalue.getName());
} catch (IllegalStateException e) {
// name not set, ignore
}
if (shouldCache(pvalue)) {
if (shouldCache(pvalue, transform)) {
// we cache only PCollection
Coder<?> coder = ((PCollection<?>) pvalue).getCoder();
Coder<? extends BoundedWindow> wCoder =
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.translation;

import java.io.Serializable;
import java.util.Arrays;
import java.util.Iterator;
import java.util.Objects;
import org.apache.beam.runners.spark.coders.CoderHelpers;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.AbstractIterator;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterators;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.PeekingIterator;
import org.apache.beam.vendor.guava.v20_0.com.google.common.primitives.UnsignedBytes;
import org.apache.spark.HashPartitioner;
import org.apache.spark.api.java.JavaRDD;
import org.joda.time.Instant;
import scala.Tuple2;

/** Functions for GroupByKey with Non-Merging windows translations to Spark. */
public class GroupNonMergingWindowsFunctions {

static <K, V, W extends BoundedWindow>
JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupByKeyAndWindow(
JavaRDD<WindowedValue<KV<K, V>>> rdd,
Coder<K> keyCoder,
Coder<V> valueCoder,
WindowingStrategy<?, W> windowingStrategy) {
final Coder<W> windowCoder = windowingStrategy.getWindowFn().windowCoder();
final WindowedValue.FullWindowedValueCoder<byte[]> windowedValueCoder =
WindowedValue.getFullCoder(ByteArrayCoder.of(), windowCoder);
return rdd.flatMapToPair(
(WindowedValue<KV<K, V>> windowedValue) -> {
final byte[] keyBytes =
CoderHelpers.toByteArray(windowedValue.getValue().getKey(), keyCoder);
final byte[] valueBytes =
CoderHelpers.toByteArray(windowedValue.getValue().getValue(), valueCoder);
return Iterators.transform(
windowedValue.explodeWindows().iterator(),
item -> {
Objects.requireNonNull(item, "Exploded window can not be null.");
@SuppressWarnings("unchecked")
final W window = (W) Iterables.getOnlyElement(item.getWindows());
final byte[] windowBytes = CoderHelpers.toByteArray(window, windowCoder);
final byte[] windowValueBytes =
CoderHelpers.toByteArray(
WindowedValue.of(
valueBytes, item.getTimestamp(), window, item.getPane()),
windowedValueCoder);
final WindowedKey windowedKey = new WindowedKey(keyBytes, windowBytes);
return new Tuple2<>(windowedKey, windowValueBytes);
});
})
.repartitionAndSortWithinPartitions(new HashPartitioner(rdd.getNumPartitions()))
.mapPartitions(
it ->
new GroupByKeyIterator<>(
it, keyCoder, valueCoder, windowingStrategy, windowedValueCoder))
.filter(Objects::nonNull); // filter last null element from GroupByKeyIterator
}

/**
* Transform stream of sorted key values into stream of value iterators for each key. This
* iterator can be iterated only once!
*
* @param <K> type of key iterator emits
* @param <V> type of value iterator emits
*/
static class GroupByKeyIterator<K, V, W extends BoundedWindow>
implements Iterator<WindowedValue<KV<K, Iterable<V>>>> {

private final PeekingIterator<Tuple2<WindowedKey, byte[]>> inner;
private final Coder<K> keyCoder;
private final Coder<V> valueCoder;
private final WindowingStrategy<?, W> windowingStrategy;
private final FullWindowedValueCoder<byte[]> windowedValueCoder;
private boolean hasNext = true;

private WindowedKey currentKey = null;

GroupByKeyIterator(
Iterator<Tuple2<WindowedKey, byte[]>> inner,
Coder<K> keyCoder,
Coder<V> valueCoder,
WindowingStrategy<?, W> windowingStrategy,
WindowedValue.FullWindowedValueCoder<byte[]> windowedValueCoder) {
this.inner = Iterators.peekingIterator(inner);
this.keyCoder = keyCoder;
this.valueCoder = valueCoder;
this.windowingStrategy = windowingStrategy;
this.windowedValueCoder = windowedValueCoder;
}

@Override
public boolean hasNext() {
return hasNext;
}

@Override
public WindowedValue<KV<K, Iterable<V>>> next() {
while (inner.hasNext()) {
final WindowedKey nextKey = inner.peek()._1;
if (nextKey.equals(currentKey)) {
// we still did not see all values for a given key
inner.next();
continue;
}
currentKey = nextKey;
final WindowedValue<KV<K, V>> decodedItem = decodeItem(inner.peek());
return decodedItem.withValue(
KV.of(decodedItem.getValue().getKey(), new ValueIterator(inner, currentKey)));
}
hasNext = false;
return null;
}

class ValueIterator implements Iterable<V> {

boolean usedAsIterable = false;
private final PeekingIterator<Tuple2<WindowedKey, byte[]>> inner;
private WindowedKey currentKey;

ValueIterator(PeekingIterator<Tuple2<WindowedKey, byte[]>> inner, WindowedKey currentKey) {
this.inner = inner;
this.currentKey = currentKey;
}

@Override
public Iterator<V> iterator() {
if (usedAsIterable) {
throw new IllegalStateException(
"ValueIterator can't be iterated more than once,"
+ "otherwise there could be data lost");
}
usedAsIterable = true;
return new AbstractIterator<V>() {
@Override
protected V computeNext() {
if (inner.hasNext() && currentKey.equals(inner.peek()._1)) {
return decodeValue(inner.next()._2);
}
return endOfData();
}
};
}
}

private V decodeValue(byte[] windowedValueBytes) {
final WindowedValue<byte[]> windowedValue =
CoderHelpers.fromByteArray(windowedValueBytes, windowedValueCoder);
return CoderHelpers.fromByteArray(windowedValue.getValue(), valueCoder);
}

private WindowedValue<KV<K, V>> decodeItem(Tuple2<WindowedKey, byte[]> item) {
final K key = CoderHelpers.fromByteArray(item._1.getKey(), keyCoder);
final WindowedValue<byte[]> windowedValue =
CoderHelpers.fromByteArray(item._2, windowedValueCoder);
final V value = CoderHelpers.fromByteArray(windowedValue.getValue(), valueCoder);
@SuppressWarnings("unchecked")
final W window = (W) Iterables.getOnlyElement(windowedValue.getWindows());
final Instant timestamp =
windowingStrategy
.getTimestampCombiner()
.assign(
window,
windowingStrategy
.getWindowFn()
.getOutputTime(windowedValue.getTimestamp(), window));
return WindowedValue.of(KV.of(key, value), timestamp, window, windowedValue.getPane());
}
}

/** Composite key of key and window for groupByKey transformation. */
public static class WindowedKey implements Comparable<WindowedKey>, Serializable {

private final byte[] key;
private final byte[] window;

WindowedKey(byte[] key, byte[] window) {
this.key = key;
this.window = window;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
WindowedKey that = (WindowedKey) o;
return Arrays.equals(key, that.key) && Arrays.equals(window, that.window);
}

@Override
public int hashCode() {
int result = Arrays.hashCode(key);
result = 31 * result + Arrays.hashCode(window);
return result;
}

public byte[] getKey() {
return key;
}

public byte[] getWindow() {
return window;
}

@Override
public int compareTo(WindowedKey o) {
int keyCompare = UnsignedBytes.lexicographicalComparator().compare(this.getKey(), o.getKey());
if (keyCompare == 0) {
return UnsignedBytes.lexicographicalComparator().compare(this.getWindow(), o.getWindow());
}
return keyCompare;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.transforms.windowing.Window;
import org.apache.beam.sdk.transforms.windowing.WindowFn;
import org.apache.beam.sdk.util.CombineFnUtil;
Expand Down Expand Up @@ -132,22 +133,32 @@ public void evaluate(GroupByKey<K, V> transform, EvaluationContext context) {
final WindowedValue.WindowedValueCoder<V> wvCoder =
WindowedValue.FullWindowedValueCoder.of(coder.getValueCoder(), windowFn.windowCoder());

// --- group by key only.
JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKey =
GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder, getPartitioner(context));

// --- now group also by window.
// for batch, GroupAlsoByWindow uses an in-memory StateInternals.
JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedAlsoByWindow =
groupedByKey.flatMap(
new SparkGroupAlsoByWindowViaOutputBufferFn<>(
windowingStrategy,
new TranslationUtils.InMemoryStateInternalsFactory<>(),
SystemReduceFn.buffering(coder.getValueCoder()),
context.getSerializableOptions(),
accum));

context.putDataset(transform, new BoundedDataset<>(groupedAlsoByWindow));
JavaRDD<WindowedValue<KV<K, Iterable<V>>>> groupedByKey;
if (windowingStrategy.getWindowFn().isNonMerging()
&& windowingStrategy.getTimestampCombiner() == TimestampCombiner.END_OF_WINDOW) {
// we can have a memory sensitive translation for non-merging windows
groupedByKey =
GroupNonMergingWindowsFunctions.groupByKeyAndWindow(
inRDD, keyCoder, coder.getValueCoder(), windowingStrategy);
} else {

// --- group by key only.
Partitioner partitioner = getPartitioner(context);
JavaRDD<WindowedValue<KV<K, Iterable<WindowedValue<V>>>>> groupedByKeyOnly =
GroupCombineFunctions.groupByKeyOnly(inRDD, keyCoder, wvCoder, partitioner);

// --- now group also by window.
// for batch, GroupAlsoByWindow uses an in-memory StateInternals.
groupedByKey =
groupedByKeyOnly.flatMap(
new SparkGroupAlsoByWindowViaOutputBufferFn<>(
windowingStrategy,
new TranslationUtils.InMemoryStateInternalsFactory<>(),
SystemReduceFn.buffering(coder.getValueCoder()),
context.getSerializableOptions(),
accum));
}
context.putDataset(transform, new BoundedDataset<>(groupedByKey));
}

@Override
Expand Down
Loading

0 comments on commit 4742939

Please sign in to comment.