From 230a03f8bfc9148b68e865fcd36f3d127ace4837 Mon Sep 17 00:00:00 2001 From: zhengbuqian Date: Fri, 19 Aug 2022 22:58:59 +0000 Subject: [PATCH] Add windmill APIs proto and windmill legacy worker implementations for multimap state --- .../beam/runners/dataflow/DataflowRunner.java | 11 +- .../worker/WindmillStateInternals.java | 657 +++++++++- .../dataflow/worker/WindmillStateReader.java | 318 ++++- .../worker/WindmillStateInternalsTest.java | 1094 +++++++++++++++++ .../worker/WindmillStateReaderTest.java | 721 +++++++++++ .../windmill/src/main/proto/windmill.proto | 80 +- 6 files changed, 2856 insertions(+), 25 deletions(-) diff --git a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java index e6848cd9d0613..2667065d1cf1f 100644 --- a/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java +++ b/runners/google-cloud-dataflow-java/src/main/java/org/apache/beam/runners/dataflow/DataflowRunner.java @@ -2443,10 +2443,10 @@ static boolean useStreamingEngine(DataflowPipelineOptions options) { static void verifyDoFnSupported( DoFn fn, boolean streaming, DataflowPipelineOptions options) { - if (DoFnSignatures.usesMultimapState(fn)) { + if (!streaming && DoFnSignatures.usesMultimapState(fn)) { throw new UnsupportedOperationException( String.format( - "%s does not currently support %s", + "%s does not currently support %s in batch mode", DataflowRunner.class.getSimpleName(), MultimapState.class.getSimpleName())); } if (streaming && DoFnSignatures.requiresTimeSortedInput(fn)) { @@ -2458,6 +2458,13 @@ static void verifyDoFnSupported( boolean streamingEngine = useStreamingEngine(options); boolean isUnifiedWorker = useUnifiedWorker(options); + + if (DoFnSignatures.usesMultimapState(fn) && isUnifiedWorker) { + throw new UnsupportedOperationException( + String.format( + "%s does not currently support %s running using streaming on unified worker", + DataflowRunner.class.getSimpleName(), MultimapState.class.getSimpleName())); + } if (DoFnSignatures.usesSetState(fn)) { if (streaming && (isUnifiedWorker || streamingEngine)) { throw new UnsupportedOperationException( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java index 46da47027fecc..ba875bd5ea0ab 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternals.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Objects; import java.util.Random; import java.util.Set; import java.util.SortedSet; @@ -40,7 +41,9 @@ import java.util.concurrent.Future; import java.util.function.BiConsumer; import java.util.function.Function; +import java.util.stream.Collectors; import javax.annotation.concurrent.NotThreadSafe; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Triple; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTable; @@ -198,8 +201,15 @@ public MultimapState bindMultimap( StateTag> spec, Coder keyCoder, Coder valueCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", MultimapState.class.getSimpleName())); + WindmillMultimap result = + (WindmillMultimap) cache.get(namespace, spec); + if (result == null) { + result = + new WindmillMultimap<>( + namespace, spec, stateFamily, keyCoder, valueCoder, isNewKey); + } + result.initializeForWorkItem(reader, scopedReadStateSupplier); + return result; } @Override @@ -1588,7 +1598,648 @@ private Future>> getFuture() { return reader.valuePrefixFuture(stateKeyPrefix, stateFamily, valueCoder); } } - }; + } + + private static class WindmillMultimap extends SimpleWindmillState + implements MultimapState { + + private final StateNamespace namespace; + private final StateTag> address; + private final ByteString stateKey; + private final String stateFamily; + private final Coder keyCoder; + private final Coder valueCoder; + + private class KeyState { + final K originalKey; + KeyExistence existence; + // valuesCached can be true if only existence == KNOWN_EXIST and all values of this key is + // cached (both values and localAdditions). + boolean valuesCached; + // Represents the values in windmill. When new values are added during user processing, they + // are added to localAdditions but not values. Those new values will be added to values only + // after they are persisted into windmill and removed from localAdditions + ConcatIterables values; + int valuesSize; + + // When new values are added during user processing, they are added to localAdditions, so that + // we can later try to persist them in windmill. When a key is removed during user processing, + // we mark removedLocally to be true so that we can later try to delete it from windmill. If + // localAdditions is not empty and removedLocally is true, values in localAdditions will be + // added to windmill after old values in windmill are removed. + List localAdditions; + boolean removedLocally; + + KeyState(K originalKey) { + this.originalKey = originalKey; + existence = KeyExistence.UNKNOWN_EXISTENCE; + valuesCached = complete; + values = new ConcatIterables<>(); + valuesSize = 0; + localAdditions = Lists.newArrayList(); + removedLocally = false; + } + } + + private enum KeyExistence { + // this key is known to exist + KNOWN_EXIST, + // this key is known to be nonexistent + KNOWN_NONEXISTENT, + // we don't know if this key is in this multimap, this is just to provide a mapping between + // the original key and the structural key. + UNKNOWN_EXISTENCE + } + + private boolean cleared = false; + // We use the structural value of the keys as the key in keyStateMap, so that different java + // Objects with the same content will be treated as the same Multimap key. + private Map keyStateMap = Maps.newHashMap(); + // If true, all keys are cached in keyStateMap with existence == KNOWN_EXIST. + private boolean allKeysKnown = false; + + private boolean complete = false; + // hasLocalAdditions and hasLocalRemovals track whether there are local changes that needs to be + // propagated to windmill. + private boolean hasLocalAdditions = false; + private boolean hasLocalRemovals = false; + + private WindmillMultimap( + StateNamespace namespace, + StateTag> address, + String stateFamily, + Coder keyCoder, + Coder valueCoder, + boolean isNewShardingKey) { + this.namespace = namespace; + this.address = address; + this.stateKey = encodeKey(namespace, address); + this.stateFamily = stateFamily; + this.keyCoder = keyCoder; + this.valueCoder = valueCoder; + this.complete = isNewShardingKey; + this.allKeysKnown = isNewShardingKey; + } + + @Override + public void put(K key, V value) { + final Object structuralKey = keyCoder.structuralValue(key); + hasLocalAdditions = true; + keyStateMap.compute( + structuralKey, + (k, v) -> { + if (v == null) v = new KeyState(key); + v.existence = KeyExistence.KNOWN_EXIST; + v.localAdditions.add(value); + return v; + }); + } + + // Initiates a backend state read to fetch all entries if necessary. + private Future>>> getFuture(boolean omitValues) { + if (complete) { + return Futures.immediateFuture(Collections.emptyList()); + } else { + return reader.multimapFetchAllFuture(omitValues, stateKey, stateFamily, valueCoder); + } + } + + // Initiates a backend state read to fetch a single entry if necessary. + private Future> getFutureForKey(K key) { + try { + ByteStringOutputStream keyStream = new ByteStringOutputStream(); + keyCoder.encode(key, keyStream, Context.OUTER); + return reader.multimapFetchSingleEntryFuture( + keyStream.toByteString(), stateKey, stateFamily, valueCoder); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public ReadableState> get(K key) { + return new ReadableState>() { + final Object structuralKey = keyCoder.structuralValue(key); + + @Override + public Iterable read() { + KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key)); + if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) { + return Collections.emptyList(); + } + if (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) { + keyStateMap.remove(structuralKey); + return Collections.emptyList(); + } + Iterable localNewValues = + Iterables.limit(keyState.localAdditions, keyState.localAdditions.size()); + if (keyState.removedLocally) { + // this key has been removed locally but the removal hasn't been sent to windmill, + // thus values in windmill(if any) are obsolete, and we only care about local values. + return Iterables.unmodifiableIterable(localNewValues); + } + if (keyState.valuesCached || complete) { + return Iterables.unmodifiableIterable( + Iterables.concat( + Iterables.limit(keyState.values, keyState.valuesSize), localNewValues)); + } + Future> persistedData = getFutureForKey(key); + try (Closeable scope = scopedReadState()) { + final Iterable persistedValues = persistedData.get(); + // Iterables.isEmpty() is O(1). + if (Iterables.isEmpty(persistedValues)) { + if (keyState.localAdditions.isEmpty()) { + // empty in both cache and windmill, mark key as KNOWN_NONEXISTENT. + keyState.existence = KeyExistence.KNOWN_NONEXISTENT; + return Collections.emptyList(); + } + return Iterables.unmodifiableIterable(localNewValues); + } + keyState.existence = KeyExistence.KNOWN_EXIST; + if (persistedValues instanceof Weighted) { + keyState.valuesCached = true; + ConcatIterables it = new ConcatIterables<>(); + it.extendWith(persistedValues); + keyState.values = it; + keyState.valuesSize = Iterables.size(persistedValues); + } + return Iterables.unmodifiableIterable( + Iterables.concat(persistedValues, localNewValues)); + } catch (InterruptedException | ExecutionException | IOException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("Unable to read Multimap state", e); + } + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public ReadableState> readLater() { + WindmillMultimap.this.getFutureForKey(key); + return this; + } + }; + } + + @Override + protected WorkItemCommitRequest persistDirectly(WindmillStateCache.ForKeyAndFamily cache) + throws IOException { + if (!cleared && !hasLocalAdditions && !hasLocalRemovals) { + cache.put(namespace, address, this, 1); + return WorkItemCommitRequest.newBuilder().buildPartial(); + } + WorkItemCommitRequest.Builder commitBuilder = WorkItemCommitRequest.newBuilder(); + Windmill.TagMultimapUpdateRequest.Builder builder = commitBuilder.addMultimapUpdatesBuilder(); + builder.setTag(stateKey).setStateFamily(stateFamily); + + if (cleared) { + builder.setDeleteAll(true); + } + if (hasLocalRemovals || hasLocalAdditions) { + ByteStringOutputStream keyStream = new ByteStringOutputStream(); + ByteStringOutputStream valueStream = new ByteStringOutputStream(); + Iterator> iterator = keyStateMap.entrySet().iterator(); + while (iterator.hasNext()) { + KeyState keyState = iterator.next().getValue(); + if (!keyState.removedLocally && keyState.localAdditions.isEmpty()) { + if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT) iterator.remove(); + continue; + } + keyCoder.encode(keyState.originalKey, keyStream, Context.OUTER); + ByteString encodedKey = keyStream.toByteStringAndReset(); + Windmill.TagMultimapEntry.Builder entryBuilder = builder.addUpdatesBuilder(); + entryBuilder.setEntryName(encodedKey); + entryBuilder.setDeleteAll(keyState.removedLocally); + keyState.removedLocally = false; + for (V value : keyState.localAdditions) { + valueCoder.encode(value, valueStream, Context.OUTER); + ByteString encodedValue = valueStream.toByteStringAndReset(); + entryBuilder.addValues(encodedValue); + } + // Move newly added values from localAdditions to keyState.values as those new values now + // are also persisted in Windmill. If a key now has no more values and is not KNOWN_EXIST, + // remove it from cache. + if (keyState.valuesCached) { + keyState.values.extendWith(keyState.localAdditions); + keyState.valuesSize += keyState.localAdditions.size(); + } + // Create a new localAdditions so that the cached values are unaffected. + keyState.localAdditions = Lists.newArrayList(); + if (!keyState.valuesCached && keyState.existence != KeyExistence.KNOWN_EXIST) { + iterator.remove(); + } + } + } + + hasLocalAdditions = false; + hasLocalRemovals = false; + cleared = false; + + cache.put(namespace, address, this, 1); + return commitBuilder.buildPartial(); + } + + @Override + public void remove(K key) { + final Object structuralKey = keyCoder.structuralValue(key); + KeyState keyState = keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(key)); + if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT + || (allKeysKnown && keyState.existence == KeyExistence.UNKNOWN_EXISTENCE)) { + return; + } + if (!keyState.valuesCached || (keyState.valuesCached && keyState.valuesSize > 0)) { + // there may be data in windmill that need to be removed. + hasLocalRemovals = true; + keyState.removedLocally = true; + keyState.values = new ConcatIterables<>(); + keyState.valuesSize = 0; + keyState.valuesCached = false; + keyState.existence = KeyExistence.KNOWN_NONEXISTENT; + } else { + // no data in windmill, deleting from local cache is sufficient. + keyStateMap.remove(structuralKey); + } + if (!keyState.localAdditions.isEmpty()) { + keyState.localAdditions = Lists.newArrayList(); + } + } + + @Override + public void clear() { + keyStateMap = Maps.newHashMap(); + cleared = true; + complete = true; + allKeysKnown = true; + } + + @Override + public ReadableState> keys() { + return new ReadableState>() { + + private Map cachedExistKeys() { + return keyStateMap.entrySet().stream() + .filter(entry -> entry.getValue().existence == KeyExistence.KNOWN_EXIST) + .collect(Collectors.toMap(Entry::getKey, e -> e.getValue().originalKey)); + } + + @Override + public Iterable read() { + if (allKeysKnown) { + return Iterables.unmodifiableIterable(cachedExistKeys().values()); + } + Future>>> persistedData = getFuture(true); + try (Closeable scope = scopedReadState()) { + Iterable>> entries = persistedData.get(); + if (entries instanceof Weighted) { + // This is a known amount of data, cache them all. + entries.forEach( + entry -> { + try { + K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER); + KeyState keyState = + keyStateMap.computeIfAbsent( + keyCoder.structuralValue(originalKey), + stk -> new KeyState(originalKey)); + if (keyState.existence == KeyExistence.UNKNOWN_EXISTENCE) { + keyState.existence = KeyExistence.KNOWN_EXIST; + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + allKeysKnown = true; + keyStateMap + .values() + .removeIf(keyState -> keyState.existence != KeyExistence.KNOWN_EXIST); + return Iterables.unmodifiableIterable(cachedExistKeys().values()); + } else { + Map cachedExistKeys = Maps.newHashMap(); + Set cachedNonExistKeys = Sets.newHashSet(); + keyStateMap.forEach( + (structuralKey, keyState) -> { + switch (keyState.existence) { + case KNOWN_EXIST: + cachedExistKeys.put(structuralKey, keyState.originalKey); + break; + case KNOWN_NONEXISTENT: + cachedNonExistKeys.add(structuralKey); + break; + default: + break; + } + }); + // keysOnlyInWindmill is lazily loaded. + Iterable keysOnlyInWindmill = + Iterables.filter( + Iterables.transform( + entries, + entry -> { + try { + K originalKey = + keyCoder.decode(entry.getKey().newInput(), Context.OUTER); + Object structuralKey = keyCoder.structuralValue(originalKey); + if (cachedExistKeys.containsKey(structuralKey) + || cachedNonExistKeys.contains(structuralKey)) return null; + return originalKey; + } catch (IOException e) { + throw new RuntimeException(e); + } + }), + Objects::nonNull); + return Iterables.unmodifiableIterable( + Iterables.concat(cachedExistKeys.values(), keysOnlyInWindmill)); + } + } catch (InterruptedException | ExecutionException | IOException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("Unable to read state", e); + } + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public ReadableState> readLater() { + WindmillMultimap.this.getFuture(true); + return this; + } + }; + } + + @Override + public ReadableState>> entries() { + return new ReadableState>>() { + @Override + public Iterable> read() { + if (complete) { + return Iterables.unmodifiableIterable( + unnestCachedEntries(mergedCachedEntries(null).entrySet())); + } + Future>>> persistedData = getFuture(false); + try (Closeable scope = scopedReadState()) { + Iterable>> entries = persistedData.get(); + if (Iterables.isEmpty(entries)) { + complete = true; + allKeysKnown = true; + return Iterables.unmodifiableIterable( + unnestCachedEntries(mergedCachedEntries(null).entrySet())); + } + if (!(entries instanceof Weighted)) { + return nonWeightedEntries(entries); + } + // This is a known amount of data, cache them all. + entries.forEach( + entry -> { + try { + final K originalKey = keyCoder.decode(entry.getKey().newInput(), Context.OUTER); + final Object structuralKey = keyCoder.structuralValue(originalKey); + KeyState keyState = + keyStateMap.computeIfAbsent(structuralKey, k -> new KeyState(originalKey)); + // Ignore any key from windmill that has been marked pending deletion or is + // fully cached. + if (keyState.existence == KeyExistence.KNOWN_NONEXISTENT + || (keyState.existence == KeyExistence.KNOWN_EXIST + && keyState.valuesCached)) return; + // Or else cache contents from windmill. + keyState.existence = KeyExistence.KNOWN_EXIST; + keyState.values.extendWith(entry.getValue()); + keyState.valuesSize += Iterables.size(entry.getValue()); + // We can't set keyState.valuesCached to true here, because there may be more + // paginated values that should not be filtered out in above if statement. + // keyState.valuesCached will be set to true in later call of + // mergedCachedEntries. + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + allKeysKnown = true; + complete = true; + return Iterables.unmodifiableIterable( + unnestCachedEntries(mergedCachedEntries(null).entrySet())); + } catch (InterruptedException | ExecutionException | IOException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + throw new RuntimeException("Unable to read state", e); + } + } + + @Override + @SuppressWarnings("FutureReturnValueIgnored") + public ReadableState>> readLater() { + WindmillMultimap.this.getFuture(false); + return this; + } + + // Collect all cached entries into a map and all KNOWN_NONEXISTENT keys to + // knownNonexistentKeys(if not null). Note that this method is not side-effect-free: it + // unloads any key that is not KNOWN_EXIST and not pending deletion from cache; also if + // complete it marks the valuesCached of any key that is KNOWN_EXIST to true, entries() + // depends on this behavior when the fetched result is weighted to iterate the whole + // keyStateMap one less time. + private Map>> mergedCachedEntries( + Map knownNonexistentKeys) { + Map>> cachedEntries = Maps.newHashMap(); + keyStateMap + .entrySet() + .removeIf( + (entry -> { + Object structuralKey = entry.getKey(); + KeyState keyState = entry.getValue(); + if (complete && keyState.existence == KeyExistence.KNOWN_EXIST) { + keyState.valuesCached = true; + } + ConcatIterables it = null; + if (!keyState.localAdditions.isEmpty()) { + it = new ConcatIterables<>(); + it.extendWith( + Iterables.limit(keyState.localAdditions, keyState.localAdditions.size())); + } + if (keyState.valuesCached) { + if (it == null) it = new ConcatIterables<>(); + it.extendWith(Iterables.limit(keyState.values, keyState.valuesSize)); + } + if (it != null) + cachedEntries.put( + structuralKey, + Triple.of(keyState.originalKey, keyState.valuesCached, it)); + if (knownNonexistentKeys != null + && keyState.existence == KeyExistence.KNOWN_NONEXISTENT) + knownNonexistentKeys.put(structuralKey, keyState.originalKey); + return (keyState.existence == KeyExistence.KNOWN_NONEXISTENT + && !keyState.removedLocally) + || keyState.existence == KeyExistence.UNKNOWN_EXISTENCE; + })); + return cachedEntries; + } + + private Iterable> unnestCachedEntries( + Iterable>>> cachedEntries) { + return Iterables.unmodifiableIterable( + () -> + Iterators.concat( + Iterables.transform( + cachedEntries, + entry -> + Iterables.transform( + entry.getValue().getRight(), + v -> + new AbstractMap.SimpleEntry<>( + entry.getValue().getLeft(), v)) + .iterator()) + .iterator())); + } + + private Iterable> nonWeightedEntries( + Iterable>> lazyWindmillEntries) { + class ResultIterable implements Iterable> { + private final Iterable>> lazyWindmillEntries; + private final Map>> cachedEntries; + private final Map knownNonexistentKeys; + + ResultIterable( + Map>> cachedEntries, + Iterable>> lazyWindmillEntries, + Map knownNonexistentKeys) { + this.cachedEntries = cachedEntries; + this.lazyWindmillEntries = lazyWindmillEntries; + this.knownNonexistentKeys = knownNonexistentKeys; + } + + @Override + public Iterator> iterator() { + // Each time when the Iterable returned by entries() is iterated, a new Iterator is + // created. Every iterator must keep its own copy of seenCachedKeys so that if a key + // is paginated into multiple iterables from windmill, the cached values of this key + // will only be returned once. + Set seenCachedKeys = Sets.newHashSet(); + // notFullyCachedEntries returns all entries from windmill that are not fully cached + // and combines them with localAdditions. If a key is fully cached, contents of this + // key from windmill are ignored. + Iterable>> notFullyCachedEntries = + Iterables.filter( + Iterables.transform( + lazyWindmillEntries, + entry -> { + try { + final K key = + keyCoder.decode(entry.getKey().newInput(), Context.OUTER); + final Object structuralKey = keyCoder.structuralValue(key); + // key is deleted in cache thus fully cached. + if (knownNonexistentKeys.containsKey(structuralKey)) return null; + Triple> triple = + cachedEntries.get(structuralKey); + // no record of key in cache, return content in windmill. + if (triple == null) { + return Triple.of(structuralKey, key, entry.getValue()); + } + // key is fully cached in cache. + if (triple.getMiddle()) return null; + + // key is not fully cached, combine the content with local additions + if (seenCachedKeys.contains(structuralKey)) { + return Triple.of(structuralKey, key, entry.getValue()); + } else { + seenCachedKeys.add(structuralKey); + ConcatIterables it = new ConcatIterables<>(); + it.extendWith(triple.getRight()); + it.extendWith(entry.getValue()); + return Triple.of(structuralKey, key, it); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + }), + Objects::nonNull); + Iterator> unnestWindmill = + Iterators.concat( + Iterables.transform( + notFullyCachedEntries, + entry -> + Iterables.transform( + entry.getRight(), + v -> new AbstractMap.SimpleEntry<>(entry.getMiddle(), v)) + .iterator()) + .iterator()); + Iterator> fullyCached = + unnestCachedEntries( + Iterables.filter( + cachedEntries.entrySet(), + entry -> !seenCachedKeys.contains(entry.getKey()))) + .iterator(); + return Iterators.concat(unnestWindmill, fullyCached); + } + } + + Map knownNonexistentKeys = Maps.newHashMap(); + Map>> cachedEntries = + mergedCachedEntries(knownNonexistentKeys); + return Iterables.unmodifiableIterable( + new ResultIterable(cachedEntries, lazyWindmillEntries, knownNonexistentKeys)); + } + }; + } + + @Override + public ReadableState containsKey(K key) { + return new ReadableState() { + ReadableState> values = null; + final Object structuralKey = keyCoder.structuralValue(key); + + @Override + public Boolean read() { + KeyState keyState = keyStateMap.getOrDefault(structuralKey, null); + if (keyState != null && keyState.existence != KeyExistence.UNKNOWN_EXISTENCE) { + return keyState.existence == KeyExistence.KNOWN_EXIST; + } + if (values == null) { + values = WindmillMultimap.this.get(key); + } + return !Iterables.isEmpty(values.read()); + } + + @Override + public ReadableState readLater() { + if (values == null) { + values = WindmillMultimap.this.get(key); + } + values.readLater(); + return this; + } + }; + } + + // Currently, isEmpty is implemented by reading all keys and could potentially be optimized. + // But note that if isEmpty is often followed by iterating over keys then maybe not too bad; if + // isEmpty is followed by iterating over both keys and values then it won't help much. + @Override + public ReadableState isEmpty() { + return new ReadableState() { + ReadableState> keys = null; + + @Override + public Boolean read() { + for (KeyState keyState : keyStateMap.values()) { + if (keyState.existence == KeyExistence.KNOWN_EXIST) return false; + } + if (keys == null) { + keys = WindmillMultimap.this.keys(); + } + return Iterables.isEmpty(keys.read()); + } + + @Override + public ReadableState readLater() { + if (keys == null) { + keys = WindmillMultimap.this.keys(); + } + keys.readLater(); + return this; + } + }; + } + } private static class WindmillBag extends SimpleWindmillState implements BagState { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java index 45708d1337010..2e1902553064f 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillStateReader.java @@ -38,6 +38,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.function.Supplier; +import java.util.stream.Collectors; import javax.annotation.Nonnull; import javax.annotation.Nullable; import org.apache.beam.runners.dataflow.worker.WindmillStateReader.StateTag.Kind; @@ -88,6 +89,14 @@ class WindmillStateReader { public static final long CONTINUATION_MAX_BAG_BYTES = 32L << 20; // 32MB + /** + * Ideal maximum bytes in a TagMultimapFetchResponse response. However, Windmill will always + * return at least one value if possible irrespective of this limit. + */ + public static final long INITIAL_MAX_MULTIMAP_BYTES = 8L << 20; // 8MB + + public static final long CONTINUATION_MAX_MULTIMAP_BYTES = 32L << 20; // 32MB + /** * Ideal maximum bytes in a TagSortedList response. However, Windmill will always return at least * one value if possible irrespective of this limit. @@ -119,7 +128,9 @@ enum Kind { BAG, WATERMARK, ORDERED_LIST, - VALUE_PREFIX + VALUE_PREFIX, + MULTIMAP_SINGLE_ENTRY, + MULTIMAP_ALL } abstract Kind getKind(); @@ -129,9 +140,10 @@ enum Kind { abstract String getStateFamily(); /** - * For {@link Kind#BAG, Kind#ORDERED_LIST, Kind#VALUE_PREFIX} kinds: A previous - * 'continuation_position' returned by Windmill to signal the resulting bag was incomplete. - * Sending that position will request the next page of values. Null for first request. + * For {@link Kind#BAG, Kind#ORDERED_LIST, Kind#VALUE_PREFIX, KIND#MULTIMAP_SINGLE_ENTRY, + * KIND#MULTIMAP_ALL} kinds: A previous 'continuation_position' returned by Windmill to signal + * the resulting state was incomplete. Sending that position will request the next page of + * values. Null for first request. * *

Null for other kinds. */ @@ -142,6 +154,17 @@ enum Kind { @Nullable abstract Range getSortedListRange(); + /** For {@link Kind#MULTIMAP_SINGLE_ENTRY} kinds: the key in the multimap to fetch or delete. */ + @Nullable + abstract ByteString getMultimapKey(); + + /** + * For {@link Kind#MULTIMAP_ALL} kinds: will only return the keys of the multimap and not the + * values if true. + */ + @Nullable + abstract Boolean getOmitValues(); + static StateTag of( Kind kind, ByteString tag, String stateFamily, @Nullable RequestPositionT requestPosition) { return new AutoValue_WindmillStateReader_StateTag.Builder() @@ -172,6 +195,10 @@ abstract Builder setRequestPosition( abstract Builder setSortedListRange(@Nullable Range sortedListRange); + abstract Builder setMultimapKey(@Nullable ByteString encodedMultimapKey); + + abstract Builder setOmitValues(Boolean omitValues); + abstract StateTag build(); } } @@ -278,7 +305,7 @@ private Future stateFuture(StateTag stateTag, @Nullable Co CoderAndFuture existingCoderAndFutureWildcard = waiting.putIfAbsent(stateTag, coderAndFuture); if (existingCoderAndFutureWildcard == null) { - // Schedule a new request. It's response is guaranteed to find the future and coder. + // Schedule a new request. Its response is guaranteed to find the future and coder. pendingLookups.add(stateTag); } else { // Piggy-back on the pending or already answered request. @@ -330,8 +357,27 @@ public Future>> orderedListFuture( .toBuilder() .setSortedListRange(Preconditions.checkNotNull(range)) .build(); - return Preconditions.checkNotNull( - valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder))); + return valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder)); + } + + public Future>>> multimapFetchAllFuture( + boolean omitValues, ByteString encodedTag, String stateFamily, Coder elemCoder) { + StateTag stateTag = + StateTag.of(Kind.MULTIMAP_ALL, encodedTag, stateFamily) + .toBuilder() + .setOmitValues(omitValues) + .build(); + return valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder)); + } + + public Future> multimapFetchSingleEntryFuture( + ByteString encodedKey, ByteString encodedTag, String stateFamily, Coder elemCoder) { + StateTag stateTag = + StateTag.of(Kind.MULTIMAP_SINGLE_ENTRY, encodedTag, stateFamily) + .toBuilder() + .setMultimapKey(encodedKey) + .build(); + return valuesToPagingIterableFuture(stateTag, elemCoder, this.stateFuture(stateTag, elemCoder)); } public Future>> valuePrefixFuture( @@ -339,8 +385,8 @@ public Future>> valuePrefixFuture( // First request has no continuation position. StateTag stateTag = StateTag.of(Kind.VALUE_PREFIX, prefix, stateFamily).toBuilder().build(); - return Preconditions.checkNotNull( - valuesToPagingIterableFuture(stateTag, valueCoder, this.stateFuture(stateTag, valueCoder))); + return valuesToPagingIterableFuture( + stateTag, valueCoder, this.stateFuture(stateTag, valueCoder)); } /** @@ -437,18 +483,24 @@ public Iterable apply( return valuesAndContPosition.values; } else { // Return an iterable which knows how to come back for more. - StateTag contStateTag = + StateTag.Builder continuationTBuilder = StateTag.of( - stateTag.getKind(), - stateTag.getTag(), - stateTag.getStateFamily(), - valuesAndContPosition.continuationPosition); + stateTag.getKind(), + stateTag.getTag(), + stateTag.getStateFamily(), + valuesAndContPosition.continuationPosition) + .toBuilder(); if (stateTag.getSortedListRange() != null) { - contStateTag = - contStateTag.toBuilder().setSortedListRange(stateTag.getSortedListRange()).build(); + continuationTBuilder.setSortedListRange(stateTag.getSortedListRange()).build(); + } + if (stateTag.getMultimapKey() != null) { + continuationTBuilder.setMultimapKey(stateTag.getMultimapKey()).build(); + } + if (stateTag.getOmitValues() != null) { + continuationTBuilder.setOmitValues(stateTag.getOmitValues()).build(); } return new PagingIterable( - reader, valuesAndContPosition.values, contStateTag, coder); + reader, valuesAndContPosition.values, continuationTBuilder.build(), coder); } } } @@ -466,20 +518,69 @@ private Future> valuesToPagingIterabl return Futures.lazyTransform(future, toIterable); } + private void delayUnbatchableMultimapFetches( + List> multimapTags, HashSet> toFetch) { + // Each KeyedGetDataRequest can have at most 1 TagMultimapFetchRequest per + // pair, thus we need to delay unbatchable multimap requests of the same stateFamily and tag + // into later batches. There's no priority between get()/entries()/keys(), they will be fetched + // based on the order they occur in pendingLookups, so that all requests can eventually be + // fetched and none starves. + + Map>>> groupedTags = + multimapTags.stream() + .collect( + Collectors.groupingBy( + StateTag::getStateFamily, Collectors.groupingBy(StateTag::getTag))); + + for (Map>> familyTags : groupedTags.values()) { + for (List> tags : familyTags.values()) { + StateTag first = tags.remove(0); + toFetch.add(first); + if (tags.isEmpty()) continue; + + if (first.getKind() == Kind.MULTIMAP_ALL) { + // first is keys()/entries(), no more TagMultimapFetchRequests allowed in current batch. + pendingLookups.addAll(tags); + continue; + } + // first is get(key), no keys()/entries() allowed in current batch; each different key can + // have at most one fetch request in this batch. + Set addedKeys = Sets.newHashSet(); + addedKeys.add(first.getMultimapKey()); + for (StateTag other : tags) { + if (other.getKind() == Kind.MULTIMAP_ALL || addedKeys.contains(other.getMultimapKey())) { + pendingLookups.add(other); + } else { + toFetch.add(other); + addedKeys.add(other.getMultimapKey()); + } + } + } + } + } + public void startBatchAndBlock() { // First, drain work out of the pending lookups into a set. These will be the items we fetch. HashSet> toFetch = Sets.newHashSet(); try { + List> multimapTags = Lists.newArrayList(); while (!pendingLookups.isEmpty()) { StateTag stateTag = pendingLookups.poll(); if (stateTag == null) { break; } - + if (stateTag.getKind() == Kind.MULTIMAP_ALL + || stateTag.getKind() == Kind.MULTIMAP_SINGLE_ENTRY) { + multimapTags.add(stateTag); + continue; + } if (!toFetch.add(stateTag)) { throw new IllegalStateException("Duplicate tags being fetched."); } } + if (!multimapTags.isEmpty()) { + delayUnbatchableMultimapFetches(multimapTags, toFetch); + } // If we failed to drain anything, some other thread pulled it off the queue. We have no work // to do. @@ -521,6 +622,7 @@ private Windmill.KeyedGetDataRequest createRequest(Iterable> toFetch boolean continuation = false; List> orderedListsToFetch = Lists.newArrayList(); + List> multimapSingleEntryToFetch = Lists.newArrayList(); for (StateTag stateTag : toFetch) { switch (stateTag.getKind()) { case BAG: @@ -569,10 +671,65 @@ private Windmill.KeyedGetDataRequest createRequest(Iterable> toFetch } break; + case MULTIMAP_SINGLE_ENTRY: + multimapSingleEntryToFetch.add(stateTag); + break; + + case MULTIMAP_ALL: + Windmill.TagMultimapFetchRequest.Builder multimapFetchBuilder = + keyedDataBuilder + .addMultimapsToFetchBuilder() + .setTag(stateTag.getTag()) + .setStateFamily(stateTag.getStateFamily()) + .setFetchEntryNamesOnly(stateTag.getOmitValues()); + if (stateTag.getRequestPosition() == null) { + multimapFetchBuilder.setFetchMaxBytes(INITIAL_MAX_MULTIMAP_BYTES); + } else { + multimapFetchBuilder.setFetchMaxBytes(CONTINUATION_MAX_MULTIMAP_BYTES); + multimapFetchBuilder.setRequestPosition((ByteString) stateTag.getRequestPosition()); + continuation = true; + } + break; + default: throw new RuntimeException("Unknown kind of tag requested: " + stateTag.getKind()); } } + if (!multimapSingleEntryToFetch.isEmpty()) { + Map>>> multimapTags = + multimapSingleEntryToFetch.stream() + .collect( + Collectors.groupingBy( + StateTag::getStateFamily, Collectors.groupingBy(StateTag::getTag))); + for (Map.Entry>>> entry : multimapTags.entrySet()) { + String stateFamily = entry.getKey(); + Map>> familyTags = multimapTags.get(stateFamily); + for (Map.Entry>> tagEntry : familyTags.entrySet()) { + ByteString tag = tagEntry.getKey(); + List> stateTags = tagEntry.getValue(); + Windmill.TagMultimapFetchRequest.Builder multimapFetchBuilder = + keyedDataBuilder + .addMultimapsToFetchBuilder() + .setTag(tag) + .setStateFamily(stateFamily) + .setFetchEntryNamesOnly(false); + for (StateTag stateTag : stateTags) { + Windmill.TagMultimapEntry.Builder entryBuilder = + multimapFetchBuilder + .addEntriesToFetchBuilder() + .setEntryName(stateTag.getMultimapKey()); + if (stateTag.getRequestPosition() == null) { + entryBuilder.setFetchMaxBytes(INITIAL_MAX_BAG_BYTES); + } else { + entryBuilder.setFetchMaxBytes(CONTINUATION_MAX_BAG_BYTES); + entryBuilder.setRequestPosition((Long) stateTag.getRequestPosition()); + continuation = true; + } + } + } + } + } + orderedListsToFetch.sort( Comparator.>comparingLong(s -> s.getSortedListRange().lowerEndpoint()) .thenComparingLong(s -> s.getSortedListRange().upperEndpoint())); @@ -679,6 +836,49 @@ private void consumeResponse(Windmill.KeyedGetDataResponse response, Set builder = + StateTag.of( + Kind.MULTIMAP_ALL, + tagMultimap.getTag(), + tagMultimap.getStateFamily(), + tagMultimap.hasRequestPosition() ? tagMultimap.getRequestPosition() : null) + .toBuilder(); + StateTag tag = builder.setOmitValues(true).build(); + if (toFetch.contains(tag)) { + // this is keys() + toFetch.remove(tag); + consumeMultimapAll(tagMultimap, tag); + continue; + } + tag = builder.setOmitValues(false).build(); + if (toFetch.contains(tag)) { + // this is entries() + toFetch.remove(tag); + consumeMultimapAll(tagMultimap, tag); + continue; + } + // this is get() + StateTag.Builder entryTagBuilder = + StateTag.of( + Kind.MULTIMAP_SINGLE_ENTRY, tagMultimap.getTag(), tagMultimap.getStateFamily()) + .toBuilder(); + StateTag entryTag = null; + for (Windmill.TagMultimapEntry entry : tagMultimap.getEntriesList()) { + entryTag = + entryTagBuilder + .setMultimapKey(entry.getEntryName()) + .setRequestPosition(entry.hasRequestPosition() ? entry.getRequestPosition() : null) + .build(); + if (!toFetch.remove(entryTag)) { + throw new IllegalStateException( + "Received response for unrequested tag " + entryTag + ". Pending tags: " + toFetch); + } + consumeMultimapSingleEntry(entry, entryTag); + } + } + if (!toFetch.isEmpty()) { throw new IllegalStateException( "Didn't receive responses for all pending fetches. Missing: " + toFetch); @@ -778,6 +978,39 @@ private List> tagPrefixPageTagValues( return entryList; } + private WeightedList multimapEntryPageValues( + Windmill.TagMultimapEntry entry, Coder valueCoder) { + if (entry.getValuesCount() == 0) { + return new WeightedList<>(Collections.emptyList()); + } + WeightedList valuesList = new WeightedList<>(new ArrayList<>(entry.getValuesCount())); + for (ByteString value : entry.getValuesList()) { + try { + V decoded = valueCoder.decode(value.newInput(), Context.OUTER); + valuesList.addWeighted(decoded, value.size()); + } catch (IOException e) { + throw new IllegalStateException("Unable to decode multimap value " + e); + } + } + return valuesList; + } + + private List>> multimapPageValues( + Windmill.TagMultimapFetchResponse response, Coder valueCoder) { + if (response.getEntriesCount() == 0) { + return new WeightedList<>(Collections.emptyList()); + } + WeightedList>> entriesList = + new WeightedList<>(new ArrayList<>(response.getEntriesCount())); + for (Windmill.TagMultimapEntry entry : response.getEntriesList()) { + WeightedList values = multimapEntryPageValues(entry, valueCoder); + entriesList.addWeighted( + new AbstractMap.SimpleEntry<>(entry.getEntryName(), values), + entry.getEntryName().size() + values.getWeight()); + } + return entriesList; + } + private void consumeBag(TagBag bag, StateTag stateTag) { boolean shouldRemove; if (stateTag.getRequestPosition() == null) { @@ -850,7 +1083,8 @@ private void consumeTagPrefixResponse( Windmill.TagValuePrefixResponse tagValuePrefixResponse, StateTag stateTag) { boolean shouldRemove; if (stateTag.getRequestPosition() == null) { - // This is the response for the first page.// Leave the future in the cache so subsequent + // This is the response for the first page. + // Leave the future in the cache so subsequent // requests for the first page // can return immediately. shouldRemove = false; @@ -912,6 +1146,47 @@ private void consumeSortedList( future.setException(new RuntimeException("Error parsing ordered list", e)); } } + + private void consumeMultimapAll( + Windmill.TagMultimapFetchResponse response, StateTag stateTag) { + // Leave the future in the cache for the first page; do not cache for subsequent pages. + boolean shouldRemove = stateTag.getRequestPosition() != null; + CoderAndFuture>, ByteString>> + coderAndFuture = getWaiting(stateTag, shouldRemove); + SettableFuture>, ByteString>> future = + coderAndFuture.getNonDoneFuture(stateTag); + Coder valueCoder = coderAndFuture.getAndClearCoder(); + try { + List>> entries = + this.multimapPageValues(response, valueCoder); + future.set( + new ValuesAndContPosition<>( + entries, + response.hasContinuationPosition() ? response.getContinuationPosition() : null)); + } catch (Exception e) { + future.setException(new RuntimeException("Error parsing multimap fetch all result", e)); + } + } + + private void consumeMultimapSingleEntry( + Windmill.TagMultimapEntry entry, StateTag stateTag) { + // Leave the future in the cache for the first page; do not cache for subsequent pages. + boolean shouldRemove = stateTag.getRequestPosition() != null; + CoderAndFuture> coderAndFuture = + getWaiting(stateTag, shouldRemove); + SettableFuture> future = + coderAndFuture.getNonDoneFuture(stateTag); + Coder valueCoder = coderAndFuture.getAndClearCoder(); + try { + List values = this.multimapEntryPageValues(entry, valueCoder); + Long continuationToken = + entry.hasContinuationPosition() ? entry.getContinuationPosition() : null; + future.set(new ValuesAndContPosition<>(values, continuationToken)); + } catch (Exception e) { + future.setException(new RuntimeException("Error parsing multimap single entry result", e)); + } + } + /** * An iterable over elements backed by paginated GetData requests to Windmill. The iterable may be * iterated over an arbitrary number of times and multiple iterators may be active simultaneously. @@ -994,6 +1269,11 @@ protected ResultT computeNext() { .toBuilder(); if (secondPagePos.getSortedListRange() != null) { nextPageBuilder.setSortedListRange(secondPagePos.getSortedListRange()); + if (secondPagePos.getOmitValues() != null) { + nextPageBuilder.setOmitValues(secondPagePos.getOmitValues()); + } + if (secondPagePos.getMultimapKey() != null) { + nextPageBuilder.setMultimapKey(secondPagePos.getMultimapKey()); } nextPagePos = nextPageBuilder.build(); pendingNextPage = diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java index a72da7f7817f3..632ed9d1db20e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateInternalsTest.java @@ -25,6 +25,8 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; @@ -32,12 +34,17 @@ import java.io.Closeable; import java.io.IOException; +import java.nio.charset.StandardCharsets; import java.util.AbstractMap; import java.util.AbstractMap.SimpleEntry; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.Iterator; +import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Random; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import javax.annotation.Nullable; @@ -52,6 +59,7 @@ import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagBag; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagSortedListUpdateRequest; import org.apache.beam.runners.dataflow.worker.windmill.Windmill.TagValue; +import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -61,6 +69,7 @@ import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.GroupingState; import org.apache.beam.sdk.state.MapState; +import org.apache.beam.sdk.state.MultimapState; import org.apache.beam.sdk.state.OrderedListState; import org.apache.beam.sdk.state.ReadableState; import org.apache.beam.sdk.state.ValueState; @@ -80,6 +89,7 @@ import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.Futures; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.util.concurrent.SettableFuture; import org.hamcrest.Matchers; +import org.hamcrest.core.CombinableMatcher; import org.joda.time.Instant; import org.junit.After; import org.junit.Before; @@ -645,6 +655,1090 @@ public void testMapComplexPersist() throws Exception { assertEquals(0, commitBuilder.getValueUpdatesCount()); } + private static ByteString encodeWithCoder(T key, Coder coder) throws IOException { + ByteStringOutputStream out = new ByteStringOutputStream(); + coder.encode(key, out, Context.OUTER); + return out.toByteString(); + } + + // We use the structural value of the Multimap keys to differentiate between different keys. So we + // mix using the original key object and a duplicate but same key object so make sure the + // correctness. + private static byte[] dup(byte[] key) { + byte[] res = new byte[key.length]; + System.arraycopy(key, 0, res, 0, key.length); + return res; + } + + @Test + public void testMultimapGet() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + ReadableState> result = multimapState.get(dup(key)).readLater(); + waitAndSet(future, Arrays.asList(1, 2, 3), 30); + assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3)); + } + + @Test + public void testMultimapPutAndGet() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + multimapState.put(key, 1); + ReadableState> result = multimapState.get(dup(key)).readLater(); + waitAndSet(future, Arrays.asList(1, 2, 3), 30); + assertThat(result.read(), Matchers.containsInAnyOrder(1, 1, 2, 3)); + } + + @Test + public void testMultimapRemoveAndGet() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + ReadableState> result1 = multimapState.get(key).readLater(); + ReadableState> result2 = multimapState.get(dup(key)).readLater(); + waitAndSet(future, Arrays.asList(1, 2, 3), 30); + + assertTrue(multimapState.containsKey(key).read()); + assertThat(result1.read(), Matchers.containsInAnyOrder(1, 2, 3)); + + multimapState.remove(key); + assertFalse(multimapState.containsKey(dup(key)).read()); + assertThat(result2.read(), Matchers.emptyIterable()); + } + + @Test + public void testMultimapRemoveThenPut() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + ReadableState> result = multimapState.get(key).readLater(); + waitAndSet(future, Arrays.asList(1, 2, 3), 30); + multimapState.remove(dup(key)); + multimapState.put(key, 4); + multimapState.put(dup(key), 5); + assertThat(result.read(), Matchers.containsInAnyOrder(4, 5)); + } + + @Test + public void testMultimapGetLocalCombineStorage() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + ReadableState> result = multimapState.get(dup(key)).readLater(); + waitAndSet(future, Arrays.asList(1, 2), 30); + multimapState.put(key, 3); + multimapState.put(dup(key), 4); + assertFalse(multimapState.isEmpty().read()); + assertThat(result.read(), Matchers.containsInAnyOrder(1, 2, 3, 4)); + } + + @Test + public void testMultimapLocalRemoveOverrideStorage() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + SettableFuture> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + + ReadableState> result = multimapState.get(key).readLater(); + waitAndSet(future, Arrays.asList(1, 2), 30); + multimapState.remove(dup(key)); + assertThat(result.read(), Matchers.emptyIterable()); + multimapState.put(key, 3); + multimapState.put(dup(key), 4); + assertFalse(multimapState.isEmpty().read()); + assertThat(result.read(), Matchers.containsInAnyOrder(3, 4)); + } + + @Test + public void testMultimapLocalClearOverrideStorage() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + SettableFuture> future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key1, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future); + SettableFuture> future2 = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key2, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(future2); + + ReadableState> result1 = multimapState.get(key1).readLater(); + ReadableState> result2 = multimapState.get(dup(key2)).readLater(); + multimapState.clear(); + waitAndSet(future, Arrays.asList(1, 2), 30); + assertThat(result1.read(), Matchers.emptyIterable()); + assertThat(result2.read(), Matchers.emptyIterable()); + assertThat(multimapState.keys().read(), Matchers.emptyIterable()); + assertThat(multimapState.entries().read(), Matchers.emptyIterable()); + assertTrue(multimapState.isEmpty().read()); + } + + private static Map.Entry> multimapEntry( + byte[] key, Integer... values) throws IOException { + return new AbstractMap.SimpleEntry<>( + encodeWithCoder(key, ByteArrayCoder.of()), Arrays.asList(values)); + } + + @SafeVarargs + private static List weightedList(T... entries) { + WindmillStateReader.WeightedList list = + new WindmillStateReader.WeightedList<>(new ArrayList<>()); + for (T entry : entries) { + list.addWeighted(entry, 1); + } + return list; + } + + @Test + public void testMultimapBasicEntriesAndKeys() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + ReadableState>> entriesResult = + multimapState.entries().readLater(); + ReadableState> keysResult = multimapState.keys().readLater(); + waitAndSet( + entriesFuture, + Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)), + 30); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30); + + Iterable> entries = entriesResult.read(); + assertEquals(6, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key1, 1), + multimapEntryMatcher(key1, 2), + multimapEntryMatcher(key1, 3), + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key2, 2), + multimapEntryMatcher(key2, 3))); + + Iterable keys = keysResult.read(); + assertEquals(2, Iterables.size(keys)); + assertThat(keys, Matchers.containsInAnyOrder(key1, key2)); + } + + private static CombinableMatcher multimapEntryMatcher(byte[] key, Integer value) { + return Matchers.both(Matchers.hasProperty("key", Matchers.equalTo(key))) + .and(Matchers.hasProperty("value", Matchers.equalTo(value))); + } + + @Test + public void testMultimapEntriesAndKeysMergeLocalAdd() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + ReadableState>> entriesResult = + multimapState.entries().readLater(); + ReadableState> keysResult = multimapState.keys().readLater(); + waitAndSet( + entriesFuture, + Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)), + 30); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30); + + multimapState.put(key1, 7); + multimapState.put(dup(key2), 8); + multimapState.put(dup(key3), 8); + + Iterable> entries = entriesResult.read(); + assertEquals(9, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key1, 1), + multimapEntryMatcher(key1, 2), + multimapEntryMatcher(key1, 3), + multimapEntryMatcher(key1, 7), + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key2, 2), + multimapEntryMatcher(key2, 3), + multimapEntryMatcher(key2, 8), + multimapEntryMatcher(key3, 8))); + + Iterable keys = keysResult.read(); + assertEquals(3, Iterables.size(keys)); + assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3)); + } + + @Test + public void testMultimapEntriesAndKeysMergeLocalRemove() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + ReadableState>> entriesResult = + multimapState.entries().readLater(); + ReadableState> keysResult = multimapState.keys().readLater(); + waitAndSet( + entriesFuture, + Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)), + 30); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30); + + multimapState.remove(dup(key1)); + multimapState.put(key2, 8); + multimapState.put(dup(key3), 8); + + Iterable> entries = entriesResult.read(); + assertEquals(5, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key2, 2), + multimapEntryMatcher(key2, 3), + multimapEntryMatcher(key2, 8), + multimapEntryMatcher(key3, 8))); + + Iterable keys = keysResult.read(); + assertThat(keys, Matchers.containsInAnyOrder(key2, key3)); + } + + @Test + public void testMultimapEntriesPaginated() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + ReadableState>> entriesResult = + multimapState.entries().readLater(); + ReadableState> keysResult = multimapState.keys().readLater(); + waitAndSet( + entriesFuture, + weightedList( + multimapEntry(key1, 1, 2, 3), + // entry key2 is returned in 2 separate responses due to pagination. + multimapEntry(key2, 2, 3, 4), + multimapEntry(key2, 4, 5)), + 30); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30); + + multimapState.remove(dup(key1)); + multimapState.put(key2, 8); + multimapState.put(dup(key3), 8); + + Iterable> entries = entriesResult.read(); + assertEquals(7, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key2, 2), + multimapEntryMatcher(key2, 3), + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key2, 5), + multimapEntryMatcher(key2, 8), + multimapEntryMatcher(key3, 8))); + + Iterable keys = keysResult.read(); + assertThat(keys, Matchers.containsInAnyOrder(key2, key3)); + } + + @Test + public void testMultimapCacheComplete() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + + // to set up the multimap as cache complete + waitAndSet(entriesFuture, weightedList(multimapEntry(key, 1, 2, 3)), 30); + multimapState.entries().read(); + + multimapState.put(key, 2); + + when(mockReader.multimapFetchAllFuture( + anyBoolean(), eq(key(NAMESPACE, tag)), eq(STATE_FAMILY), eq(VarIntCoder.of()))) + .thenThrow( + new RuntimeException( + "The multimap is cache complete and should not perform any windmill read.")); + when(mockReader.multimapFetchSingleEntryFuture( + any(), eq(key(NAMESPACE, tag)), eq(STATE_FAMILY), eq(VarIntCoder.of()))) + .thenThrow( + new RuntimeException( + "The multimap is cache complete and should not perform any windmill read.")); + + Iterable> entries = multimapState.entries().read(); + assertEquals(4, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key, 1), + multimapEntryMatcher(key, 2), + multimapEntryMatcher(key, 3), + multimapEntryMatcher(key, 2))); + + Iterable keys = multimapState.keys().read(); + assertThat(keys, Matchers.containsInAnyOrder(key)); + + Iterable values = multimapState.get(dup(key)).read(); + assertThat(values, Matchers.containsInAnyOrder(1, 2, 2, 3)); + } + + @Test + public void testMultimapCachedSingleEntry() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key = "key".getBytes(StandardCharsets.UTF_8); + + SettableFuture> entryFuture = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(entryFuture); + + // to set up the entry key as cache complete and add some local changes + waitAndSet(entryFuture, weightedList(1, 2, 3), 30); + multimapState.get(key).read(); + multimapState.put(key, 2); + + when(mockReader.multimapFetchSingleEntryFuture( + eq(encodeWithCoder(key, ByteArrayCoder.of())), + eq(key(NAMESPACE, tag)), + eq(STATE_FAMILY), + eq(VarIntCoder.of()))) + .thenThrow( + new RuntimeException( + "The multimap is cache complete for " + + Arrays.toString(key) + + " and should not perform any windmill read.")); + + Iterable values = multimapState.get(dup(key)).read(); + assertThat(values, Matchers.containsInAnyOrder(1, 2, 2, 3)); + assertTrue(multimapState.containsKey(key).read()); + } + + @Test + public void testMultimapCachedPartialEntry() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + + SettableFuture> entryFuture = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key1, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(entryFuture); + + // to set up the entry key1 as cache complete and add some local changes + waitAndSet(entryFuture, weightedList(1, 2, 3), 30); + multimapState.get(key1).read(); + multimapState.put(key1, 2); + multimapState.put(key3, 20); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + + // windmill contains extra entry key2 + waitAndSet( + entriesFuture, + weightedList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 4, 5, 6)), + 30); + + // key1 exist in both cache and windmill; key2 exists only in windmill; key3 exists only in + // cache. They should all be merged. + Iterable> entries = multimapState.entries().read(); + + assertEquals(8, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key1, 1), + multimapEntryMatcher(key1, 2), + multimapEntryMatcher(key1, 2), + multimapEntryMatcher(key1, 3), + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key2, 5), + multimapEntryMatcher(key2, 6), + multimapEntryMatcher(key3, 20))); + + assertThat(multimapState.keys().read(), Matchers.containsInAnyOrder(key1, key2, key3)); + } + + @Test + public void testMultimapCachedPartialEntryCannotCachePolled() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + + SettableFuture> entryFuture = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key1, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(entryFuture); + + // to set up the entry key1 as cache complete and add some local changes + waitAndSet(entryFuture, weightedList(1, 2, 3), 30); + multimapState.get(key1).read(); + multimapState.put(dup(key1), 2); + multimapState.put(dup(key3), 20); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + // windmill contains extra entry key2, and this time the entries returned should not be cached. + waitAndSet( + entriesFuture, + Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 4, 5, 6)), + 30); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 30); + + // key1 exist in both cache and windmill; key2 exists only in windmill; key3 exists only in + // cache. They should all be merged. + Iterable> entries = multimapState.entries().read(); + + assertEquals(8, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key1, 1), + multimapEntryMatcher(key1, 2), + multimapEntryMatcher(key1, 2), + multimapEntryMatcher(key1, 3), + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key2, 5), + multimapEntryMatcher(key2, 6), + multimapEntryMatcher(key3, 20))); + + assertThat(multimapState.keys().read(), Matchers.containsInAnyOrder(key1, key2, key3)); + } + + @Test + public void testMultimapModifyAfterReadDoesNotAffectResult() throws IOException { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final byte[] key1 = "key1".getBytes(StandardCharsets.UTF_8); + final byte[] key2 = "key2".getBytes(StandardCharsets.UTF_8); + final byte[] key3 = "key3".getBytes(StandardCharsets.UTF_8); + final byte[] key4 = "key4".getBytes(StandardCharsets.UTF_8); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + SettableFuture>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + SettableFuture> getKey1Future = SettableFuture.create(); + SettableFuture> getKey2Future = SettableFuture.create(); + SettableFuture> getKey4Future = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key1, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(getKey1Future); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key2, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(getKey2Future); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key4, ByteArrayCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + VarIntCoder.of())) + .thenReturn(getKey4Future); + + ReadableState>> entriesResult = + multimapState.entries().readLater(); + ReadableState> keysResult = multimapState.keys().readLater(); + waitAndSet( + entriesFuture, + Arrays.asList(multimapEntry(key1, 1, 2, 3), multimapEntry(key2, 2, 3, 4)), + 200); + waitAndSet(keysFuture, Arrays.asList(multimapEntry(key1), multimapEntry(key2)), 200); + + // make key4 to be known nonexistent. + multimapState.remove(key4); + + ReadableState> key1Future = multimapState.get(key1).readLater(); + waitAndSet(getKey1Future, Arrays.asList(1, 2, 3), 200); + ReadableState> key2Future = multimapState.get(key2).readLater(); + waitAndSet(getKey2Future, Arrays.asList(2, 3, 4), 200); + ReadableState> key4Future = multimapState.get(key4).readLater(); + waitAndSet(getKey4Future, Collections.emptyList(), 200); + + multimapState.put(key1, 7); + multimapState.put(dup(key2), 8); + multimapState.put(dup(key3), 8); + + Iterable> entries = entriesResult.read(); + Iterable keys = keysResult.read(); + Iterable key1Values = key1Future.read(); + Iterable key2Values = key2Future.read(); + Iterable key4Values = key4Future.read(); + + // values added/removed after read should not be reflected in result + multimapState.remove(key1); + multimapState.put(key2, 9); + multimapState.put(key4, 10); + + assertEquals(9, Iterables.size(entries)); + assertThat( + entries, + Matchers.containsInAnyOrder( + multimapEntryMatcher(key1, 1), + multimapEntryMatcher(key1, 2), + multimapEntryMatcher(key1, 3), + multimapEntryMatcher(key1, 7), + multimapEntryMatcher(key2, 4), + multimapEntryMatcher(key2, 2), + multimapEntryMatcher(key2, 3), + multimapEntryMatcher(key2, 8), + multimapEntryMatcher(key3, 8))); + + assertEquals(3, Iterables.size(keys)); + assertThat(keys, Matchers.containsInAnyOrder(key1, key2, key3)); + + assertEquals(4, Iterables.size(key1Values)); + assertThat(key1Values, Matchers.containsInAnyOrder(1, 2, 3, 7)); + + assertEquals(4, Iterables.size(key2Values)); + assertThat(key2Values, Matchers.containsInAnyOrder(2, 3, 4, 8)); + + assertTrue(Iterables.isEmpty(key4Values)); + } + + @Test + public void testMultimapLazyIterateHugeEntriesResult() { + // A multimap with 1 million keys with a total of 10GBs data + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + + waitAndSet( + entriesFuture, + () -> + new Iterator>>() { + int returnedEntries = 0; + byte[] entryKey = new byte[10_000]; // each key is 10KB + final int targetEntries = 1_000_000; // return 1 million entries, which is 10 GBs + Random rand = new Random(); + + @Override + public boolean hasNext() { + return returnedEntries < targetEntries; + } + + @Override + public Map.Entry> next() { + returnedEntries++; + rand.nextBytes(entryKey); + try { + return multimapEntry(entryKey, 1); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }, + 200); + Iterable> entries = multimapState.entries().read(); + assertEquals(1_000_000, Iterables.size(entries)); + } + + @Test + public void testMultimapLazyIterateHugeKeysResult() { + // A multimap with 1 million keys with a total of 10GBs data + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, ByteArrayCoder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + SettableFuture>>> keysFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + true, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(keysFuture); + + waitAndSet( + keysFuture, + () -> + new Iterator>>() { + int returnedEntries = 0; + byte[] entryKey = new byte[10_000]; // each key is 10KB + final int targetEntries = 1_000_000; // return 1 million entries, which is 10 GBs + Random rand = new Random(); + + @Override + public boolean hasNext() { + return returnedEntries < targetEntries; + } + + @Override + public Map.Entry> next() { + returnedEntries++; + rand.nextBytes(entryKey); + try { + return multimapEntry(entryKey); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + }, + 200); + Iterable keys = multimapState.keys().read(); + assertEquals(1_000_000, Iterables.size(keys)); + } + + @Test + public void testMultimapLazyIterateHugeEntriesResultSingleEntry() throws IOException { + // A multimap with 1 key and 1 million values and a total of 10GBs data + final String tag = "multimap"; + final Integer key = 100; + StateTag> addr = + StateTags.multimap(tag, VarIntCoder.of(), ByteArrayCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, ByteArrayCoder.of())) + .thenReturn(entriesFuture); + SettableFuture> getKeyFuture = SettableFuture.create(); + when(mockReader.multimapFetchSingleEntryFuture( + encodeWithCoder(key, VarIntCoder.of()), + key(NAMESPACE, tag), + STATE_FAMILY, + ByteArrayCoder.of())) + .thenReturn(getKeyFuture); + + // a not weighted iterators that returns tons of data + Iterable values = + () -> + new Iterator() { + int returnedValues = 0; + byte[] value = new byte[10_000]; // each value is 10KB + final int targetValues = 1_000_000; // return 1 million values, which is 10 GBs + Random rand = new Random(); + + @Override + public boolean hasNext() { + return returnedValues < targetValues; + } + + @Override + public byte[] next() { + returnedValues++; + rand.nextBytes(value); + return value; + } + }; + + waitAndSet( + entriesFuture, + Arrays.asList( + new AbstractMap.SimpleEntry<>(encodeWithCoder(key, VarIntCoder.of()), values)), + 200); + waitAndSet(getKeyFuture, values, 200); + + Iterable> entries = multimapState.entries().read(); + assertEquals(1_000_000, Iterables.size(entries)); + + Iterable valueResult = multimapState.get(key).read(); + assertEquals(1_000_000, Iterables.size(valueResult)); + } + + private static class MultimapEntryUpdate { + String key; + Iterable values; + boolean deleteAll; + + public MultimapEntryUpdate(String key, Iterable values, boolean deleteAll) { + this.key = key; + this.values = values; + this.deleteAll = deleteAll; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof MultimapEntryUpdate)) return false; + MultimapEntryUpdate that = (MultimapEntryUpdate) o; + return deleteAll == that.deleteAll + && Objects.equals(key, that.key) + && Objects.equals(values, that.values); + } + + @Override + public int hashCode() { + return Objects.hash(key, values, deleteAll); + } + } + + private static MultimapEntryUpdate decodeTagMultimapEntry(Windmill.TagMultimapEntry entryProto) { + try { + String key = StringUtf8Coder.of().decode(entryProto.getEntryName().newInput(), Context.OUTER); + List values = new ArrayList<>(); + for (ByteString value : entryProto.getValuesList()) { + values.add(VarIntCoder.of().decode(value.newInput(), Context.OUTER)); + } + return new MultimapEntryUpdate(key, values, entryProto.getDeleteAll()); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static void assertTagMultimapUpdates( + Windmill.TagMultimapUpdateRequest.Builder updates, MultimapEntryUpdate... expected) { + assertThat( + updates.getUpdatesList().stream() + .map(WindmillStateInternalsTest::decodeTagMultimapEntry) + .collect(Collectors.toList()), + Matchers.containsInAnyOrder(expected)); + } + + @Test + public void testMultimapPutAndPersist() { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final String key1 = "key1"; + final String key2 = "key2"; + + multimapState.put(key1, 1); + multimapState.put(key1, 2); + multimapState.put(key2, 2); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(1, commitBuilder.getMultimapUpdatesCount()); + Windmill.TagMultimapUpdateRequest.Builder builder = + Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); + assertTagMultimapUpdates( + builder, + new MultimapEntryUpdate(key1, Arrays.asList(1, 2), false), + new MultimapEntryUpdate(key2, Arrays.asList(2), false)); + } + + @Test + public void testMultimapRemovePutAndPersist() { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final String key1 = "key1"; + final String key2 = "key2"; + + // we should add 1 and 2 to key1 + multimapState.remove(key1); + multimapState.put(key1, 1); + multimapState.put(key1, 2); + // we should not add 2 to key 2 + multimapState.put(key2, 2); + multimapState.remove(key2); + // we should add 4 to key 2 + multimapState.put(key2, 4); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(1, commitBuilder.getMultimapUpdatesCount()); + Windmill.TagMultimapUpdateRequest.Builder builder = + Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); + assertTagMultimapUpdates( + builder, + new MultimapEntryUpdate(key1, Arrays.asList(1, 2), true), + new MultimapEntryUpdate(key2, Arrays.asList(4), true)); + } + + @Test + public void testMultimapRemoveAndPersist() { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final String key1 = "key1"; + final String key2 = "key2"; + + multimapState.remove(key1); + multimapState.remove(key2); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(1, commitBuilder.getMultimapUpdatesCount()); + Windmill.TagMultimapUpdateRequest.Builder builder = + Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); + assertTagMultimapUpdates( + builder, + new MultimapEntryUpdate(key1, Collections.emptyList(), true), + new MultimapEntryUpdate(key2, Collections.emptyList(), true)); + } + + @Test + public void testMultimapPutRemoveClearAndPersist() { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + final String key1 = "key1"; + final String key2 = "key2"; + final String key3 = "key3"; + + // no need to send any put/remove if clear is called later + multimapState.put(key1, 1); + multimapState.put(key2, 2); + multimapState.remove(key2); + multimapState.clear(); + // remove without put sent after clear should also not be added: we are cache complete after + // clear, so we know we can skip unnecessary remove. + multimapState.remove(key3); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(1, commitBuilder.getMultimapUpdatesCount()); + Windmill.TagMultimapUpdateRequest.Builder builder = + Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); + assertEquals(0, builder.getUpdatesCount()); + assertTrue(builder.getDeleteAll()); + } + + @Test + public void testMultimapPutRemoveAndPersistWhenComplete() { + final String tag = "multimap"; + StateTag> addr = + StateTags.multimap(tag, StringUtf8Coder.of(), VarIntCoder.of()); + MultimapState multimapState = underTest.state(NAMESPACE, addr); + + SettableFuture>>> entriesFuture = + SettableFuture.create(); + when(mockReader.multimapFetchAllFuture( + false, key(NAMESPACE, tag), STATE_FAMILY, VarIntCoder.of())) + .thenReturn(entriesFuture); + + // to set up the multimap as cache complete + waitAndSet(entriesFuture, Collections.emptyList(), 30); + multimapState.entries().read(); + + final String key1 = "key1"; + final String key2 = "key2"; + + // put when complete should be sent + multimapState.put(key1, 4); + + // put-then-remove when complete should not be sent + multimapState.put(key2, 5); + multimapState.remove(key2); + + Windmill.WorkItemCommitRequest.Builder commitBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + underTest.persist(commitBuilder); + + assertEquals(1, commitBuilder.getMultimapUpdatesCount()); + Windmill.TagMultimapUpdateRequest.Builder builder = + Iterables.getOnlyElement(commitBuilder.getMultimapUpdatesBuilderList()); + assertTagMultimapUpdates(builder, new MultimapEntryUpdate(key1, Arrays.asList(4), false)); + } + public static final Range FULL_ORDERED_LIST_RANGE = Range.closedOpen(WindmillOrderedList.MIN_TS_MICROS, WindmillOrderedList.MAX_TS_MICROS); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java index ef3f470ed2a79..de853e86ff5f0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/WindmillStateReaderTest.java @@ -21,9 +21,16 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.when; +import com.google.api.client.util.Lists; +import com.google.common.collect.Maps; import java.io.IOException; import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.Map; import java.util.concurrent.Future; import org.apache.beam.runners.dataflow.worker.windmill.Windmill; @@ -37,6 +44,7 @@ import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.vendor.grpc.v1p48p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Charsets; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Range; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.BaseEncoding; import org.hamcrest.Matchers; @@ -46,6 +54,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; @@ -65,6 +74,11 @@ public class WindmillStateReaderTest { private static final long CONT_POSITION = 1391631351L; private static final ByteString STATE_KEY_PREFIX = ByteString.copyFromUtf8("key"); + private static final ByteString STATE_MULTIMAP_KEY_1 = ByteString.copyFromUtf8("multimapkey1"); + private static final ByteString STATE_MULTIMAP_KEY_2 = ByteString.copyFromUtf8("multimapkey2"); + private static final ByteString STATE_MULTIMAP_KEY_3 = ByteString.copyFromUtf8("multimapkey3"); + private static final ByteString STATE_MULTIMAP_CONT_1 = ByteString.copyFromUtf8("continuation_1"); + private static final ByteString STATE_MULTIMAP_CONT_2 = ByteString.copyFromUtf8("continuation_2"); private static final ByteString STATE_KEY_1 = ByteString.copyFromUtf8("key1"); private static final ByteString STATE_KEY_2 = ByteString.copyFromUtf8("key2"); private static final String STATE_FAMILY = "family"; @@ -99,6 +113,713 @@ private ByteString intData(int value) throws IOException { return output.toByteString(); } + @Test + public void testReadMultimapSingleEntry() throws Exception { + Future> future = + underTest.multimapFetchSingleEntryFuture( + STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .addEntriesToFetch( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES) + .build())); + + Windmill.KeyedGetDataResponse.Builder response = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .addAllValues(Arrays.asList(intData(5), intData(6))))); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build())) + .thenReturn(response.build()); + + Iterable results = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); + for (Integer unused : results) { + // Iterate over the results to force loading all the pages. + } + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat(results, Matchers.containsInAnyOrder(5, 6)); + assertNoReader(future); + } + + @Test + public void testReadMultimapSingleEntryPaginated() throws Exception { + Future> future = + underTest.multimapFetchSingleEntryFuture( + STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .addEntriesToFetch( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES) + .build())); + + Windmill.KeyedGetDataResponse.Builder response1 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .addAllValues(Arrays.asList(intData(5), intData(6))) + .setContinuationPosition(500))); + Windmill.KeyedGetDataRequest.Builder expectedRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .addEntriesToFetch( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES) + .setRequestPosition(500) + .build())); + + Windmill.KeyedGetDataResponse.Builder response2 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .addAllValues(Arrays.asList(intData(7), intData(8))) + .setContinuationPosition(800) + .setRequestPosition(500))); + Windmill.KeyedGetDataRequest.Builder expectedRequest3 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .addEntriesToFetch( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES) + .setRequestPosition(800) + .build())); + + Windmill.KeyedGetDataResponse.Builder response3 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .addAllValues(Arrays.asList(intData(9), intData(10))) + .setRequestPosition(800))); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build())) + .thenReturn(response1.build()); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build())) + .thenReturn(response2.build()); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build())) + .thenReturn(response3.build()); + + Iterable results = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build()); + for (Integer unused : results) { + // Iterate over the results to force loading all the pages. + } + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build()); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build()); + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat(results, Matchers.contains(5, 6, 7, 8, 9, 10)); + // NOTE: The future will still contain a reference to the underlying reader. + } + + // check whether the two TagMultimapFetchRequests equal to each other, ignoring the order of + // entries and the order of values in each entry. + private static void assertMultimapFetchRequestEqual( + Windmill.TagMultimapFetchRequest req1, Windmill.TagMultimapFetchRequest req2) { + assertMultimapEntriesEqual(req1.getEntriesToFetchList(), req2.getEntriesToFetchList()); + assertEquals( + req1.toBuilder().clearEntriesToFetch().build(), + req2.toBuilder().clearEntriesToFetch().build()); + } + + private static void assertMultimapEntriesEqual( + List left, List right) { + Map map = Maps.newHashMap(); + for (Windmill.TagMultimapEntry entry : left) { + map.put(entry.getEntryName(), entry); + } + for (Windmill.TagMultimapEntry entry : right) { + assertTrue(map.containsKey(entry.getEntryName())); + Windmill.TagMultimapEntry that = map.remove(entry.getEntryName()); + if (entry.getValuesCount() == 0) { + assertEquals(0, that.getValuesCount()); + } else { + assertThat(entry.getValuesList(), Matchers.containsInAnyOrder(that.getValuesList())); + } + assertEquals(entry.toBuilder().clearValues().build(), that.toBuilder().clearValues().build()); + } + assertTrue(map.isEmpty()); + } + + @Test + public void testReadMultimapMultipleEntries() throws Exception { + Future> future1 = + underTest.multimapFetchSingleEntryFuture( + STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Future> future2 = + underTest.multimapFetchSingleEntryFuture( + STATE_MULTIMAP_KEY_2, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .addEntriesToFetch( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES) + .build()) + .addEntriesToFetch( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_2) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES) + .build())); + + Windmill.KeyedGetDataResponse.Builder response = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .addAllValues(Arrays.asList(intData(5), intData(6)))) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_2) + .addAllValues(Arrays.asList(intData(15), intData(16))))); + when(mockWindmill.getStateData(ArgumentMatchers.eq(COMPUTATION), ArgumentMatchers.any())) + .thenReturn(response.build()); + + Iterable results1 = future1.get(); + Iterable results2 = future2.get(); + + final ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class); + Mockito.verify(mockWindmill) + .getStateData(ArgumentMatchers.eq(COMPUTATION), requestCaptor.capture()); + assertMultimapFetchRequestEqual( + expectedRequest.build().getMultimapsToFetch(0), + requestCaptor.getValue().getMultimapsToFetch(0)); + + // Iterate over the results to force loading all the pages. + for (Integer unused : results1) {} + for (Integer unused : results2) {} + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat(results1, Matchers.containsInAnyOrder(5, 6)); + assertThat(results2, Matchers.containsInAnyOrder(15, 16)); + assertNoReader(future1); + assertNoReader(future2); + } + + @Test + public void testReadMultimapMultipleEntriesWithPagination() throws Exception { + Future> future1 = + underTest.multimapFetchSingleEntryFuture( + STATE_MULTIMAP_KEY_1, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Future> future2 = + underTest.multimapFetchSingleEntryFuture( + STATE_MULTIMAP_KEY_2, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .addEntriesToFetch( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES) + .build()) + .addEntriesToFetch( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_2) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES) + .build())); + + Windmill.KeyedGetDataResponse.Builder response1 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .addAllValues(Arrays.asList(intData(5), intData(6))) + .setContinuationPosition(800)) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_2) + .addAllValues(Arrays.asList(intData(15), intData(16))))); + Windmill.KeyedGetDataRequest.Builder expectedRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .addEntriesToFetch( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES) + .setRequestPosition(800) + .build())); + Windmill.KeyedGetDataResponse.Builder response2 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .addAllValues(Arrays.asList(intData(7), intData(8))) + .setRequestPosition(800))); + when(mockWindmill.getStateData(ArgumentMatchers.eq(COMPUTATION), ArgumentMatchers.any())) + .thenReturn(response1.build()) + .thenReturn(response2.build()); + + Iterable results1 = future1.get(); + Iterable results2 = future2.get(); + + // Iterate over the results to force loading all the pages. + for (Integer unused : results1) {} + for (Integer unused : results2) {} + + final ArgumentCaptor requestCaptor = + ArgumentCaptor.forClass(Windmill.KeyedGetDataRequest.class); + Mockito.verify(mockWindmill, times(2)) + .getStateData(ArgumentMatchers.eq(COMPUTATION), requestCaptor.capture()); + assertMultimapFetchRequestEqual( + expectedRequest1.build().getMultimapsToFetch(0), + requestCaptor.getAllValues().get(0).getMultimapsToFetch(0)); + assertMultimapFetchRequestEqual( + expectedRequest2.build().getMultimapsToFetch(0), + requestCaptor.getAllValues().get(1).getMultimapsToFetch(0)); + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat(results1, Matchers.containsInAnyOrder(5, 6, 7, 8)); + assertThat(results2, Matchers.containsInAnyOrder(15, 16)); + // NOTE: The future will still contain a reference to the underlying reader. + } + + @Test + public void testReadMultimapKeys() throws Exception { + Future>>> future = + underTest.multimapFetchAllFuture(true, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(true) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)); + + Windmill.KeyedGetDataResponse.Builder response = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_1)) + .addEntries( + Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_2))); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build())) + .thenReturn(response.build()); + + Iterable>> results = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); + List keys = Lists.newArrayList(); + for (Map.Entry> entry : results) { + keys.add(entry.getKey()); + assertEquals(0, Iterables.size(entry.getValue())); + } + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat(keys, Matchers.containsInAnyOrder(STATE_MULTIMAP_KEY_1, STATE_MULTIMAP_KEY_2)); + assertNoReader(future); + } + + @Test + public void testReadMultimapKeysPaginated() throws Exception { + Future>>> future = + underTest.multimapFetchAllFuture(true, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(true) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)); + + Windmill.KeyedGetDataResponse.Builder response1 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_1)) + .setContinuationPosition(STATE_MULTIMAP_CONT_1)); + + Windmill.KeyedGetDataRequest.Builder expectedRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(true) + .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES) + .setRequestPosition(STATE_MULTIMAP_CONT_1)); + + Windmill.KeyedGetDataResponse.Builder response2 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_2)) + .setRequestPosition(STATE_MULTIMAP_CONT_1) + .setContinuationPosition(STATE_MULTIMAP_CONT_2)); + Windmill.KeyedGetDataRequest.Builder expectedRequest3 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(true) + .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES) + .setRequestPosition(STATE_MULTIMAP_CONT_2)); + + Windmill.KeyedGetDataResponse.Builder response3 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder().setEntryName(STATE_MULTIMAP_KEY_3)) + .setRequestPosition(STATE_MULTIMAP_CONT_2)); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build())) + .thenReturn(response1.build()); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build())) + .thenReturn(response2.build()); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build())) + .thenReturn(response3.build()); + + Iterable>> results = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build()); + List keys = Lists.newArrayList(); + for (Map.Entry> entry : results) { + keys.add(entry.getKey()); + assertEquals(0, Iterables.size(entry.getValue())); + } + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build()); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build()); + Mockito.verifyNoMoreInteractions(mockWindmill); + + assertThat( + keys, + Matchers.containsInAnyOrder( + STATE_MULTIMAP_KEY_1, STATE_MULTIMAP_KEY_2, STATE_MULTIMAP_KEY_3)); + // NOTE: The future will still contain a reference to the underlying reader. + } + + @Test + public void testReadMultimapAllEntries() throws Exception { + Future>>> future = + underTest.multimapFetchAllFuture(false, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)); + + Windmill.KeyedGetDataResponse.Builder response = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .addValues(intData(1)) + .addValues(intData(2))) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_2) + .addValues(intData(10)) + .addValues(intData(20)))); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest.build())) + .thenReturn(response.build()); + + Iterable>> results = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest.build()); + int foundEntries = 0; + for (Map.Entry> entry : results) { + if (entry.getKey().equals(STATE_MULTIMAP_KEY_1)) { + foundEntries++; + assertThat(entry.getValue(), Matchers.containsInAnyOrder(1, 2)); + } else { + foundEntries++; + assertEquals(STATE_MULTIMAP_KEY_2, entry.getKey()); + assertThat(entry.getValue(), Matchers.containsInAnyOrder(10, 20)); + } + } + assertEquals(2, foundEntries); + Mockito.verifyNoMoreInteractions(mockWindmill); + assertNoReader(future); + } + + private static void assertMultimapEntries( + Iterable>> expected, + List>> actual) { + Map> expectedMap = Maps.newHashMap(); + for (Map.Entry> entry : expected) { + ByteString key = entry.getKey(); + if (!expectedMap.containsKey(key)) expectedMap.put(key, new ArrayList<>()); + entry.getValue().forEach(expectedMap.get(key)::add); + } + for (Map.Entry> entry : actual) { + assertThat( + entry.getValue(), + Matchers.containsInAnyOrder(expectedMap.remove(entry.getKey()).toArray())); + } + assertTrue(expectedMap.isEmpty()); + } + + @Test + public void testReadMultimapEntriesPaginated() throws Exception { + Future>>> future = + underTest.multimapFetchAllFuture(false, STATE_KEY_1, STATE_FAMILY, INT_CODER); + Mockito.verifyNoMoreInteractions(mockWindmill); + + Windmill.KeyedGetDataRequest.Builder expectedRequest1 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .setFetchMaxBytes(WindmillStateReader.INITIAL_MAX_MULTIMAP_BYTES)); + + Windmill.KeyedGetDataResponse.Builder response1 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_1) + .addValues(intData(1)) + .addValues(intData(2))) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_2) + .addValues(intData(3)) + .addValues(intData(3))) + .setContinuationPosition(STATE_MULTIMAP_CONT_1)); + + Windmill.KeyedGetDataRequest.Builder expectedRequest2 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES) + .setRequestPosition(STATE_MULTIMAP_CONT_1)); + + Windmill.KeyedGetDataResponse.Builder response2 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_2) + .addValues(intData(2))) + .setRequestPosition(STATE_MULTIMAP_CONT_1) + .setContinuationPosition(STATE_MULTIMAP_CONT_2)); + Windmill.KeyedGetDataRequest.Builder expectedRequest3 = + Windmill.KeyedGetDataRequest.newBuilder() + .setKey(DATA_KEY) + .setShardingKey(SHARDING_KEY) + .setWorkToken(WORK_TOKEN) + .setMaxBytes(WindmillStateReader.MAX_CONTINUATION_KEY_BYTES) + .addMultimapsToFetch( + Windmill.TagMultimapFetchRequest.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .setFetchEntryNamesOnly(false) + .setFetchMaxBytes(WindmillStateReader.CONTINUATION_MAX_MULTIMAP_BYTES) + .setRequestPosition(STATE_MULTIMAP_CONT_2)); + + Windmill.KeyedGetDataResponse.Builder response3 = + Windmill.KeyedGetDataResponse.newBuilder() + .setKey(DATA_KEY) + .addTagMultimaps( + Windmill.TagMultimapFetchResponse.newBuilder() + .setTag(STATE_KEY_1) + .setStateFamily(STATE_FAMILY) + .addEntries( + Windmill.TagMultimapEntry.newBuilder() + .setEntryName(STATE_MULTIMAP_KEY_2) + .addValues(intData(4))) + .setRequestPosition(STATE_MULTIMAP_CONT_2)); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest1.build())) + .thenReturn(response1.build()); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest2.build())) + .thenReturn(response2.build()); + Mockito.when(mockWindmill.getStateData(COMPUTATION, expectedRequest3.build())) + .thenReturn(response3.build()); + + Iterable>> results = future.get(); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest1.build()); + assertMultimapEntries( + results, + Arrays.asList( + new AbstractMap.SimpleEntry<>(STATE_MULTIMAP_KEY_1, Arrays.asList(1, 2)), + new AbstractMap.SimpleEntry<>(STATE_MULTIMAP_KEY_2, Arrays.asList(3, 3, 2, 4)))); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest2.build()); + Mockito.verify(mockWindmill).getStateData(COMPUTATION, expectedRequest3.build()); + Mockito.verifyNoMoreInteractions(mockWindmill); + // NOTE: The future will still contain a reference to the underlying reader. + } + @Test public void testReadBag() throws Exception { Future> future = underTest.bagFuture(STATE_KEY_1, STATE_FAMILY, INT_CODER); diff --git a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto index b0e4dba698bcf..38fd5cf0dd162 100644 --- a/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto +++ b/runners/google-cloud-dataflow-java/worker/windmill/src/main/proto/windmill.proto @@ -146,6 +146,79 @@ message TagBag { optional int64 fetch_max_bytes = 6 [default = 0x7fffffffffffffff]; } +// For a given sharding key and state family, a TagMultimap is a collection of +// `entry_name, bag of values` pairs, each pair is a TagMultimapEntry. +// +// The request_position, continuation_position and fetch_max_bytes fields in +// TagMultimapEntry are used for the pagination and byte limiting of individual +// entry fetch requests get(entry_name); while those fields in +// TagMultimapFetchRequest and TagMultimapFetchResponse are used for full +// multimap fetch requests entry_names() and entries(). +// Do not set both in a TagMultimapFetchRequest at the same time. +message TagMultimapEntry { + optional bytes entry_name = 1; + // In update request: if true all values associated with this entry_name will + // be deleted. If new values are present they will be written. + optional bool delete_all = 2; + // In update request: The given values will be added to the collection and + // associated with entry_name. + // In fetch response: Values that are associated with this entry_name in the + // multimap. + repeated bytes values = 3; + // In fetch request: A previously returned continuation_position from an + // earlier read response. Indicates we wish to fetch the next page of values. + // If this is the first request, set to empty. + // In fetch response: copied from request. + optional int64 request_position = 4; + // In fetch response: Set when there are values after those returned above, + // but they were suppressed to respect the fetch_max_bytes limit. Subsequent + // requests should copy this to request_position to retrieve the next page of + // values. + optional int64 continuation_position = 5; + // In fetch request: Limits the size of the fetched values to this byte limit. + // A lower limit may be imposed by the service. + optional int64 fetch_max_bytes = 6 [default = 0x7fffffffffffffff]; +} + +message TagMultimapFetchRequest { + optional bytes tag = 1; + optional string state_family = 2; + // If true, values will be omitted in the response. + optional bool fetch_entry_names_only = 3; + // Limits the size of the fetched entries to this byte limit. A lower limit + // may be imposed by the service. + optional int64 fetch_max_bytes = 4 [default = 0x7fffffffffffffff]; + // A previously returned continuation_position from an earlier fetch response. + // Indicates we wish to fetch the next page of entries. If this is the first + // request, set to empty. + optional bytes request_position = 5; + // Fetch the requested subset of entries only. Will fetch all entries if left + // empty. Entries in entries_to_fetch should only have the entry_name, + // request_position and fetch_max_bytes set. + repeated TagMultimapEntry entries_to_fetch = 6; +} + +message TagMultimapFetchResponse { + optional bytes tag = 1; + optional string state_family = 2; + repeated TagMultimapEntry entries = 3; + // Will be set only if the entries_to_fetch in the request is empty when + // we want to fetch all entries. + optional bytes continuation_position = 4; + // Request position copied from request. + optional bytes request_position = 5; +} + +message TagMultimapUpdateRequest { + optional bytes tag = 1; + optional string state_family = 2; + // All entries including the values in the multimap will be deleted. + optional bool delete_all = 3; + // If delete_all is true and updates are not empty, the multimap will first be + // cleared and then those updates will be applied. + repeated TagMultimapEntry updates = 4; +} + // A single entry in a sorted list message SortedListEntry { // The value payload. @@ -297,6 +370,8 @@ message KeyedGetDataRequest { repeated TagBag bags_to_fetch = 8; // Must be at most one sorted_list_to_fetch for a given state family and tag. repeated TagSortedListFetchRequest sorted_lists_to_fetch = 9; + // Must be at most one multimaps_to_fetch for a given state family and tag. + repeated TagMultimapFetchRequest multimaps_to_fetch = 12; repeated WatermarkHold watermark_holds_to_fetch = 5; repeated LatencyAttribution latency_attribution = 13; @@ -329,6 +404,8 @@ message KeyedGetDataResponse { repeated TagBag bags = 6; // There is one TagSortedListFetchResponse per state-family, tag pair. repeated TagSortedListFetchResponse tag_sorted_lists = 8; + // There is one TagMultimapFetchResponse per state-family, tag pair. + repeated TagMultimapFetchResponse tag_multimaps = 10; repeated WatermarkHold watermark_holds = 5; reserved 4; @@ -377,7 +454,7 @@ message GlobalDataRequest { optional string state_family = 3; } -// next id: 24 +// next id: 27 message WorkItemCommitRequest { required bytes key = 1; required fixed64 work_token = 2; @@ -394,6 +471,7 @@ message WorkItemCommitRequest { repeated TagValuePrefix tag_value_prefix_deletes = 25; repeated TagBag bag_updates = 18; repeated TagSortedListUpdateRequest sorted_list_updates = 24; + repeated TagMultimapUpdateRequest multimap_updates = 26; repeated Counter counter_updates = 8; repeated GlobalDataRequest global_data_requests = 11; repeated GlobalData global_data_updates = 10;