Skip to content

Commit

Permalink
Guard keys-values multi map input with a flag. (#29690)
Browse files Browse the repository at this point in the history
This is needed as some runners (e.g. Dataflow) do not gracefully
return errors on unkown state read types.

This flag will be set via runner capabilites in a future PR.
  • Loading branch information
robertwb authored Dec 8, 2023
1 parent d23ed6f commit 98a2690
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -915,10 +915,9 @@ public Coder<V> valueCoder() {
// Expect the following requests for the first bundle:
// * one to read iterable side input
// * one to read keys from multimap side input
// * one to attempt multimap side input bulk read
// * one to read key1 iterable from multimap side input
// * one to read key2 iterable from multimap side input
assertEquals(5, stateRequestHandler.receivedRequests.size());
assertEquals(4, stateRequestHandler.receivedRequests.size());
assertEquals(
stateRequestHandler.receivedRequests.get(0).getStateKey().getIterableSideInput(),
BeamFnApi.StateKey.IterableSideInput.newBuilder()
Expand All @@ -932,20 +931,14 @@ public Coder<V> valueCoder() {
.setTransformId(transformId)
.build());
assertEquals(
stateRequestHandler.receivedRequests.get(2).getStateKey().getMultimapKeysValuesSideInput(),
BeamFnApi.StateKey.MultimapKeysValuesSideInput.newBuilder()
.setSideInputId(multimapView.getTagInternal().getId())
.setTransformId(transformId)
.build());
assertEquals(
stateRequestHandler.receivedRequests.get(3).getStateKey().getMultimapSideInput(),
stateRequestHandler.receivedRequests.get(2).getStateKey().getMultimapSideInput(),
BeamFnApi.StateKey.MultimapSideInput.newBuilder()
.setSideInputId(multimapView.getTagInternal().getId())
.setTransformId(transformId)
.setKey(encode("key1"))
.build());
assertEquals(
stateRequestHandler.receivedRequests.get(4).getStateKey().getMultimapSideInput(),
stateRequestHandler.receivedRequests.get(3).getStateKey().getMultimapSideInput(),
BeamFnApi.StateKey.MultimapSideInput.newBuilder()
.setSideInputId(multimapView.getTagInternal().getId())
.setTransformId(transformId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class MultimapSideInput<K, V> implements MultimapView<K, V> {
private final Coder<K> keyCoder;
private final Coder<V> valueCoder;
private volatile Function<ByteString, Iterable<V>> bulkReadResult;
private final boolean useBulkRead;

public MultimapSideInput(
Cache<?, ?> cache,
Expand All @@ -62,6 +63,18 @@ public MultimapSideInput(
StateKey stateKey,
Coder<K> keyCoder,
Coder<V> valueCoder) {
// TODO(robertwb): Plumb the value of useBulkRead from runner capabilities.
this(cache, beamFnStateClient, instructionId, stateKey, keyCoder, valueCoder, false);
}

public MultimapSideInput(
Cache<?, ?> cache,
BeamFnStateClient beamFnStateClient,
String instructionId,
StateKey stateKey,
Coder<K> keyCoder,
Coder<V> valueCoder,
boolean useBulkRead) {
checkArgument(
stateKey.hasMultimapKeysSideInput(),
"Expected MultimapKeysSideInput StateKey but received %s.",
Expand All @@ -72,6 +85,7 @@ public MultimapSideInput(
StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build();
this.keyCoder = keyCoder;
this.valueCoder = valueCoder;
this.useBulkRead = useBulkRead;
}

@Override
Expand All @@ -84,62 +98,70 @@ public Iterable<K> get() {
public Iterable<V> get(K k) {
ByteString encodedKey = encodeKey(k);

if (bulkReadResult == null) {
synchronized (this) {
if (bulkReadResult == null) {
Map<ByteString, Iterable<V>> bulkRead = new HashMap<>();
StateKey bulkReadStateKey =
StateKey.newBuilder()
.setMultimapKeysValuesSideInput(
StateKey.MultimapKeysValuesSideInput.newBuilder()
.setTransformId(
keysRequest.getStateKey().getMultimapKeysSideInput().getTransformId())
.setSideInputId(
keysRequest.getStateKey().getMultimapKeysSideInput().getSideInputId())
.setWindow(
keysRequest.getStateKey().getMultimapKeysSideInput().getWindow()))
.build();
if (useBulkRead) {
if (bulkReadResult == null) {
synchronized (this) {
if (bulkReadResult == null) {
Map<ByteString, Iterable<V>> bulkRead = new HashMap<>();
StateKey bulkReadStateKey =
StateKey.newBuilder()
.setMultimapKeysValuesSideInput(
StateKey.MultimapKeysValuesSideInput.newBuilder()
.setTransformId(
keysRequest
.getStateKey()
.getMultimapKeysSideInput()
.getTransformId())
.setSideInputId(
keysRequest
.getStateKey()
.getMultimapKeysSideInput()
.getSideInputId())
.setWindow(
keysRequest.getStateKey().getMultimapKeysSideInput().getWindow()))
.build();

StateRequest bulkReadRequest =
keysRequest.toBuilder().setStateKey(bulkReadStateKey).build();
try {
Iterator<KV<K, Iterable<V>>> entries =
StateFetchingIterators.readAllAndDecodeStartingFrom(
Caches.subCache(cache, "ValuesForKey", encodedKey),
beamFnStateClient,
bulkReadRequest,
KvCoder.of(keyCoder, IterableCoder.of(valueCoder)))
.iterator();
while (bulkRead.size() < BULK_READ_SIZE && entries.hasNext()) {
KV<K, Iterable<V>> entry = entries.next();
bulkRead.put(encodeKey(entry.getKey()), entry.getValue());
}
if (entries.hasNext()) {
StateRequest bulkReadRequest =
keysRequest.toBuilder().setStateKey(bulkReadStateKey).build();
try {
Iterator<KV<K, Iterable<V>>> entries =
StateFetchingIterators.readAllAndDecodeStartingFrom(
Caches.subCache(cache, "ValuesForKey", encodedKey),
beamFnStateClient,
bulkReadRequest,
KvCoder.of(keyCoder, IterableCoder.of(valueCoder)))
.iterator();
while (bulkRead.size() < BULK_READ_SIZE && entries.hasNext()) {
KV<K, Iterable<V>> entry = entries.next();
bulkRead.put(encodeKey(entry.getKey()), entry.getValue());
}
if (entries.hasNext()) {
bulkReadResult = bulkRead::get;
} else {
bulkReadResult =
key -> {
Iterable<V> result = bulkRead.get(key);
if (result == null) {
// As we read the entire set of values, we don't have to do a lookup to know
// this key doesn't exist.
// Missing keys are treated as empty iterables in this multimap.
return Collections.emptyList();
} else {
return result;
}
};
}
} catch (Exception exn) {
bulkReadResult = bulkRead::get;
} else {
bulkReadResult =
key -> {
Iterable<V> result = bulkRead.get(key);
if (result == null) {
// As we read the entire set of values, we don't have to do a lookup to know
// this key doesn't exist.
// Missing keys are treated as empty iterables in this multimap.
return Collections.emptyList();
} else {
return result;
}
};
}
} catch (Exception exn) {
bulkReadResult = bulkRead::get;
}
}
}
}

Iterable<V> bulkReadValues = bulkReadResult.apply(encodedKey);
if (bulkReadValues != null) {
return bulkReadValues;
Iterable<V> bulkReadValues = bulkReadResult.apply(encodedKey);
if (bulkReadValues != null) {
return bulkReadValues;
}
}

StateKey stateKey =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ public void testGetWithBulkRead() throws Exception {
"instructionId",
keysStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
StringUtf8Coder.of(),
true);
assertArrayEquals(
new String[] {"A1", "A2", "A3"}, Iterables.toArray(multimapSideInput.get(A), String.class));
assertArrayEquals(
Expand All @@ -94,7 +95,8 @@ public void testGet() throws Exception {
"instructionId",
keysStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
StringUtf8Coder.of(),
true);
assertArrayEquals(
new String[] {"A1", "A2", "A3"}, Iterables.toArray(multimapSideInput.get(A), String.class));
assertArrayEquals(
Expand Down Expand Up @@ -124,7 +126,8 @@ public void testGetCached() throws Exception {
"instructionId",
keysStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
StringUtf8Coder.of(),
true);
assertArrayEquals(
new String[] {"A1", "A2", "A3"},
Iterables.toArray(multimapSideInput.get(A), String.class));
Expand All @@ -147,7 +150,8 @@ public void testGetCached() throws Exception {
"instructionId",
keysStateKey(),
ByteArrayCoder.of(),
StringUtf8Coder.of());
StringUtf8Coder.of(),
true);
assertArrayEquals(
new String[] {"A1", "A2", "A3"},
Iterables.toArray(multimapSideInput.get(A), String.class));
Expand Down
62 changes: 33 additions & 29 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,8 @@ def __init__(self,
transform_id, # type: str
tag, # type: Optional[str]
side_input_data, # type: pvalue.SideInputData
coder # type: WindowedValueCoder
coder, # type: WindowedValueCoder
use_bulk_read = False, # type: bool
):
# type: (...) -> None
self._state_handler = state_handler
Expand All @@ -399,6 +400,7 @@ def __init__(self,
self._target_window_coder = coder.window_coder
# TODO(robertwb): Limit the cache size.
self._cache = {} # type: Dict[BoundedWindow, Any]
self._use_bulk_read = use_bulk_read

def __getitem__(self, window):
target_window = self._side_input_data.window_mapping_fn(window)
Expand Down Expand Up @@ -432,42 +434,44 @@ def __getitem__(self, window):
key_coder = self._element_coder.key_coder()
key_coder_impl = key_coder.get_impl()
value_coder = self._element_coder.value_coder()
use_bulk_read = self._use_bulk_read

class MultiMap(object):
_bulk_read = None
_lock = threading.Lock()

def __getitem__(self, key):
if self._bulk_read is None:
with self._lock:
if self._bulk_read is None:
try:
# Attempt to bulk read the key-values over the iterable
# protocol which, if supported, can be much more efficient
# than point lookups if it fits into memory.
for ix, (k, vs) in enumerate(_StateBackedIterable(
state_handler,
kv_iter_state_key,
coders.TupleCoder(
(key_coder, coders.IterableCoder(value_coder))))):
cache[k] = vs
if ix > StateBackedSideInputMap._BULK_READ_LIMIT:
if use_bulk_read:
if self._bulk_read is None:
with self._lock:
if self._bulk_read is None:
try:
# Attempt to bulk read the key-values over the iterable
# protocol which, if supported, can be much more efficient
# than point lookups if it fits into memory.
for ix, (k, vs) in enumerate(_StateBackedIterable(
state_handler,
kv_iter_state_key,
coders.TupleCoder(
(key_coder, coders.IterableCoder(value_coder))))):
cache[k] = vs
if ix > StateBackedSideInputMap._BULK_READ_LIMIT:
self._bulk_read = (
StateBackedSideInputMap._BULK_READ_PARTIALLY)
break
else:
# We reached the end of the iteration without breaking.
self._bulk_read = (
StateBackedSideInputMap._BULK_READ_PARTIALLY)
break
else:
# We reached the end of the iteration without breaking.
StateBackedSideInputMap._BULK_READ_FULLY)
except Exception:
_LOGGER.error(
"Iterable access of map side inputs unsupported.",
exc_info=True)
self._bulk_read = (
StateBackedSideInputMap._BULK_READ_FULLY)
except Exception:
_LOGGER.error(
"Iterable access of map side inputs unsupported.",
exc_info=True)
self._bulk_read = (
StateBackedSideInputMap._BULK_READ_PARTIALLY)

if (self._bulk_read == StateBackedSideInputMap._BULK_READ_FULLY):
return cache.get(key, [])
StateBackedSideInputMap._BULK_READ_PARTIALLY)

if (self._bulk_read == StateBackedSideInputMap._BULK_READ_FULLY):
return cache.get(key, [])

if key not in cache:
keyed_state_key = beam_fn_api_pb2.StateKey()
Expand Down

0 comments on commit 98a2690

Please sign in to comment.