Skip to content

Commit

Permalink
Samza runner support for non unique stateId across multiple ParDos (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
seungjin-an authored Dec 7, 2022
1 parent 6c24637 commit 11ed0e1
Show file tree
Hide file tree
Showing 17 changed files with 428 additions and 79 deletions.
4 changes: 0 additions & 4 deletions runners/samza/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ tasks.register("validatesRunner", Test) {
excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testPairWithIndexWindowedTimestampedUnbounded'
excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testOutputAfterCheckpointUnbounded'
}
filter {
// Re-enable the test after Samza runner supports same state id across DoFn(s).
excludeTest('ParDoTest$StateTests', 'testValueStateSameId')
}
}

tasks.register("validatesRunnerSickbay", Test) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Iterator;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.SplittableParDo;
import org.apache.beam.runners.core.construction.renderer.PipelineDotRenderer;
Expand All @@ -33,6 +34,7 @@
import org.apache.beam.runners.samza.translation.SamzaPipelineTranslator;
import org.apache.beam.runners.samza.translation.SamzaPortablePipelineTranslator;
import org.apache.beam.runners.samza.translation.SamzaTransformOverrides;
import org.apache.beam.runners.samza.translation.StateIdParser;
import org.apache.beam.runners.samza.translation.TranslationContext;
import org.apache.beam.runners.samza.util.PipelineJsonRenderer;
import org.apache.beam.sdk.Pipeline;
Expand Down Expand Up @@ -140,9 +142,11 @@ public SamzaPipelineResult run(Pipeline pipeline) {
LOG.info("Beam pipeline JSON graph:\n{}", jsonGraph);

final Map<PValue, String> idMap = PViewToIdMapper.buildIdMap(pipeline);
final Set<String> nonUniqueStateIds = StateIdParser.scan(pipeline);
final ConfigBuilder configBuilder = new ConfigBuilder(options);

SamzaPipelineTranslator.createConfig(pipeline, options, idMap, configBuilder);
SamzaPipelineTranslator.createConfig(
pipeline, options, idMap, nonUniqueStateIds, configBuilder);
configBuilder.put(BEAM_DOT_GRAPH, dotGraph);
configBuilder.put(BEAM_JSON_GRAPH, jsonGraph);

Expand All @@ -162,7 +166,7 @@ public SamzaPipelineResult run(Pipeline pipeline) {
appDescriptor.withMetricsReporterFactories(reporterFactories);

SamzaPipelineTranslator.translate(
pipeline, new TranslationContext(appDescriptor, idMap, options));
pipeline, new TranslationContext(appDescriptor, idMap, nonUniqueStateIds, options));
};

// perform a final round of validation for the pipeline options now that all configs are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {

private final DoFnSchemaInformation doFnSchemaInformation;
private final Map<?, PCollectionView<?>> sideInputMapping;
private final Map<String, String> stateIdToStoreMapping;

