Skip to content

Commit

Permalink
Add spark mapstate (#31669)
Browse files Browse the repository at this point in the history
* add isEmpty test in testMap

* add map state in spark runner

* update comment on SparkStateInternalsTest

* modify isEmpty test to assertFalse / assertTrue
  • Loading branch information
twosom authored Jun 28, 2024
1 parent 6ec1fb2 commit 3588d19
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -386,10 +386,16 @@ public void testMap() throws Exception {
value.entries().readLater().read(),
containsInAnyOrder(MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5)));

// isEmpty
assertFalse(value.isEmpty().read());

// clear
value.clear();
assertThat(value.entries().read(), Matchers.emptyIterable());
assertThat(underTest.state(NAMESPACE_1, STRING_MAP_ADDR), equalTo(value));

// isEmpty
assertTrue(value.isEmpty().read());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateTag;
Expand All @@ -28,12 +32,14 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.InstantCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.MapCoder;
import org.apache.beam.sdk.state.BagState;
import org.apache.beam.sdk.state.CombiningState;
import org.apache.beam.sdk.state.MapState;
import org.apache.beam.sdk.state.MultimapState;
import org.apache.beam.sdk.state.OrderedListState;
import org.apache.beam.sdk.state.ReadableState;
import org.apache.beam.sdk.state.ReadableStates;
import org.apache.beam.sdk.state.SetState;
import org.apache.beam.sdk.state.State;
import org.apache.beam.sdk.state.StateContext;
Expand All @@ -44,6 +50,7 @@
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.sdk.util.CombineFnUtil;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.joda.time.Instant;
Expand Down Expand Up @@ -119,11 +126,10 @@ public <T> SetState<T> bindSet(StateTag<SetState<T>> spec, Coder<T> elemCoder) {

@Override
public <KeyT, ValueT> MapState<KeyT, ValueT> bindMap(
StateTag<MapState<KeyT, ValueT>> spec,
StateTag<MapState<KeyT, ValueT>> address,
Coder<KeyT> mapKeyCoder,
Coder<ValueT> mapValueCoder) {
throw new UnsupportedOperationException(
String.format("%s is not supported", MapState.class.getSimpleName()));
return new SparkMapState<>(namespace, address, MapCoder.of(mapKeyCoder, mapValueCoder));
}

@Override
Expand Down Expand Up @@ -359,6 +365,142 @@ public AccumT mergeAccumulators(Iterable<AccumT> accumulators) {
}
}

private final class SparkMapState<MapKeyT, MapValueT>
extends AbstractState<Map<MapKeyT, MapValueT>> implements MapState<MapKeyT, MapValueT> {

private SparkMapState(
StateNamespace namespace,
StateTag<? extends State> address,
Coder<Map<MapKeyT, MapValueT>> coder) {
super(namespace, address, coder);
}

@Override
public ReadableState<MapValueT> get(MapKeyT key) {
return getOrDefault(key, null);
}

@Override
public ReadableState<MapValueT> getOrDefault(MapKeyT key, @Nullable MapValueT defaultValue) {
return new ReadableState<MapValueT>() {
@Override
public MapValueT read() {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
if (sparkMapState == null) {
return defaultValue;
}
return sparkMapState.getOrDefault(key, defaultValue);
}

@Override
public ReadableState<MapValueT> readLater() {
return this;
}
};
}

@Override
public void put(MapKeyT key, MapValueT value) {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
if (sparkMapState == null) {
sparkMapState = new HashMap<>();
}
sparkMapState.put(key, value);
writeValue(sparkMapState);
}

@Override
public ReadableState<MapValueT> computeIfAbsent(
MapKeyT key, Function<? super MapKeyT, ? extends MapValueT> mappingFunction) {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
MapValueT current = sparkMapState.get(key);
if (current == null) {
put(key, mappingFunction.apply(key));
}
return ReadableStates.immediate(current);
}

@Override
public void remove(MapKeyT key) {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
sparkMapState.remove(key);
writeValue(sparkMapState);
}

@Override
public ReadableState<Iterable<MapKeyT>> keys() {
return new ReadableState<Iterable<MapKeyT>>() {
@Override
public Iterable<MapKeyT> read() {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
if (sparkMapState == null) {
return Collections.emptyList();
}
return sparkMapState.keySet();
}

@Override
public ReadableState<Iterable<MapKeyT>> readLater() {
return this;
}
};
}

@Override
public ReadableState<Iterable<MapValueT>> values() {
return new ReadableState<Iterable<MapValueT>>() {
@Override
public Iterable<MapValueT> read() {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
if (sparkMapState == null) {
return Collections.emptyList();
}
Iterable<MapValueT> result = readValue().values();
return result != null ? ImmutableList.copyOf(result) : Collections.emptyList();
}

@Override
public ReadableState<Iterable<MapValueT>> readLater() {
return this;
}
};
}

@Override
public ReadableState<Iterable<Map.Entry<MapKeyT, MapValueT>>> entries() {
return new ReadableState<Iterable<Map.Entry<MapKeyT, MapValueT>>>() {
@Override
public Iterable<Map.Entry<MapKeyT, MapValueT>> read() {
Map<MapKeyT, MapValueT> sparkMapState = readValue();
if (sparkMapState == null) {
return Collections.emptyList();
}
return sparkMapState.entrySet();
}

@Override
public ReadableState<Iterable<Map.Entry<MapKeyT, MapValueT>>> readLater() {
return this;
}
};
}

@Override
public ReadableState<Boolean> isEmpty() {
return new ReadableState<Boolean>() {
@Override
public Boolean read() {
return stateTable.get(namespace.stringKey(), address.getId()) == null;
}

@Override
public ReadableState<Boolean> readLater() {
return this;
}
};
}
}

private final class SparkBagState<T> extends AbstractState<List<T>> implements BagState<T> {
private SparkBagState(StateNamespace namespace, StateTag<BagState<T>> address, Coder<T> coder) {
super(namespace, address, ListCoder.of(coder));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

/**
* Tests for {@link SparkStateInternals}. This is based on {@link StateInternalsTest}. Ignore set
* and map tests.
* tests.
*/
@RunWith(JUnit4.class)
public class SparkStateInternalsTest extends StateInternalsTest {
Expand All @@ -51,15 +51,7 @@ public void testMergeSetIntoSource() {}
@Ignore
public void testMergeSetIntoNewNamespace() {}

@Override
@Ignore
public void testMap() {}

@Override
@Ignore
public void testSetReadable() {}

@Override
@Ignore
public void testMapReadable() {}
}

0 comments on commit 3588d19

Please sign in to comment.