From af31d35a0007648b0fc44d0bcb1d6f94cb88d706 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Thu, 13 Jun 2024 14:46:36 -0400 Subject: [PATCH] Implement ordered list state for FnApi. (#30317) * Add request and response proto messages for ordered list state. * Initial implementation of OrderedListState for fnApi. * Discard the use of the value coder in FakeBeamFnStateClient. * Fix the behavior of pre-existing iterators on local change of state If there are changes on a state after we obtain iterators from calling read() and readRange(), the behavior of these pre-existing iterators were incorrect in the previous implementation. The change introduced here will make sure that these iterators will still work as if no local change is made. * Support continuation token for ordered list get request in fake client * Add copyright notices to the new files * Add binding for ordered list state in fnapi state accessor * Clean up comments * Apply spotless and checkStyle to reformat * Add an encode-only coder for the use in the fake client. * Remove request and response messages for ordered list state get. * The range information is placed in the state key of ordered list * For consistency, we reuse the existing get request and response mesasages of other states like Bag, MultiMap, etc. * Remove request and response messages for ordered list state update * Reuse existing messages of clear and append. * Minor fixes based on feedbacks from reviewers * Replace String::size() > 0 with String::isEmpty() * Return this in readLater and readRangeLater instead of throwing an exception * Remove the added SupressWarnings("unchecked") * Apply spotless * Use data field in AppendRequest for ordered list state Previously, we used a repeated OrderedListEntry field in the AppendRequest particularly for ordered list state. For consistency, we now get rid of that and use the same data field as other states. * Apply spotless * Minor renaming of a variable * Create a new coder for TimestampedValue according to the notes in proto. * Address feedback from the reviewer - Add a test to cover the case when an add/clear operation happens while we are partway through an existing iterable. - Use clear() instead of clearRange(min, max) when we can. - Fix a typo. * Apply spotless * Add urn for ordered list state. * Add ordered list spec to ParDoTranslation. * Fix an edge case when async called after clear. Minor fix based on reviwer comments. * Refactor some variable names. Add a notes on the order of pendingAdds and pendingRemoves during async_close() --- .../model/pipeline/v1/beam_runner_api.proto | 6 +- .../util/construction/ParDoTranslation.java | 12 +- .../construction/ParDoTranslationTest.java | 4 + .../fn/harness/state/FnApiStateAccessor.java | 96 ++- .../harness/state/OrderedListUserState.java | 329 +++++++++ .../harness/state/StateFetchingIterators.java | 7 +- .../harness/state/FakeBeamFnStateClient.java | 205 +++++- .../state/OrderedListUserStateTest.java | 684 ++++++++++++++++++ 8 files changed, 1314 insertions(+), 29 deletions(-) create mode 100644 sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/OrderedListUserState.java create mode 100644 sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/OrderedListUserStateTest.java diff --git a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto index 4672c98fd073..422c2e1a5f7c 100644 --- a/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/org/apache/beam/model/pipeline/v1/beam_runner_api.proto @@ -479,7 +479,11 @@ message StandardUserStateTypes { // StateKey.MultimapKeysUserState or StateKey.MultimapUserState. MULTIMAP = 1 [(beam_urn) = "beam:user_state:multimap:v1"]; - // TODO(https://github.com/apache/beam/issues/20486): Add protocol to support OrderedListState + // Represents a user state specification that supports an ordered list. + // + // StateRequests performed on this user state must use + // StateKey.OrderedListUserState. + ORDERED_LIST = 2 [(beam_urn) = "beam:user_state:ordered_list:v1"]; } } diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/ParDoTranslation.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/ParDoTranslation.java index 23906c733ae3..c873a2a7b860 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/ParDoTranslation.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/util/construction/ParDoTranslation.java @@ -119,6 +119,8 @@ public class ParDoTranslation { public static final String BAG_USER_STATE = "beam:user_state:bag:v1"; /** Represents a user state specification that supports a multimap. */ public static final String MULTIMAP_USER_STATE = "beam:user_state:multimap:v1"; + /** Represents a user state specification that supports an ordered list. */ + public static final String ORDERED_LIST_USER_STATE = "beam:user_state:ordered_list:v1"; static { checkState( @@ -141,6 +143,8 @@ public class ParDoTranslation { BeamUrns.getUrn(StandardRequirements.Enum.REQUIRES_ON_WINDOW_EXPIRATION))); checkState(BAG_USER_STATE.equals(BeamUrns.getUrn(StandardUserStateTypes.Enum.BAG))); checkState(MULTIMAP_USER_STATE.equals(BeamUrns.getUrn(StandardUserStateTypes.Enum.MULTIMAP))); + checkState( + ORDERED_LIST_USER_STATE.equals(BeamUrns.getUrn(StandardUserStateTypes.Enum.ORDERED_LIST))); } /** The URN for an unknown Java {@link DoFn}. */ @@ -601,9 +605,7 @@ public RunnerApi.StateSpec dispatchOrderedList(Coder elementCoder) { .setOrderedListSpec( RunnerApi.OrderedListStateSpec.newBuilder() .setElementCoderId(registerCoderOrThrow(components, elementCoder))) - // TODO(https://github.com/apache/beam/issues/20486): Update with correct protocol - // once the protocol is defined and - // the SDK harness uses it. + .setProtocol(FunctionSpec.newBuilder().setUrn(ORDERED_LIST_USER_STATE)) .build(); } @@ -694,6 +696,10 @@ static StateSpec fromProto(RunnerApi.StateSpec stateSpec, RehydratedComponent case SET_SPEC: return StateSpecs.set(components.getCoder(stateSpec.getSetSpec().getElementCoderId())); + case ORDERED_LIST_SPEC: + return StateSpecs.orderedList( + components.getCoder(stateSpec.getOrderedListSpec().getElementCoderId())); + case SPEC_NOT_SET: default: throw new IllegalArgumentException( diff --git a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ParDoTranslationTest.java b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ParDoTranslationTest.java index fa308163ed3a..6ef836038196 100644 --- a/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ParDoTranslationTest.java +++ b/sdks/java/core/src/test/java/org/apache/beam/sdk/util/construction/ParDoTranslationTest.java @@ -245,6 +245,10 @@ public static Iterable stateSpecs() { { StateSpecs.map(StringUtf8Coder.of(), VarIntCoder.of()), FunctionSpec.newBuilder().setUrn(ParDoTranslation.MULTIMAP_USER_STATE).build() + }, + { + StateSpecs.orderedList(VarIntCoder.of()), + FunctionSpec.newBuilder().setUrn(ParDoTranslation.ORDERED_LIST_USER_STATE).build() } }); } diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java index cfc247039503..93f89301d158 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/FnApiStateAccessor.java @@ -43,6 +43,7 @@ import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; +import org.apache.beam.sdk.state.GroupingState; import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.MultimapState; import org.apache.beam.sdk.state.OrderedListState; @@ -63,12 +64,14 @@ import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.sdk.util.construction.BeamUrns; import org.apache.beam.sdk.values.PCollectionView; +import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Instant; /** Provides access to side inputs and state via a {@link BeamFnStateClient}. */ @SuppressWarnings({ @@ -600,8 +603,73 @@ public MultimapState bindMultimap( @Override public OrderedListState bindOrderedList( String id, StateSpec> spec, Coder elemCoder) { - throw new UnsupportedOperationException( - "TODO: Add support for a sorted-list state to the Fn API."); + return (OrderedListState) + stateKeyObjectCache.computeIfAbsent( + createOrderedListUserStateKey(id), + new Function() { + @Override + public Object apply(StateKey key) { + return new OrderedListState() { + private final OrderedListUserState impl = + createOrderedListUserState(key, elemCoder); + + @Override + public void clear() { + impl.clear(); + } + + @Override + public void add(TimestampedValue value) { + impl.add(value); + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public @Nullable Boolean read() { + return !impl.read().iterator().hasNext(); + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Nullable + @Override + public Iterable> read() { + return readRange( + Instant.ofEpochMilli(Long.MIN_VALUE), Instant.ofEpochMilli(Long.MAX_VALUE)); + } + + @Override + public GroupingState, Iterable>> + readLater() { + return this; + } + + @Override + public Iterable> readRange( + Instant minTimestamp, Instant limitTimestamp) { + return impl.readRange(minTimestamp, limitTimestamp); + } + + @Override + public void clearRange(Instant minTimestamp, Instant limitTimestamp) { + impl.clearRange(minTimestamp, limitTimestamp); + } + + @Override + public OrderedListState readRangeLater( + Instant minTimestamp, Instant limitTimestamp) { + return this; + } + }; + } + }); } @Override @@ -849,6 +917,30 @@ private StateKey createMultimapKeysUserStateKey(String stateId) { return builder.build(); } + private OrderedListUserState createOrderedListUserState( + StateKey stateKey, Coder valueCoder) { + OrderedListUserState rval = + new OrderedListUserState<>( + getCacheFor(stateKey), + beamFnStateClient, + processBundleInstructionId.get(), + stateKey, + valueCoder); + stateFinalizers.add(rval::asyncClose); + return rval; + } + + private StateKey createOrderedListUserStateKey(String stateId) { + StateKey.Builder builder = StateKey.newBuilder(); + builder + .getOrderedListUserStateBuilder() + .setWindow(encodedCurrentWindowSupplier.get()) + .setKey(encodedCurrentKeySupplier.get()) + .setTransformId(ptransformId) + .setUserStateId(stateId); + return builder.build(); + } + public void finalizeState() { // Persist all dirty state cells try { diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/OrderedListUserState.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/OrderedListUserState.java new file mode 100644 index 000000000000..47b5057880b9 --- /dev/null +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/OrderedListUserState.java @@ -0,0 +1,329 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.state; + +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkArgument; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions.checkState; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.List; +import java.util.Map.Entry; +import java.util.NavigableMap; +import java.util.concurrent.CompletableFuture; +import org.apache.beam.fn.harness.Cache; +import org.apache.beam.fn.harness.Caches; +import org.apache.beam.fn.harness.state.StateFetchingIterators.CachingStateIterable; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearRequest; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.coders.StructuredCoder; +import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.fn.stream.PrefetchableIterable; +import org.apache.beam.sdk.fn.stream.PrefetchableIterables; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.sdk.values.TypeDescriptor; +import org.apache.beam.sdk.values.TypeParameter; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.BoundType; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Range; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.TreeRangeSet; +import org.joda.time.Instant; + +/** + * An implementation of an ordered list user state that utilizes the Beam Fn State API to fetch, + * clear and persist values. + * + *

Calling {@link #asyncClose()} schedules any required persistence changes. This object should + * no longer be used after it is closed. + * + *

TODO: Move to an async persist model where persistence is signalled based upon cache memory + * pressure and its need to flush. + */ +public class OrderedListUserState { + private final BeamFnStateClient beamFnStateClient; + private final StateRequest requestTemplate; + private final TimestampedValueCoder timestampedValueCoder; + // Pending updates to persistent storage + // (a) The elements in pendingAdds are the ones that should be added to the persistent storage + // during the next async_close(). It doesn't include the ones that are removed by + // clear_range() or clear() after the last add. + // (b) The elements in pendingRemoves are the sort keys that should be removed from the persistent + // storage. + // (c) When syncing local copy with persistent storage, pendingRemoves are performed first and + // then pendingAdds. Switching this order may result in wrong results, because a value added + // later could be removed from an earlier clear. + private NavigableMap> pendingAdds = Maps.newTreeMap(); + private TreeRangeSet pendingRemoves = TreeRangeSet.create(); + + private boolean isCleared = false; + private boolean isClosed = false; + + public static class TimestampedValueCoder extends StructuredCoder> { + + private final Coder valueCoder; + + // Internally, a TimestampedValue is encoded with a KvCoder, where the key is encoded with + // a VarLongCoder and the value is encoded with a LengthPrefixCoder. + // Refer to the comment in StateAppendRequest + // (org/apache/beam/model/fn_execution/v1/beam_fn_api.proto) for more detail. + private final KvCoder internalKvCoder; + + public static OrderedListUserState.TimestampedValueCoder of(Coder valueCoder) { + return new OrderedListUserState.TimestampedValueCoder<>(valueCoder); + } + + @Override + public Object structuralValue(TimestampedValue value) { + Object structuralValue = valueCoder.structuralValue(value.getValue()); + return TimestampedValue.of(structuralValue, value.getTimestamp()); + } + + @SuppressWarnings("unchecked") + TimestampedValueCoder(Coder valueCoder) { + this.valueCoder = checkNotNull(valueCoder); + this.internalKvCoder = KvCoder.of(VarLongCoder.of(), LengthPrefixCoder.of(valueCoder)); + } + + @Override + public void encode(TimestampedValue timestampedValue, OutputStream outStream) + throws IOException { + internalKvCoder.encode( + KV.of(timestampedValue.getTimestamp().getMillis(), timestampedValue.getValue()), + outStream); + } + + @Override + public TimestampedValue decode(InputStream inStream) throws IOException { + KV kv = internalKvCoder.decode(inStream); + return TimestampedValue.of(kv.getValue(), Instant.ofEpochMilli(kv.getKey())); + } + + @Override + public void verifyDeterministic() throws NonDeterministicException { + verifyDeterministic( + this, "TimestampedValueCoder requires a deterministic valueCoder", valueCoder); + } + + @Override + public List> getCoderArguments() { + return Arrays.>asList(valueCoder); + } + + public Coder getValueCoder() { + return valueCoder; + } + + @Override + public TypeDescriptor> getEncodedTypeDescriptor() { + return new TypeDescriptor>() {}.where( + new TypeParameter() {}, valueCoder.getEncodedTypeDescriptor()); + } + + @Override + public List> getComponents() { + return Collections.singletonList(valueCoder); + } + } + + public OrderedListUserState( + Cache cache, + BeamFnStateClient beamFnStateClient, + String instructionId, + StateKey stateKey, + Coder valueCoder) { + checkArgument( + stateKey.hasOrderedListUserState(), + "Expected OrderedListUserState StateKey but received %s.", + stateKey); + this.beamFnStateClient = beamFnStateClient; + this.timestampedValueCoder = TimestampedValueCoder.of(valueCoder); + this.requestTemplate = + StateRequest.newBuilder().setInstructionId(instructionId).setStateKey(stateKey).build(); + } + + public void add(TimestampedValue value) { + checkState( + !isClosed, + "OrderedList user state is no longer usable because it is closed for %s", + requestTemplate.getStateKey()); + Instant timestamp = value.getTimestamp(); + pendingAdds.putIfAbsent(timestamp, new ArrayList<>()); + pendingAdds.get(timestamp).add(value.getValue()); + } + + public Iterable> readRange(Instant minTimestamp, Instant limitTimestamp) { + checkState( + !isClosed, + "OrderedList user state is no longer usable because it is closed for %s", + requestTemplate.getStateKey()); + + // Store pendingAdds whose sort key is in the query range and values are truncated by the + // current size. The values (collections) of pendingAdds are kept, so that they will still be + // accessible in pre-existing iterables even after: + // (1) a sort key is added to or removed from pendingAdds, or + // (2) a new value is added to an existing sort key + ArrayList>> pendingAddsInRange = new ArrayList<>(); + for (Entry> kv : + pendingAdds.subMap(minTimestamp, limitTimestamp).entrySet()) { + pendingAddsInRange.add( + PrefetchableIterables.limit( + Iterables.transform(kv.getValue(), (v) -> TimestampedValue.of(v, kv.getKey())), + kv.getValue().size())); + } + Iterable> valuesInRange = Iterables.concat(pendingAddsInRange); + + if (!isCleared) { + StateRequest.Builder getRequestBuilder = this.requestTemplate.toBuilder(); + getRequestBuilder + .getStateKeyBuilder() + .getOrderedListUserStateBuilder() + .getRangeBuilder() + .setStart(minTimestamp.getMillis()) + .setEnd(limitTimestamp.getMillis()); + + // TODO: consider use cache here + CachingStateIterable> persistentValues = + StateFetchingIterators.readAllAndDecodeStartingFrom( + Caches.noop(), + this.beamFnStateClient, + getRequestBuilder.build(), + this.timestampedValueCoder); + + // Make a snapshot of the current pendingRemoves and use them to filter persistent values. + // The values of pendingRemoves are copied, so that they will still be accessible in + // pre-existing iterables even after a sort key is removed. + TreeRangeSet pendingRemovesSnapshot = TreeRangeSet.create(pendingRemoves); + Iterable> persistentValuesAfterRemoval = + Iterables.filter( + persistentValues, v -> !pendingRemovesSnapshot.contains(v.getTimestamp())); + + return Iterables.mergeSorted( + ImmutableList.of(persistentValuesAfterRemoval, valuesInRange), + Comparator.comparing(TimestampedValue::getTimestamp)); + } + + return valuesInRange; + } + + public Iterable> read() { + checkState( + !isClosed, + "OrderedList user state is no longer usable because it is closed for %s", + requestTemplate.getStateKey()); + + return readRange(Instant.ofEpochMilli(Long.MIN_VALUE), Instant.ofEpochMilli(Long.MAX_VALUE)); + } + + public void clearRange(Instant minTimestamp, Instant limitTimestamp) { + checkState( + !isClosed, + "OrderedList user state is no longer usable because it is closed for %s", + requestTemplate.getStateKey()); + + // Remove items (in a collection) in the specific range from pendingAdds. + // The old values of the removed sub map are kept, so that they will still be accessible in + // pre-existing iterables even after the sort key is cleared. + pendingAdds.subMap(minTimestamp, limitTimestamp).clear(); + if (!isCleared) { + pendingRemoves.add( + Range.range(minTimestamp, BoundType.CLOSED, limitTimestamp, BoundType.OPEN)); + } + } + + public void clear() { + checkState( + !isClosed, + "OrderedList user state is no longer usable because it is closed for %s", + requestTemplate.getStateKey()); + isCleared = true; + // Create a new object for pendingRemoves and clear the mappings in pendingAdds. + // The entire tree range set of pendingRemoves and the old values in the pendingAdds are kept, + // so that they will still be accessible in pre-existing iterables even after the state is + // cleared. + pendingRemoves = TreeRangeSet.create(); + pendingRemoves.add( + Range.range( + Instant.ofEpochMilli(Long.MIN_VALUE), + BoundType.CLOSED, + Instant.ofEpochMilli(Long.MAX_VALUE), + BoundType.OPEN)); + pendingAdds.clear(); + } + + public void asyncClose() throws Exception { + isClosed = true; + + if (!pendingRemoves.isEmpty()) { + for (Range r : pendingRemoves.asRanges()) { + StateRequest.Builder stateRequest = this.requestTemplate.toBuilder(); + stateRequest.setClear(StateClearRequest.newBuilder().build()); + stateRequest + .getStateKeyBuilder() + .getOrderedListUserStateBuilder() + .getRangeBuilder() + .setStart(r.lowerEndpoint().getMillis()) + .setEnd(r.upperEndpoint().getMillis()); + + CompletableFuture response = beamFnStateClient.handle(stateRequest); + if (!response.get().getError().isEmpty()) { + throw new IllegalStateException(response.get().getError()); + } + } + pendingRemoves.clear(); + } + + if (!pendingAdds.isEmpty()) { + ByteStringOutputStream outStream = new ByteStringOutputStream(); + + for (Entry> entry : pendingAdds.entrySet()) { + for (T v : entry.getValue()) { + TimestampedValue tv = TimestampedValue.of(v, entry.getKey()); + try { + timestampedValueCoder.encode(tv, outStream); + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } + } + StateRequest.Builder stateRequest = this.requestTemplate.toBuilder(); + stateRequest.getAppendBuilder().setData(outStream.toByteString()); + + CompletableFuture response = beamFnStateClient.handle(stateRequest); + if (!response.get().getError().isEmpty()) { + throw new IllegalStateException(response.get().getError()); + } + pendingAdds.clear(); + } + } +} diff --git a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java index 6f8622d736fe..3b9fccfa2a5e 100644 --- a/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java +++ b/sdks/java/harness/src/main/java/org/apache/beam/fn/harness/state/StateFetchingIterators.java @@ -605,13 +605,16 @@ public ByteString next() { } prefetchedResponse = null; + ByteString tokenFromResponse = stateResponse.getGet().getContinuationToken(); + // If the continuation token is empty, that means we have reached EOF. - if (ByteString.EMPTY.equals(stateResponse.getGet().getContinuationToken())) { + if (ByteString.EMPTY.equals(tokenFromResponse)) { continuationToken = null; } else { - continuationToken = stateResponse.getGet().getContinuationToken(); + continuationToken = tokenFromResponse; prefetch(); } + return stateResponse.getGet().getData(); } } diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java index 468b4f6b4251..34dec41771b0 100644 --- a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/FakeBeamFnStateClient.java @@ -17,17 +17,25 @@ */ package org.apache.beam.fn.harness.state; +import static org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.RequestCase.GET; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import java.io.IOException; +import java.io.InputStream; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.NavigableSet; +import java.util.NoSuchElementException; +import java.util.TreeSet; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import org.apache.beam.fn.harness.state.OrderedListUserState.TimestampedValueCoder; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.OrderedListRange; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateAppendResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateClearResponse; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateGetResponse; @@ -36,9 +44,14 @@ import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateRequest.RequestCase; import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateResponse; +import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.VarIntCoder; +import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.util.ByteStringOutputStream; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Maps; @@ -47,6 +60,7 @@ public class FakeBeamFnStateClient implements BeamFnStateClient { private static final int DEFAULT_CHUNK_SIZE = 6; private final Map> data; private int currentId; + private final Map> orderedListSortKeysFromStateKey; public FakeBeamFnStateClient(Coder valueCoder, Map> initialData) { this(valueCoder, initialData, DEFAULT_CHUNK_SIZE); @@ -97,6 +111,27 @@ public FakeBeamFnStateClient(Map, List>> initialData, i } return chunks; })); + + List orderedListStateKeys = + initialData.keySet().stream() + .filter((k) -> k.getTypeCase() == TypeCase.ORDERED_LIST_USER_STATE) + .collect(Collectors.toList()); + + this.orderedListSortKeysFromStateKey = new HashMap<>(); + for (StateKey key : orderedListStateKeys) { + long sortKey = key.getOrderedListUserState().getRange().getStart(); + + StateKey.Builder keyBuilder = key.toBuilder(); + + // clear the range in the state key before using it as a key to store, because ordered list + // with different ranges would be mapped to the same set of sort keys. + keyBuilder.getOrderedListUserStateBuilder().clearRange(); + + this.orderedListSortKeysFromStateKey + .computeIfAbsent(keyBuilder.build(), (unused) -> new TreeSet<>()) + .add(sortKey); + } + this.data = new ConcurrentHashMap<>( Maps.filterValues(encodedData, byteStrings -> !byteStrings.isEmpty())); @@ -134,7 +169,7 @@ public CompletableFuture handle(StateRequest.Builder requestBuild assertNotEquals(TypeCase.TYPE_NOT_SET, key.getTypeCase()); // multimap side input and runner based state keys only support get requests if (key.getTypeCase() == TypeCase.MULTIMAP_SIDE_INPUT || key.getTypeCase() == TypeCase.RUNNER) { - assertEquals(RequestCase.GET, request.getRequestCase()); + assertEquals(GET, request.getRequestCase()); } if (key.getTypeCase() == TypeCase.MULTIMAP_KEYS_VALUES_SIDE_INPUT && !data.containsKey(key)) { // Allow testing this not being supported rather than blindly returning the empty list. @@ -143,34 +178,162 @@ public CompletableFuture handle(StateRequest.Builder requestBuild switch (request.getRequestCase()) { case GET: - List byteStrings = - data.getOrDefault(request.getStateKey(), Collections.singletonList(ByteString.EMPTY)); - int block = 0; - if (request.getGet().getContinuationToken().size() > 0) { - block = Integer.parseInt(request.getGet().getContinuationToken().toStringUtf8()); - } - ByteString returnBlock = byteStrings.get(block); - ByteString continuationToken = ByteString.EMPTY; - if (byteStrings.size() > block + 1) { - continuationToken = ByteString.copyFromUtf8(Integer.toString(block + 1)); + if (key.getTypeCase() == TypeCase.ORDERED_LIST_USER_STATE) { + long start = key.getOrderedListUserState().getRange().getStart(); + long end = key.getOrderedListUserState().getRange().getEnd(); + + KvCoder coder = KvCoder.of(VarLongCoder.of(), VarIntCoder.of()); + long sortKey = start; + int index = 0; + if (!request.getGet().getContinuationToken().isEmpty()) { + try { + // The continuation format here is the sort key (long) followed by an index (int) + KV cursor = + coder.decode(request.getGet().getContinuationToken().newInput()); + sortKey = cursor.getKey(); + index = cursor.getValue(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + ByteString continuationToken; + ByteString returnBlock = ByteString.EMPTY; + try { + if (sortKey < start || sortKey >= end) { + throw new IndexOutOfBoundsException("sort key out of range"); + } + + StateKey.Builder stateKeyWithoutRange = request.getStateKey().toBuilder(); + stateKeyWithoutRange.getOrderedListUserStateBuilder().clearRange(); + NavigableSet subset = + orderedListSortKeysFromStateKey + .getOrDefault(stateKeyWithoutRange.build(), new TreeSet<>()) + .subSet(sortKey, true, end, false); + + // get the effective sort key currently, can throw NoSuchElementException + Long nextSortKey = subset.first(); + + StateKey.Builder keyBuilder = request.getStateKey().toBuilder(); + keyBuilder + .getOrderedListUserStateBuilder() + .getRangeBuilder() + .setStart(nextSortKey) + .setEnd(nextSortKey + 1); + List byteStrings = + data.getOrDefault(keyBuilder.build(), Collections.singletonList(ByteString.EMPTY)); + + // get the block specified in continuation token, can throw IndexOutOfBoundsException + returnBlock = byteStrings.get(index); + + if (byteStrings.size() > index + 1) { + // more blocks from this sort key + index += 1; + } else { + // finish navigating the current sort key and need to find the next one, + // can throw NoSuchElementException + nextSortKey = subset.tailSet(nextSortKey, false).first(); + index = 0; + } + + ByteStringOutputStream outputStream = new ByteStringOutputStream(); + try { + KV cursor = KV.of(nextSortKey, index); + coder.encode(cursor, outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + continuationToken = outputStream.toByteString(); + } catch (NoSuchElementException | IndexOutOfBoundsException e) { + continuationToken = ByteString.EMPTY; + } + response = + StateResponse.newBuilder() + .setGet( + StateGetResponse.newBuilder() + .setData(returnBlock) + .setContinuationToken(continuationToken)); + } else { + List byteStrings = + data.getOrDefault(request.getStateKey(), Collections.singletonList(ByteString.EMPTY)); + int block = 0; + if (!request.getGet().getContinuationToken().isEmpty()) { + block = Integer.parseInt(request.getGet().getContinuationToken().toStringUtf8()); + } + ByteString returnBlock = byteStrings.get(block); + ByteString continuationToken = ByteString.EMPTY; + if (byteStrings.size() > block + 1) { + continuationToken = ByteString.copyFromUtf8(Integer.toString(block + 1)); + } + response = + StateResponse.newBuilder() + .setGet( + StateGetResponse.newBuilder() + .setData(returnBlock) + .setContinuationToken(continuationToken)); } - response = - StateResponse.newBuilder() - .setGet( - StateGetResponse.newBuilder() - .setData(returnBlock) - .setContinuationToken(continuationToken)); break; case CLEAR: - data.remove(request.getStateKey()); + if (key.getTypeCase() == TypeCase.ORDERED_LIST_USER_STATE) { + OrderedListRange r = request.getStateKey().getOrderedListUserState().getRange(); + StateKey.Builder stateKeyWithoutRange = request.getStateKey().toBuilder(); + stateKeyWithoutRange.getOrderedListUserStateBuilder().clearRange(); + + List keysToRemove = + new ArrayList<>( + orderedListSortKeysFromStateKey + .getOrDefault(stateKeyWithoutRange.build(), new TreeSet<>()) + .subSet(r.getStart(), true, r.getEnd(), false)); + for (Long l : keysToRemove) { + StateKey.Builder keyBuilder = request.getStateKey().toBuilder(); + keyBuilder.getOrderedListUserStateBuilder().getRangeBuilder().setStart(l).setEnd(l + 1); + data.remove(keyBuilder.build()); + orderedListSortKeysFromStateKey.get(stateKeyWithoutRange.build()).remove(l); + } + } else { + data.remove(request.getStateKey()); + } response = StateResponse.newBuilder().setClear(StateClearResponse.getDefaultInstance()); break; case APPEND: - List previousValue = - data.computeIfAbsent(request.getStateKey(), (unused) -> new ArrayList<>()); - previousValue.add(request.getAppend().getData()); + if (key.getTypeCase() == TypeCase.ORDERED_LIST_USER_STATE) { + InputStream inStream = request.getAppend().getData().newInput(); + TimestampedValueCoder coder = TimestampedValueCoder.of(ByteArrayCoder.of()); + try { + while (inStream.available() > 0) { + TimestampedValue tv = coder.decode(inStream); + ByteStringOutputStream outStream = new ByteStringOutputStream(); + coder.encode(tv, outStream); + ByteString output = outStream.toByteString(); + + StateKey.Builder keyBuilder = request.getStateKey().toBuilder(); + long sortKey = tv.getTimestamp().getMillis(); + keyBuilder + .getOrderedListUserStateBuilder() + .getRangeBuilder() + .setStart(sortKey) + .setEnd(sortKey + 1); + + List previousValues = + data.computeIfAbsent(keyBuilder.build(), (unused) -> new ArrayList<>()); + previousValues.add(output); + + StateKey.Builder stateKeyWithoutRange = request.getStateKey().toBuilder(); + stateKeyWithoutRange.getOrderedListUserStateBuilder().clearRange(); + orderedListSortKeysFromStateKey + .computeIfAbsent(stateKeyWithoutRange.build(), (unused) -> new TreeSet<>()) + .add(sortKey); + } + } catch (IOException ex) { + throw new RuntimeException(ex); + } + } else { + List previousValue = + data.computeIfAbsent(request.getStateKey(), (unused) -> new ArrayList<>()); + previousValue.add(request.getAppend().getData()); + } response = StateResponse.newBuilder().setAppend(StateAppendResponse.getDefaultInstance()); break; diff --git a/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/OrderedListUserStateTest.java b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/OrderedListUserStateTest.java new file mode 100644 index 000000000000..efade508843b --- /dev/null +++ b/sdks/java/harness/src/test/java/org/apache/beam/fn/harness/state/OrderedListUserStateTest.java @@ -0,0 +1,684 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.fn.harness.state; + +import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.util.Collections; +import java.util.Iterator; +import org.apache.beam.fn.harness.Caches; +import org.apache.beam.fn.harness.state.OrderedListUserState.TimestampedValueCoder; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.OrderedListRange; +import org.apache.beam.model.fnexecution.v1.BeamFnApi.StateKey; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.StringUtf8Coder; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.TimestampedValue; +import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Iterables; +import org.joda.time.Instant; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class OrderedListUserStateTest { + private static final TimestampedValue A1 = + TimestampedValue.of("A1", Instant.ofEpochMilli(1)); + private static final TimestampedValue B1 = + TimestampedValue.of("B1", Instant.ofEpochMilli(1)); + private static final TimestampedValue C1 = + TimestampedValue.of("C1", Instant.ofEpochMilli(1)); + private static final TimestampedValue A2 = + TimestampedValue.of("A2", Instant.ofEpochMilli(2)); + private static final TimestampedValue B2 = + TimestampedValue.of("B2", Instant.ofEpochMilli(2)); + private static final TimestampedValue A3 = + TimestampedValue.of("A3", Instant.ofEpochMilli(3)); + private static final TimestampedValue A4 = + TimestampedValue.of("A4", Instant.ofEpochMilli(4)); + + private final String pTransformId = "pTransformId"; + private final String stateId = "stateId"; + private final String encodedWindow = "encodedWindow"; + private final Coder> timestampedValueCoder = + TimestampedValueCoder.of(StringUtf8Coder.of()); + + @Test + public void testNoPersistedValues() throws Exception { + FakeBeamFnStateClient fakeClient = new FakeBeamFnStateClient(Collections.emptyMap()); + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + assertThat(userState.read(), is(emptyIterable())); + } + + @Test + public void testRead() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of(createOrderedListStateKey("A", 1), asList(A1, B1))); + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + assertArrayEquals( + asList(A1, B1).toArray(), Iterables.toArray(userState.read(), TimestampedValue.class)); + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.read()); + } + + @Test + public void testReadRange() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), asList(A1, B1), + createOrderedListStateKey("A", 4), Collections.singletonList(A4), + createOrderedListStateKey("A", 2), Collections.singletonList(A2))); + + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + Iterable> stateBeforeB2 = + userState.readRange(Instant.ofEpochMilli(2), Instant.ofEpochMilli(4)); + assertArrayEquals( + Collections.singletonList(A2).toArray(), + Iterables.toArray(stateBeforeB2, TimestampedValue.class)); + + // Add a new value to an existing sort key + userState.add(B2); + assertArrayEquals( + Collections.singletonList(A2).toArray(), + Iterables.toArray(stateBeforeB2, TimestampedValue.class)); + assertArrayEquals( + asList(A2, B2).toArray(), + Iterables.toArray( + userState.readRange(Instant.ofEpochMilli(2), Instant.ofEpochMilli(4)), + TimestampedValue.class)); + + // Add a new value to a new sort key + userState.add(A3); + assertArrayEquals( + Collections.singletonList(A2).toArray(), + Iterables.toArray(stateBeforeB2, TimestampedValue.class)); + assertArrayEquals( + asList(A2, B2, A3).toArray(), + Iterables.toArray( + userState.readRange(Instant.ofEpochMilli(2), Instant.ofEpochMilli(4)), + TimestampedValue.class)); + + userState.asyncClose(); + assertThrows( + IllegalStateException.class, + () -> userState.readRange(Instant.ofEpochMilli(1), Instant.ofEpochMilli(2))); + } + + @Test + public void testAdd() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + Collections.singletonList(A1), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4), + createOrderedListStateKey("A", 2), + asList(A2, B2))); + + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + // add to an existing timestamp + userState.add(B1); + assertArrayEquals( + asList(A1, B1, A2, B2, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + + // add to a nonexistent timestamp + userState.add(A3); + assertArrayEquals( + asList(A1, B1, A2, B2, A3, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + + // add a duplicated value + userState.add(B1); + assertArrayEquals( + asList(A1, B1, B1, A2, B2, A3, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.add(A1)); + } + + @Test + public void testClearRange() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + asList(A1, B1), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4), + createOrderedListStateKey("A", 2), + asList(A2, B2), + createOrderedListStateKey("A", 3), + Collections.singletonList(A3))); + + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + Iterable> initStateFrom2To3 = + userState.readRange(Instant.ofEpochMilli(2), Instant.ofEpochMilli(4)); + + // clear range below the current timestamp range + userState.clearRange(Instant.ofEpochMilli(-1), Instant.ofEpochMilli(0)); + assertArrayEquals( + asList(A2, B2, A3).toArray(), Iterables.toArray(initStateFrom2To3, TimestampedValue.class)); + assertArrayEquals( + asList(A1, B1, A2, B2, A3, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + + // clear range above the current timestamp range + userState.clearRange(Instant.ofEpochMilli(5), Instant.ofEpochMilli(10)); + assertArrayEquals( + asList(A2, B2, A3).toArray(), Iterables.toArray(initStateFrom2To3, TimestampedValue.class)); + assertArrayEquals( + asList(A1, B1, A2, B2, A3, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + + // clear range that falls inside the current timestamp range + userState.clearRange(Instant.ofEpochMilli(2), Instant.ofEpochMilli(4)); + assertArrayEquals( + asList(A2, B2, A3).toArray(), Iterables.toArray(initStateFrom2To3, TimestampedValue.class)); + assertArrayEquals( + asList(A1, B1, A4).toArray(), Iterables.toArray(userState.read(), TimestampedValue.class)); + + // clear range that partially covers the current timestamp range + userState.clearRange(Instant.ofEpochMilli(3), Instant.ofEpochMilli(5)); + assertArrayEquals( + asList(A2, B2, A3).toArray(), Iterables.toArray(initStateFrom2To3, TimestampedValue.class)); + assertArrayEquals( + asList(A1, B1).toArray(), Iterables.toArray(userState.read(), TimestampedValue.class)); + + // clear range that fully covers the current timestamp range + userState.clearRange(Instant.ofEpochMilli(-1), Instant.ofEpochMilli(10)); + assertArrayEquals( + asList(A2, B2, A3).toArray(), Iterables.toArray(initStateFrom2To3, TimestampedValue.class)); + assertThat(userState.read(), is(emptyIterable())); + + userState.asyncClose(); + assertThrows( + IllegalStateException.class, + () -> userState.clearRange(Instant.ofEpochMilli(1), Instant.ofEpochMilli(2))); + } + + @Test + public void testClear() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + asList(A1, B1), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4), + createOrderedListStateKey("A", 2), + asList(A2, B2), + createOrderedListStateKey("A", 3), + Collections.singletonList(A3))); + + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + Iterable> stateBeforeClear = userState.read(); + userState.clear(); + assertArrayEquals( + asList(A1, B1, A2, B2, A3, A4).toArray(), + Iterables.toArray(stateBeforeClear, TimestampedValue.class)); + assertThat(userState.read(), is(emptyIterable())); + + userState.asyncClose(); + assertThrows(IllegalStateException.class, () -> userState.clear()); + } + + @Test + public void testAddAndClearRange() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + Collections.singletonList(A1), + createOrderedListStateKey("A", 3), + Collections.singletonList(A3), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4))); + + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + // add to a non-existing timestamp, clear, and then add + userState.add(A2); + Iterable> stateBeforeFirstClearRange = userState.read(); + userState.clearRange(Instant.ofEpochMilli(2), Instant.ofEpochMilli(3)); + assertArrayEquals( + asList(A1, A2, A3, A4).toArray(), + Iterables.toArray(stateBeforeFirstClearRange, TimestampedValue.class)); + assertArrayEquals( + asList(A1, A3, A4).toArray(), Iterables.toArray(userState.read(), TimestampedValue.class)); + userState.add(B2); + assertArrayEquals( + asList(A1, A2, A3, A4).toArray(), + Iterables.toArray(stateBeforeFirstClearRange, TimestampedValue.class)); + assertArrayEquals( + asList(A1, B2, A3, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + + // add to an existing timestamp, clear, and then add + userState.add(B1); + userState.clearRange(Instant.ofEpochMilli(1), Instant.ofEpochMilli(2)); + assertArrayEquals( + asList(A1, A2, A3, A4).toArray(), + Iterables.toArray(stateBeforeFirstClearRange, TimestampedValue.class)); + assertArrayEquals( + asList(B2, A3, A4).toArray(), Iterables.toArray(userState.read(), TimestampedValue.class)); + userState.add(B1); + assertArrayEquals( + asList(A1, A2, A3, A4).toArray(), + Iterables.toArray(stateBeforeFirstClearRange, TimestampedValue.class)); + assertArrayEquals( + asList(B1, B2, A3, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + + // add a duplicated value, clear, and then add + userState.add(A3); + userState.clearRange(Instant.ofEpochMilli(3), Instant.ofEpochMilli(4)); + assertArrayEquals( + asList(A1, A2, A3, A4).toArray(), + Iterables.toArray(stateBeforeFirstClearRange, TimestampedValue.class)); + assertArrayEquals( + asList(B1, B2, A4).toArray(), Iterables.toArray(userState.read(), TimestampedValue.class)); + userState.add(A3); + assertArrayEquals( + asList(A1, A2, A3, A4).toArray(), + Iterables.toArray(stateBeforeFirstClearRange, TimestampedValue.class)); + assertArrayEquals( + asList(B1, B2, A3, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + } + + @Test + public void testAddAndClearRangeAfterClear() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + Collections.singletonList(A1), + createOrderedListStateKey("A", 3), + Collections.singletonList(A3), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4))); + + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + userState.clear(); + userState.clearRange(Instant.ofEpochMilli(0), Instant.ofEpochMilli(5)); + assertThat(userState.read(), is(emptyIterable())); + + userState.add(A1); + assertArrayEquals( + Collections.singletonList(A1).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + + userState.add(A2); + userState.add(A3); + assertArrayEquals( + asList(A1, A2, A3).toArray(), Iterables.toArray(userState.read(), TimestampedValue.class)); + + userState.clearRange(Instant.ofEpochMilli(2), Instant.ofEpochMilli(3)); + assertArrayEquals( + asList(A1, A3).toArray(), Iterables.toArray(userState.read(), TimestampedValue.class)); + } + + @Test + public void testNoopAsyncCloseAndRead() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + Collections.singletonList(A1), + createOrderedListStateKey("A", 3), + Collections.singletonList(A3), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4))); + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + userState.asyncClose(); + } + + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + assertArrayEquals( + asList(A1, A3, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + } + } + + @Test + public void testAddAsyncCloseAndRead() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + Collections.singletonList(A1), + createOrderedListStateKey("A", 3), + Collections.singletonList(A3), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4))); + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + userState.add(B1); + userState.add(A2); + userState.asyncClose(); + } + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + assertArrayEquals( + asList(A1, B1, A2, A3, A4).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + } + } + + @Test + public void testClearRangeAsyncCloseAndRead() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + Collections.singletonList(A1), + createOrderedListStateKey("A", 2), + Collections.singletonList(A2), + createOrderedListStateKey("A", 3), + Collections.singletonList(A3), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4))); + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + userState.clearRange(Instant.ofEpochMilli(1), Instant.ofEpochMilli(3)); + userState.clearRange(Instant.ofEpochMilli(4), Instant.ofEpochMilli(5)); + userState.asyncClose(); + } + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + assertArrayEquals( + Collections.singletonList(A3).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + } + } + + @Test + public void testAddClearRangeAsyncCloseAndRead() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + Collections.singletonList(A1), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4))); + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + userState.add(B1); + userState.add(A2); + userState.add(A3); + userState.clearRange(Instant.ofEpochMilli(1), Instant.ofEpochMilli(3)); + userState.clearRange(Instant.ofEpochMilli(4), Instant.ofEpochMilli(5)); + userState.asyncClose(); + } + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + assertArrayEquals( + Collections.singletonList(A3).toArray(), + Iterables.toArray(userState.read(), TimestampedValue.class)); + } + } + + @Test + public void testClearAsyncCloseAndRead() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + Collections.singletonList(A1), + createOrderedListStateKey("A", 2), + Collections.singletonList(A2), + createOrderedListStateKey("A", 3), + Collections.singletonList(A3), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4))); + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + userState.clear(); + userState.asyncClose(); + } + { + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + assertThat(userState.read(), is(emptyIterable())); + } + } + + @Test + public void testOperationsDuringNavigatingIterable() throws Exception { + FakeBeamFnStateClient fakeClient = + new FakeBeamFnStateClient( + timestampedValueCoder, + ImmutableMap.of( + createOrderedListStateKey("A", 1), + asList(A1, B1), + createOrderedListStateKey("A", 2), + asList(A2, B2), + createOrderedListStateKey("A", 3), + Collections.singletonList(A3), + createOrderedListStateKey("A", 4), + Collections.singletonList(A4))); + + OrderedListUserState userState = + new OrderedListUserState<>( + Caches.noop(), + fakeClient, + "instructionId", + createOrderedListStateKey("A"), + StringUtf8Coder.of()); + + Iterator> iter = userState.read().iterator(); + assertEquals(iter.next(), A1); + + // Adding a C1 locally, but it should not be returned after B1 in the existing iterable. + userState.add(C1); + assertEquals(iter.next(), B1); + assertEquals(iter.next(), A2); + + // Clearing range [2,4) locally, but B2 and A3 should still be returned. + userState.clearRange(Instant.ofEpochMilli(2), Instant.ofEpochMilli(4)); + assertEquals(iter.next(), B2); + assertEquals(iter.next(), A3); + + // Clearing all ranges locally, but A4 should still be returned. + userState.clear(); + assertEquals(iter.next(), A4); + } + + private ByteString encode(String... values) throws IOException { + ByteStringOutputStream out = new ByteStringOutputStream(); + for (String value : values) { + StringUtf8Coder.of().encode(value, out); + } + return out.toByteString(); + } + + private StateKey createOrderedListStateKey(String key) throws IOException { + return StateKey.newBuilder() + .setOrderedListUserState( + StateKey.OrderedListUserState.newBuilder() + .setWindow(encode(encodedWindow)) + .setTransformId(pTransformId) + .setUserStateId(stateId) + .setKey(encode(key))) + .build(); + } + + private StateKey createOrderedListStateKey(String key, long sortKey) throws IOException { + return StateKey.newBuilder() + .setOrderedListUserState( + StateKey.OrderedListUserState.newBuilder() + .setWindow(encode(encodedWindow)) + .setTransformId(pTransformId) + .setUserStateId(stateId) + .setKey(encode(key)) + .setRange( + OrderedListRange.newBuilder().setStart(sortKey).setEnd(sortKey + 1).build())) + .build(); + } +}