Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MultimapState API #23491

Merged
merged 13 commits into from
Dec 16, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ message StateSpec {
MapStateSpec map_spec = 4;
SetStateSpec set_spec = 5;
OrderedListStateSpec ordered_list_spec = 6;
MultimapStateSpec multimap_spec = 8;
}

// (Required) URN of the protocol required by this state specification to present
Expand Down Expand Up @@ -599,6 +600,11 @@ message MapStateSpec {
string value_coder_id = 2;
}

message MultimapStateSpec {
string key_coder_id = 1;
string value_coder_id = 2;
}

message SetStateSpec {
string element_coder_id = 1;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,17 @@ public RunnerApi.StateSpec dispatchSet(Coder<?> elementCoder) {
.setProtocol(FunctionSpec.newBuilder().setUrn(MULTIMAP_USER_STATE))
.build();
}

@Override
public RunnerApi.StateSpec dispatchMultimap(Coder<?> keyCoder, Coder<?> valueCoder) {
return builder
.setMultimapSpec(
RunnerApi.MultimapStateSpec.newBuilder()
.setKeyCoderId(registerCoderOrThrow(components, keyCoder))
.setValueCoderId(registerCoderOrThrow(components, valueCoder)))
.setProtocol(FunctionSpec.newBuilder().setUrn(MULTIMAP_USER_STATE))
.build();
}
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.beam.runners.core;

import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
Expand All @@ -36,6 +37,7 @@
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.MultimapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.ReadableStates;
Expand All @@ -51,11 +53,13 @@
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.apache.beam.sdk.values.TimestampedValue;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ArrayListMultimap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Lists;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Maps;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Multimap;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
Expand Down Expand Up @@ -153,6 +157,14 @@ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
return new InMemoryMap<>(mapKeyCoder, mapValueCoder);
}

@Override
public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
StateTag<MultimapState<KeyT, ValueT>> spec,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
return new InMemoryMultimap<>(keyCoder, valueCoder);
}

@Override
public <T> OrderedListState<T> bindOrderedList(
StateTag<OrderedListState<T>> spec, Coder<T> elemCoder) {
Expand Down Expand Up @@ -463,6 +475,124 @@ public InMemoryBag<T> copy() {
}
}

