Skip to content

Commit

Permalink
[BEAM-5392] small CR corrections
Browse files Browse the repository at this point in the history
  • Loading branch information
mareksimunek committed Feb 8, 2019
1 parent 4742939 commit 5a7e8fa
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
}

/**
* Cache PCollection if {@link #isCacheDisabled()} flag is false or transform isn't GroupByKey
* transformation and PCollection is used more then once in Pipeline.
* Cache PCollection if SparkPipelineOptions.isCacheDisabled is false or transform isn't
* GroupByKey transformation and PCollection is used more then once in Pipeline.
*
* <p>PCollection is not cached in GroupByKey transformation, because Spark automatically persists
* some intermediate data in shuffle operations, even without users calling persist.
Expand All @@ -149,29 +149,52 @@ public Map<TupleTag<?>, Coder<?>> getOutputCoders() {
* @param transform
* @return if PCollection will be cached
*/
public boolean shouldCache(PValue pvalue, PTransform<?, ? extends PValue> transform) {
if (isCacheDisabled() || transform instanceof GroupByKey) {
public boolean shouldCache(PTransform<?, ? extends PValue> transform, PValue pvalue) {
if (serializableOptions.get().as(SparkPipelineOptions.class).isCacheDisabled()
|| transform instanceof GroupByKey) {
return false;
}
return pvalue instanceof PCollection && cacheCandidates.getOrDefault(pvalue, 0L) > 1;
}

/**
* Add single output of transform to context map and possibly cache if it conforms {@link
* #shouldCache(PTransform, PValue)}.
*
* @param transform from which Dataset was created
* @param dataset created Dataset from transform
*/
public void putDataset(PTransform<?, ? extends PValue> transform, Dataset dataset) {
putDataset(transform, getOutput(transform), dataset);
}

/**
* Add output of transform to context map and possibly cache if it conforms {@link
* #shouldCache(PTransform, PValue)}. Used when PTransform has multiple outputs.
*
* @param pvalue one of multiple outputs of transform
* @param dataset created Dataset from transform
*/
public void putDataset(PValue pvalue, Dataset dataset) {
putDataset(null, pvalue, dataset);
}

/**
* Add output of transform to context map and possibly cache if it conforms {@link
* #shouldCache(PTransform, PValue)}.
*
* @param transform from which Dataset was created
* @param pvalue output of transform
* @param dataset created Dataset from transform
*/
private void putDataset(
@Nullable PTransform<?, ? extends PValue> transform, PValue pvalue, Dataset dataset) {
try {
dataset.setName(pvalue.getName());
} catch (IllegalStateException e) {
// name not set, ignore
}
if (shouldCache(pvalue, transform)) {
if (shouldCache(transform, pvalue)) {
// we cache only PCollection
Coder<?> coder = ((PCollection<?>) pvalue).getCoder();
Coder<? extends BoundedWindow> wCoder =
Expand Down Expand Up @@ -266,8 +289,4 @@ <T> Iterable<WindowedValue<T>> getWindowedValues(PCollection<T> pcollection) {
public String storageLevel() {
return serializableOptions.get().as(SparkPipelineOptions.class).getStorageLevel();
}

public boolean isCacheDisabled() {
return serializableOptions.get().as(SparkPipelineOptions.class).isCacheDisabled();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ static class GroupByKeyIterator<K, V, W extends BoundedWindow>
private final Coder<V> valueCoder;
private final WindowingStrategy<?, W> windowingStrategy;
private final FullWindowedValueCoder<byte[]> windowedValueCoder;
private boolean hasNext = true;

private boolean hasNext = true;
private WindowedKey currentKey = null;

GroupByKeyIterator(
Expand Down Expand Up @@ -140,7 +140,7 @@ class ValueIterator implements Iterable<V> {

boolean usedAsIterable = false;
private final PeekingIterator<Tuple2<WindowedKey, byte[]>> inner;
private WindowedKey currentKey;
private final WindowedKey currentKey;

ValueIterator(PeekingIterator<Tuple2<WindowedKey, byte[]>> inner, WindowedKey currentKey) {
this.inner = inner;
Expand Down Expand Up @@ -222,11 +222,11 @@ public int hashCode() {
return result;
}

public byte[] getKey() {
byte[] getKey() {
return key;
}

public byte[] getWindow() {
byte[] getWindow() {
return window;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;

import org.apache.beam.runners.spark.translation.Dataset;
import org.apache.beam.runners.spark.translation.EvaluationContext;
Expand All @@ -31,6 +32,7 @@
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.Create.Values;
import org.apache.beam.sdk.transforms.GroupByKey;
import org.apache.beam.sdk.values.PCollection;
import org.apache.spark.api.java.JavaSparkContext;
import org.junit.Test;
Expand All @@ -43,10 +45,8 @@ public class CacheTest {
* pipeline.
*/
@Test
public void cacheCandidatesUpdaterTest() throws Exception {
SparkPipelineOptions options =
PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
options.setRunner(TestSparkRunner.class);
public void cacheCandidatesUpdaterTest() {
SparkPipelineOptions options = createOptions();
Pipeline pipeline = Pipeline.create(options);
PCollection<String> pCollection = pipeline.apply(Create.of("foo", "bar"));
// first read
Expand All @@ -65,22 +65,31 @@ public void cacheCandidatesUpdaterTest() throws Exception {
}

@Test
public void cacheDisabledOptionTest() {
SparkPipelineOptions options =
PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
options.setRunner(TestSparkRunner.class);
public void shouldCacheTest() {
SparkPipelineOptions options = createOptions();
options.setCacheDisabled(true);
Pipeline pipeline = Pipeline.create(options);
Values<String> createInput = Create.of("foo", "bar");
PCollection<String> pCollection = pipeline.apply(createInput);

Values<String> valuesTransform = Create.of("foo", "bar");
PCollection pCollection = mock(PCollection.class);

JavaSparkContext jsc = SparkContextFactory.getSparkContext(options);
EvaluationContext ctxt = new EvaluationContext(jsc, pipeline, options);
ctxt.getCacheCandidates().put(pCollection, 2L);

assertFalse(ctxt.shouldCache(pCollection, createInput));
assertFalse(ctxt.shouldCache(valuesTransform, pCollection));

options.setCacheDisabled(false);
assertTrue(ctxt.shouldCache(pCollection, createInput));
assertTrue(ctxt.shouldCache(valuesTransform, pCollection));

GroupByKey<String, String> gbkTransform = GroupByKey.create();
assertFalse(ctxt.shouldCache(gbkTransform, pCollection));
}

private SparkPipelineOptions createOptions() {
SparkPipelineOptions options =
PipelineOptionsFactory.create().as(TestSparkPipelineOptions.class);
options.setRunner(TestSparkRunner.class);
return options;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void testGbkIteratorValuesCannotBeReiterated() {
}
}

GroupByKeyIterator<String, Integer, GlobalWindow> createGbkIterator() {
private GroupByKeyIterator<String, Integer, GlobalWindow> createGbkIterator() {
StringUtf8Coder keyCoder = StringUtf8Coder.of();
BigEndianIntegerCoder valueCoder = BigEndianIntegerCoder.of();
WindowingStrategy<Object, GlobalWindow> winStrategy = WindowingStrategy.of(new GlobalWindows());
Expand Down Expand Up @@ -108,7 +108,7 @@ private static class ItemFactory<K, V> {
private final WindowedValue.FullWindowedValueCoder<byte[]> winValCoder;
private final byte[] globalWindow;

public ItemFactory(
ItemFactory(
Coder<K> keyCoder,
Coder<V> valueCoder,
FullWindowedValueCoder<byte[]> winValCoder,
Expand Down

0 comments on commit 5a7e8fa

Please sign in to comment.