public DoFnOp(
TupleTag<FnOutT> mainOutputTag,
Expand All @@ -148,7 +149,8 @@ public DoFnOp(
JobInfo jobInfo,
Map<String, TupleTag<?>> idToTupleTagMap,
DoFnSchemaInformation doFnSchemaInformation,
Map<?, PCollectionView<?>> sideInputMapping) {
Map<?, PCollectionView<?>> sideInputMapping,
Map<String, String> stateIdToStoreMapping) {
this.mainOutputTag = mainOutputTag;
this.doFn = doFn;
this.sideInputs = sideInputs;
Expand All @@ -171,6 +173,7 @@ public DoFnOp(
this.bundleStateId = "_samza_bundle_" + transformId;
this.doFnSchemaInformation = doFnSchemaInformation;
this.sideInputMapping = sideInputMapping;
this.stateIdToStoreMapping = stateIdToStoreMapping;
}

@Override
Expand Down Expand Up @@ -260,6 +263,7 @@ public void open(
outputCoders,
doFnSchemaInformation,
(Map<String, PCollectionView<?>>) sideInputMapping,
stateIdToStoreMapping,
emitter,
outputFutureCollector);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
Expand Down Expand Up @@ -91,15 +89,19 @@ public static <InT, FnOutT> DoFnRunner<InT, FnOutT> create(
Map<TupleTag<?>, Coder<?>> outputCoders,
DoFnSchemaInformation doFnSchemaInformation,
Map<String, PCollectionView<?>> sideInputMapping,
Map<String, String> stateIdToStoreIdMapping,
OpEmitter emitter,
FutureCollector futureCollector) {
final KeyedInternals keyedInternals;
final TimerInternals timerInternals;
final StateInternals stateInternals;
final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
final SamzaStoreStateInternals.Factory<?> stateInternalsFactory =
SamzaStoreStateInternals.createStateInternalsFactory(
transformId, keyCoder, context.getTaskContext(), pipelineOptions, signature);
transformId,
keyCoder,
context.getTaskContext(),
pipelineOptions,
stateIdToStoreIdMapping);

final SamzaExecutionContext executionContext =
(SamzaExecutionContext) context.getApplicationContainerContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
Expand Down Expand Up @@ -66,7 +64,6 @@
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
Expand Down Expand Up @@ -129,18 +126,7 @@ static KeyValueStore<ByteArray, StateValue<?>> getBeamStore(TaskContext context)
*/
static <K> Factory<K> createNonKeyedStateInternalsFactory(
String id, TaskContext context, SamzaPipelineOptions pipelineOptions) {
return createStateInternalsFactory(id, null, context, pipelineOptions, Collections.emptySet());
}

static <K> Factory<K> createStateInternalsFactory(
String id,
Coder<K> keyCoder,
TaskContext context,
SamzaPipelineOptions pipelineOptions,
DoFnSignature signature) {

return createStateInternalsFactory(
id, keyCoder, context, pipelineOptions, signature.stateDeclarations().keySet());
return createStateInternalsFactory(id, null, context, pipelineOptions, Collections.emptyMap());
}

static <K> Factory<K> createStateInternalsFactory(
Expand All @@ -150,31 +136,35 @@ static <K> Factory<K> createStateInternalsFactory(
SamzaPipelineOptions pipelineOptions,
ExecutableStage executableStage) {

Set<String> stateIds =
Map<String, String> stateIdToStoreMap =
executableStage.getUserStates().stream()
.map(UserStateReference::localName)
.collect(Collectors.toSet());
.collect(
Collectors.toMap(UserStateReference::localName, UserStateReference::localName));

return createStateInternalsFactory(id, keyCoder, context, pipelineOptions, stateIds);
return createStateInternalsFactory(id, keyCoder, context, pipelineOptions, stateIdToStoreMap);
}

@SuppressWarnings("unchecked")
private static <K> Factory<K> createStateInternalsFactory(
static <K> Factory<K> createStateInternalsFactory(
String id,
@Nullable Coder<K> keyCoder,
TaskContext context,
SamzaPipelineOptions pipelineOptions,
Collection<String> stateIds) {
Map<String, String> stateIdToStoreMap) {
final int batchGetSize = pipelineOptions.getStoreBatchGetSize();
final Map<String, KeyValueStore<ByteArray, StateValue<?>>> stores = new HashMap<>();
stores.put(BEAM_STORE, getBeamStore(context));

final Coder<K> stateKeyCoder;
if (keyCoder != null) {
stateIds.forEach(
stateId ->
stores.put(
stateId, (KeyValueStore<ByteArray, StateValue<?>>) context.getStore(stateId)));
stateIdToStoreMap
.keySet()
.forEach(
stateId ->
stores.put(
stateId,
(KeyValueStore<ByteArray, StateValue<?>>)
context.getStore(stateIdToStoreMap.get(stateId))));
stateKeyCoder = keyCoder;
} else {
stateKeyCoder = (Coder<K>) VoidCoder.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.apache.samza.runtime.LocalApplicationRunner;
import org.apache.samza.runtime.RemoteApplicationRunner;
import org.apache.samza.standalone.PassthroughJobCoordinatorFactory;
import org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory;
import org.apache.samza.zk.ZkJobCoordinatorFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -306,9 +307,7 @@ private static Map<String, String> createSystemConfig(
static Map<String, String> createRocksDBStoreConfig(SamzaPipelineOptions options) {
final ImmutableMap.Builder<String, String> configBuilder =
ImmutableMap.<String, String>builder()
.put(
BEAM_STORE_FACTORY,
"org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory")
.put(BEAM_STORE_FACTORY, RocksDbKeyValueStorageEngineFactory.class.getName())
.put("stores.beamStore.rocksdb.compression", "lz4");

if (options.getStateDurable()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
*/
package org.apache.beam.runners.samza.translation;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.samza.SamzaPipelineOptions;
import org.apache.beam.runners.samza.util.StoreIdGenerator;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.PTransform;
Expand All @@ -35,12 +35,13 @@ public class ConfigContext {
private final Map<PValue, String> idMap;
private AppliedPTransform<?, ?, ?> currentTransform;
private final SamzaPipelineOptions options;
private final Set<String> stateIds;
private final StoreIdGenerator storeIdGenerator;

public ConfigContext(Map<PValue, String> idMap, SamzaPipelineOptions options) {
public ConfigContext(
Map<PValue, String> idMap, Set<String> nonUniqueStateIds, SamzaPipelineOptions options) {
this.idMap = idMap;
this.options = options;
this.stateIds = new HashSet<>();
this.storeIdGenerator = new StoreIdGenerator(nonUniqueStateIds);
}

public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) {
Expand All @@ -64,8 +65,8 @@ public SamzaPipelineOptions getPipelineOptions() {
return this.options;
}

public boolean addStateId(String stateId) {
return stateIds.add(stateId);
public StoreIdGenerator getStoreIdGenerator() {
return storeIdGenerator;
}

private String getIdForPValue(PValue pvalue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.apache.samza.operators.MessageStream;
import org.apache.samza.operators.functions.FlatMapFunction;
import org.apache.samza.operators.functions.WatermarkFunction;
import org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory;
import org.joda.time.Instant;

/**
Expand Down Expand Up @@ -161,6 +162,13 @@ private static <InT, OutT> void doTranslate(
Map<String, PCollectionView<?>> sideInputMapping =
ParDoTranslation.getSideInputMapping(ctx.getCurrentTransform());

final DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
final Map<String, String> stateIdToStoreMapping = new HashMap<>();
for (String stateId : signature.stateDeclarations().keySet()) {
final String transformFullName = node.getEnclosingNode().getFullName();
final String storeId = ctx.getStoreIdGenerator().getId(stateId, transformFullName);
stateIdToStoreMapping.put(stateId, storeId);
}
final DoFnOp<InT, OutT, RawUnionValue> op =
new DoFnOp<>(
transform.getMainOutputTag(),
Expand All @@ -182,7 +190,8 @@ private static <InT, OutT> void doTranslate(
null,
Collections.emptyMap(),
doFnSchemaInformation,
sideInputMapping);
sideInputMapping,
stateIdToStoreMapping);

final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
Expand Down Expand Up @@ -333,7 +342,8 @@ private static <InT, OutT> void doTranslatePortable(
ctx.getJobInfo(),
idToTupleTagMap,
doFnSchemaInformation,
sideInputMapping);
sideInputMapping,
Collections.emptyMap());

final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
Expand Down Expand Up @@ -376,18 +386,11 @@ public Map<String, String> createConfig(

if (signature.usesState()) {
// set up user state configs
for (DoFnSignature.StateDeclaration state : signature.stateDeclarations().values()) {
final String storeId = state.id();

// TODO: remove validation after we support same state id in different ParDo.
if (!ctx.addStateId(storeId)) {
throw new IllegalStateException(
"Duplicate StateId " + storeId + " found in multiple ParDo.");
}

for (String stateId : signature.stateDeclarations().keySet()) {
final String transformFullName = node.getEnclosingNode().getFullName();
final String storeId = ctx.getStoreIdGenerator().getId(stateId, transformFullName);
config.put(
"stores." + storeId + ".factory",
"org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory");
"stores." + storeId + ".factory", RocksDbKeyValueStorageEngineFactory.class.getName());
config.put("stores." + storeId + ".key.serde", "byteArraySerde");
config.put("stores." + storeId + ".msg.serde", "stateValueSerde");
config.put("stores." + storeId + ".rocksdb.compression", "lz4");
Expand Down Expand Up @@ -430,8 +433,7 @@ public Map<String, String> createPortableConfig(
final String storeId = stateId.getLocalName();

config.put(
"stores." + storeId + ".factory",
"org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory");
"stores." + storeId + ".factory", RocksDbKeyValueStorageEngineFactory.class.getName());
config.put("stores." + storeId + ".key.serde", "byteArraySerde");
config.put("stores." + storeId + ".msg.serde", "stateValueSerde");
config.put("stores." + storeId + ".rocksdb.compression", "lz4");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class PortableTranslationContext extends TranslationContext {

public PortableTranslationContext(
StreamApplicationDescriptor appDescriptor, SamzaPipelineOptions options, JobInfo jobInfo) {
super(appDescriptor, Collections.emptyMap(), options);
super(appDescriptor, Collections.emptyMap(), Collections.emptySet(), options);
this.jobInfo = jobInfo;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.TransformPayloadTranslatorRegistrar;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
Expand Down Expand Up @@ -78,8 +79,9 @@ public static void createConfig(
Pipeline pipeline,
SamzaPipelineOptions options,
Map<PValue, String> idMap,
Set<String> nonUniqueStateIds,
ConfigBuilder configBuilder) {
final ConfigContext ctx = new ConfigContext(idMap, options);
final ConfigContext ctx = new ConfigContext(idMap, nonUniqueStateIds, options);

final TransformVisitorFn configFn =
new TransformVisitorFn() {
Expand Down
Loading

0 comments on commit 11ed0e1

Please sign in to comment.