/** An {@link InMemoryState} implementation of {@link MultimapState}. */
public static final class InMemoryMultimap<K, V>
implements MultimapState<K, V>, InMemoryState<InMemoryMultimap<K, V>> {
private final Coder<K> keyCoder;
private final Coder<V> valueCoder;
private Multimap<Object, V> contents = ArrayListMultimap.create();
private Map<Object, K> structuralKeysMapping = Maps.newHashMap();

public InMemoryMultimap(Coder<K> keyCoder, Coder<V> valueCoder) {
this.keyCoder = keyCoder;
this.valueCoder = valueCoder;
}

@Override
public void clear() {
contents = ArrayListMultimap.create();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just clear the instantiated contents?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See

// Even though we're clearing we can't remove this from the in-memory state map, since

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should create a new container instead of clear(), this is also how Bag/Map/Set do clear().

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The +1 was for your response that we need to keep stuff consistent for past reads that were returned.

structuralKeysMapping = Maps.newHashMap();
}

@Override
public void put(K key, V value) {
Object structuralKey = keyCoder.structuralValue(key);
structuralKeysMapping.put(structuralKey, key);
contents.put(structuralKey, value);
}

@Override
public void remove(K key) {
Object structuralKey = keyCoder.structuralValue(key);
structuralKeysMapping.remove(structuralKey);
contents.removeAll(structuralKey);
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<Iterable<V>> get(K key) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these annotations auto generated?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have removed the annotations.

return CollectionViewState.of(contents.get(keyCoder.structuralValue(key)));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This returns a stale view in this order of operations:

mm.put("A", 1);
ReadableState<Iterable<V>> iter = mm.get("A");
mm.clear();
iter.read(); <-- will contain "1"

I believe there are other combinations where you won't get what you want. Please ensure that ParDoTest covers:

  • get before and after remove/put/clear.
  • containsKey before and after remove/put/clear
  • entries before and after remove/put/clear
  • isEmpty before and after remove/put/clear
  • keys fore and after remove/put/clear

Currently I see your covering entries before and after put/remove.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, updated ParDoTest to include those as well.

}

@Override
public ReadableState<Iterable<K>> keys() {
return CollectionViewState.of(structuralKeysMapping.values());
}

@Override
public ReadableState<Iterable<Map.Entry<K, Iterable<V>>>> entries() {
return new ReadableState<Iterable<Map.Entry<K, Iterable<V>>>>() {
@Override
public Iterable<Map.Entry<K, Iterable<V>>> read() {
List<Map.Entry<K, Iterable<V>>> result = new ArrayList<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

= Lists.arrayListWithExpectedSize(structuralKeysMapping.size()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (Map.Entry<Object, K> entry : structuralKeysMapping.entrySet()) {
result.add(
new AbstractMap.SimpleEntry(
entry.getValue(), ImmutableList.copyOf(contents.get(entry.getKey()))));
}
return ImmutableList.copyOf(result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Collections.unmodifiableList, since it doesn't copy

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

@Override
public ReadableState<Iterable<Map.Entry<K, Iterable<V>>>> readLater() {
return this;
}
};
}

@Override
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto on annotations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

public @UnknownKeyFor @NonNull @Initialized ReadableState<
@UnknownKeyFor @NonNull @Initialized Boolean>
containsKey(K key) {
return new ReadableState<Boolean>() {
@Override
public Boolean read() {
return structuralKeysMapping.containsKey(keyCoder.structuralValue(key));
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<Boolean> readLater() {
return this;
}
};
}

@Override
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

annotations

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

public @UnknownKeyFor @NonNull @Initialized ReadableState<
@UnknownKeyFor @NonNull @Initialized Boolean>
isEmpty() {
return new ReadableState<Boolean>() {
@Override
public Boolean read() {
return contents.isEmpty();
}

@Override
public @UnknownKeyFor @NonNull @Initialized ReadableState<Boolean> readLater() {
return this;
}
};
}

@Override
public boolean isCleared() {
return contents.isEmpty();
}

@Override
public InMemoryMultimap<K, V> copy() {
InMemoryMultimap<K, V> that = new InMemoryMultimap<>(keyCoder, valueCoder);
for (K key : this.structuralKeysMapping.values()) {
// Make a copy of structuralKey as well
Object structuralKey = keyCoder.structuralValue(key);
that.structuralKeysMapping.put(structuralKey, uncheckedClone(keyCoder, key));
for (V value : this.contents.get(structuralKey)) {
that.contents.put(structuralKey, uncheckedClone(valueCoder, value));
}
}
return that;
}
}

/** An {@link InMemoryState} implementation of {@link OrderedListState}. */
public static final class InMemoryOrderedList<T>
implements OrderedListState<T>, InMemoryState<InMemoryOrderedList<T>> {
Expand Down Expand Up @@ -625,6 +755,28 @@ public InMemorySet<T> copy() {
}
}

private static class CollectionViewState<T> implements ReadableState<Iterable<T>> {
private final Collection<T> collection;

private CollectionViewState(Collection<T> collection) {
this.collection = collection;
}

public static <T> CollectionViewState<T> of(Collection<T> collection) {
return new CollectionViewState<>(collection);
}

@Override
public Iterable<T> read() {
return ImmutableList.copyOf(collection);
}

@Override
public ReadableState<Iterable<T>> readLater() {
return this;
}
}

/** An {@link InMemoryState} implementation of {@link MapState}. */
public static final class InMemoryMap<K, V>
implements MapState<K, V>, InMemoryState<InMemoryMap<K, V>> {
Expand Down Expand Up @@ -685,28 +837,6 @@ public void remove(K key) {
contents.remove(key);
}

private static class CollectionViewState<T> implements ReadableState<Iterable<T>> {
private final Collection<T> collection;

private CollectionViewState(Collection<T> collection) {
this.collection = collection;
}

public static <T> CollectionViewState<T> of(Collection<T> collection) {
return new CollectionViewState<>(collection);
}

@Override
public Iterable<T> read() {
return ImmutableList.copyOf(collection);
}

@Override
public ReadableState<Iterable<T>> readLater() {
return this;
}
}

@Override
public ReadableState<Iterable<K>> keys() {
return CollectionViewState.of(contents.keySet());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.MultimapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
Expand Down Expand Up @@ -84,6 +85,9 @@ <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
Coder<KeyT> mapKeyCoder,
Coder<ValueT> mapValueCoder);

<KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
StateTag<MultimapState<KeyT, ValueT>> spec, Coder<KeyT> keyCoder, Coder<ValueT> valueCoder);

<T> OrderedListState<T> bindOrderedList(StateTag<OrderedListState<T>> spec, Coder<T> elemCoder);

<InputT, AccumT, OutputT> CombiningState<InputT, AccumT, OutputT> bindCombiningValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.MultimapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
Expand Down Expand Up @@ -87,6 +88,15 @@ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
return binder.bindMap(tagForSpec(id, spec), mapKeyCoder, mapValueCoder);
}

@Override
public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
String id,
StateSpec<MultimapState<KeyT, ValueT>> spec,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
return binder.bindMultimap(tagForSpec(id, spec), keyCoder, valueCoder);
}

@Override
public <T> OrderedListState<T> bindOrderedList(
String id, StateSpec<OrderedListState<T>> spec, Coder<T> elemCoder) {
Expand Down Expand Up @@ -203,6 +213,11 @@ public static <K, V> StateTag<MapState<K, V>> map(
return new SimpleStateTag<>(new StructuredId(id), StateSpecs.map(keyCoder, valueCoder));
}

public static <K, V> StateTag<MultimapState<K, V>> multimap(
String id, Coder<K> keyCoder, Coder<V> valueCoder) {
return new SimpleStateTag<>(new StructuredId(id), StateSpecs.multimap(keyCoder, valueCoder));
}

public static <T> StateTag<OrderedListState<T>> orderedList(String id, Coder<T> elemCoder) {
return new SimpleStateTag<>(new StructuredId(id), StateSpecs.orderedList(elemCoder));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryBag;
import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryCombiningState;
import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryMap;
import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryMultimap;
import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryOrderedList;
import org.apache.beam.runners.core.InMemoryStateInternals.InMemorySet;
import org.apache.beam.runners.core.InMemoryStateInternals.InMemoryState;
Expand All @@ -41,6 +42,7 @@
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.MultimapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
Expand Down Expand Up @@ -370,6 +372,22 @@ public <T> OrderedListState<T> bindOrderedList(
}
}

@Override
public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
StateTag<MultimapState<KeyT, ValueT>> address,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
if (containedInUnderlying(namespace, address)) {
@SuppressWarnings("unchecked")
InMemoryState<? extends MultimapState<KeyT, ValueT>> existingState =
(InMemoryState<? extends MultimapState<KeyT, ValueT>>)
underlying.get().get(namespace, address, c);
return existingState.copy();
} else {
return new InMemoryMultimap<>(keyCoder, valueCoder);
}
}

@Override
public <InputT, AccumT, OutputT>
CombiningState<InputT, AccumT, OutputT> bindCombiningValueWithContext(
Expand Down Expand Up @@ -458,6 +476,14 @@ public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
return underlying.get(namespace, address, c);
}

@Override
public <KeyT, ValueT> MultimapState<KeyT, ValueT> bindMultimap(
StateTag<MultimapState<KeyT, ValueT>> address,
Coder<KeyT> keyCoder,
Coder<ValueT> valueCoder) {
return underlying.get(namespace, address, c);
}

@Override
public <T> OrderedListState<T> bindOrderedList(
StateTag<OrderedListState<T>> address, Coder<T> elemCoder) {
Expand Down
1 change: 1 addition & 0 deletions runners/flink/flink_runner.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ def createValidatesRunnerTask(Map m) {
excludeCategories 'org.apache.beam.sdk.testing.UsesCommittedMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesSystemMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering'
excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState'
excludeCategories 'org.apache.beam.sdk.testing.UsesLoopingTimer'
if (config.streaming) {
excludeCategories 'org.apache.beam.sdk.testing.UsesTimerMap'
Expand Down
2 changes: 1 addition & 1 deletion runners/flink/job-server/flink_job_server.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@ def portableValidatesRunnerTask(String name, boolean streaming, boolean checkpoi
excludeCategories 'org.apache.beam.sdk.testing.UsesGaugeMetrics'
excludeCategories 'org.apache.beam.sdk.testing.UsesParDoLifecycle'
excludeCategories 'org.apache.beam.sdk.testing.UsesMapState'
excludeCategories 'org.apache.beam.sdk.testing.UsesMultimapState'
excludeCategories 'org.apache.beam.sdk.testing.UsesSetState'
excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState'
excludeCategories 'org.apache.beam.sdk.testing.UsesStrictTimerOrdering'
excludeCategories 'org.apache.beam.sdk.testing.UsesOnWindowExpiration'
excludeCategories 'org.apache.beam.sdk.testing.UsesBundleFinalizer'
excludeCategories 'org.apache.beam.sdk.testing.UsesOrderedListState'
if (streaming) {
excludeCategories 'org.apache.beam.sdk.testing.UsesBoundedSplittableParDo'
excludeCategories 'org.apache.beam.sdk.testing.UsesTestStreamWithProcessingTime'
Expand Down
Loading