Skip to content

Commit

Permalink
apache#13 [euphoria-flink] Fix bug in batch ReduceByKeyTranslator and…
Browse files Browse the repository at this point in the history
… add unit test
  • Loading branch information
vanekjar committed Apr 27, 2017
1 parent 7f0da3e commit 2097155
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ public DataSet translate(FlinkOperator<ReduceByKey> operator,
// ~ reduce the data now
Operator<BatchElement<Window, Pair>, ?> reduced =
tuples.groupBy(new RBKKeySelector())
.combineGroup(new RBKReducer(reducer))
.groupBy(new RBKKeySelector())
.reduceGroup(new RBKReducer(reducer))
.setParallelism(operator.getParallelism())
.name(operator.getName() + "::reduce");
Expand Down Expand Up @@ -190,13 +188,11 @@ private void doReduce(Iterable<BatchElement<Window, Pair>> values,
// Tuple2[Window, Key] => Reduced Value
Map<Tuple2, BatchElement<Window, Pair>> reducedValues = new HashMap<>();

Tuple2 kw = new Tuple2();
for (BatchElement<Window, Pair> batchElement : values) {
Object key = batchElement.getElement().getFirst();
Window window = batchElement.getWindow();

kw.f0 = window;
kw.f1 = key;
Tuple2 kw = new Tuple2<>(window, key);
BatchElement<Window, Pair> val = reducedValues.get(kw);
if (val == null) {
reducedValues.put(kw, batchElement);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -645,4 +645,81 @@ public void validate(Partitions<Integer> partitions) {
}
});
}

@Test
public void testReduceByKeyWithWrongHashCodeImpl() {
execute(new AbstractTestCase<Pair<Word, Long>, Pair<Word, Long>>() {

@Override
protected Dataset<Pair<Word, Long>> getOutput(Dataset<Pair<Word, Long>> input) {
return ReduceByKey.of(input)
.keyBy(Pair::getFirst)
.valueBy(e -> 1L)
.combineBy(Sums.ofLongs())
.windowBy(Time.of(Duration.ofSeconds(1)), Pair::getSecond)
.output();
}

@Override
protected Partitions<Pair<Word, Long>> getInput() {
return Partitions
.add(
Pair.of(new Word("euphoria"), 300L),
Pair.of(new Word("euphoria"), 600L),
Pair.of(new Word("spark"), 900L),
Pair.of(new Word("euphoria"), 1300L),
Pair.of(new Word("flink"), 1600L),
Pair.of(new Word("spark"), 1900L))
.build();
}

@Override
public int getNumOutputPartitions() {
return 2;
}

@Override
public void validate(Partitions<Pair<Word, Long>> partitions) {
assertEquals(2, partitions.size());
List<Pair<Word, Long>> data = partitions.get(0);
assertUnorderedEquals(Arrays.asList(
Pair.of(new Word("euphoria"), 2L), Pair.of(new Word("spark"), 1L), // first window
Pair.of(new Word("euphoria"), 1L), Pair.of(new Word("spark"), 1L), // second window
Pair.of(new Word("flink"), 1L)),
data);
}
});
}

/**
* String with invalid hash code implementation returning constant.
*/
private static class Word {
private final String str;

public Word(String str) {
this.str = str;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Word)) return false;

Word word = (Word) o;

return !(str != null ? !str.equals(word.str) : word.str != null);

}

@Override
public int hashCode() {
return 42;
}

@Override
public String toString() {
return str;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ List<Integer> flatten(List<Triple<Integer, Integer, Integer>> l) {

// ---------------------------------------------------------------------------------

private static class CountState extends State<String, Long> {
private static class CountState<IN> extends State<IN, Long> {
final ValueStorage<Long> count;
CountState(Context<Long> context, StorageProvider storageProvider)
{
Expand All @@ -232,7 +232,7 @@ private static class CountState extends State<String, Long> {

@Override
@SuppressWarnings("unchecked")
public void add(String element) {
public void add(IN element) {
count.set(count.get() + 1);
}

Expand All @@ -241,7 +241,7 @@ public void flush() {
getContext().collect(count.get());
}

void add(CountState other) {
void add(CountState<IN> other) {
count.set(count.get() + other.count.get());
}

Expand All @@ -250,8 +250,8 @@ public void close() {
count.clear();
}

public static void combine(CountState target, Iterable<CountState> others) {
for (CountState other : others) {
public static <IN> void combine(CountState<IN> target, Iterable<CountState<IN>> others) {
for (CountState<IN> other : others) {
target.add(other);
}
}
Expand Down Expand Up @@ -283,7 +283,7 @@ protected Partitions<Pair<String, Integer>> getInput() {
return ReduceStateByKey.of(input)
.keyBy(e -> e.getFirst().charAt(0) - '0')
.valueBy(Pair::getFirst)
.stateFactory((StateFactory<String, Long, CountState>) CountState::new)
.stateFactory((StateFactory<String, Long, CountState<String>>) CountState::new)
.mergeStatesBy(CountState::combine)
// FIXME .timedBy(Pair::getSecond) and make the assertion in the validation phase stronger
.windowBy(Count.of(3))
Expand Down Expand Up @@ -660,4 +660,84 @@ public void validate(Partitions<Integer> partitions) {
}
});
}

// ------------------------------------

@Test
public void testReduceStateByKeyWithWrongHashCodeImpl() {
execute(new AbstractTestCase<Pair<Word, Long>, Pair<Word, Long>>() {

@Override
protected Dataset<Pair<Word, Long>> getOutput(Dataset<Pair<Word, Long>> input) {
return ReduceStateByKey.of(input)
.keyBy(Pair::getFirst)
.valueBy(Pair::getFirst)
.stateFactory((StateFactory<Word, Long, CountState<Word>>) CountState::new)
.mergeStatesBy(CountState::combine)
.windowBy(Time.of(Duration.ofSeconds(1)), (Pair<Word, Long> p) -> p.getSecond())
.output();
}

@Override
protected Partitions<Pair<Word, Long>> getInput() {
return Partitions
.add(
Pair.of(new Word("euphoria"), 300L),
Pair.of(new Word("euphoria"), 600L),
Pair.of(new Word("spark"), 900L),
Pair.of(new Word("euphoria"), 1300L),
Pair.of(new Word("flink"), 1600L),
Pair.of(new Word("spark"), 1900L))
.build();
}

@Override
public int getNumOutputPartitions() {
return 2;
}

@Override
public void validate(Partitions<Pair<Word, Long>> partitions) {
assertEquals(2, partitions.size());
List<Pair<Word, Long>> data = partitions.get(0);
assertUnorderedEquals(Arrays.asList(
Pair.of(new Word("euphoria"), 2L), Pair.of(new Word("spark"), 1L), // first window
Pair.of(new Word("euphoria"), 1L), Pair.of(new Word("spark"), 1L), // second window
Pair.of(new Word("flink"), 1L)),
data);
}
});
}

/**
* String with invalid hash code implementation returning constant.
*/
private static class Word {
private final String str;

public Word(String str) {
this.str = str;
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (!(o instanceof Word)) return false;

Word word = (Word) o;

return !(str != null ? !str.equals(word.str) : word.str != null);

}

@Override
public int hashCode() {
return 42;
}

@Override
public String toString() {
return str;
}
}
}

0 comments on commit 2097155

Please sign in to comment.