Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samza runner support for non unique stateId across multiple ParDos #24276

Merged
merged 5 commits into from
Dec 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions runners/samza/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,6 @@ tasks.register("validatesRunner", Test) {
excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testPairWithIndexWindowedTimestampedUnbounded'
excludeTestsMatching 'org.apache.beam.sdk.transforms.SplittableDoFnTest.testOutputAfterCheckpointUnbounded'
}
filter {
// Re-enable the test after Samza runner supports same state id across DoFn(s).
excludeTest('ParDoTest$StateTests', 'testValueStateSameId')
}
}

tasks.register("validatesRunnerSickbay", Test) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.util.Iterator;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.runners.core.construction.SplittableParDo;
import org.apache.beam.runners.core.construction.renderer.PipelineDotRenderer;
Expand All @@ -33,6 +34,7 @@
import org.apache.beam.runners.samza.translation.SamzaPipelineTranslator;
import org.apache.beam.runners.samza.translation.SamzaPortablePipelineTranslator;
import org.apache.beam.runners.samza.translation.SamzaTransformOverrides;
import org.apache.beam.runners.samza.translation.StateIdParser;
import org.apache.beam.runners.samza.translation.TranslationContext;
import org.apache.beam.runners.samza.util.PipelineJsonRenderer;
import org.apache.beam.sdk.Pipeline;
Expand Down Expand Up @@ -140,9 +142,11 @@ public SamzaPipelineResult run(Pipeline pipeline) {
LOG.info("Beam pipeline JSON graph:\n{}", jsonGraph);

final Map<PValue, String> idMap = PViewToIdMapper.buildIdMap(pipeline);
final Set<String> nonUniqueStateIds = StateIdParser.scan(pipeline);
final ConfigBuilder configBuilder = new ConfigBuilder(options);

SamzaPipelineTranslator.createConfig(pipeline, options, idMap, configBuilder);
SamzaPipelineTranslator.createConfig(
pipeline, options, idMap, nonUniqueStateIds, configBuilder);
configBuilder.put(BEAM_DOT_GRAPH, dotGraph);
configBuilder.put(BEAM_JSON_GRAPH, jsonGraph);

Expand All @@ -162,7 +166,7 @@ public SamzaPipelineResult run(Pipeline pipeline) {
appDescriptor.withMetricsReporterFactories(reporterFactories);

SamzaPipelineTranslator.translate(
pipeline, new TranslationContext(appDescriptor, idMap, options));
pipeline, new TranslationContext(appDescriptor, idMap, nonUniqueStateIds, options));
};

// perform a final round of validation for the pipeline options now that all configs are
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ public class DoFnOp<InT, FnOutT, OutT> implements Op<InT, OutT, Void> {

private final DoFnSchemaInformation doFnSchemaInformation;
private final Map<?, PCollectionView<?>> sideInputMapping;
private final Map<String, String> stateIdToStoreMapping;

public DoFnOp(
TupleTag<FnOutT> mainOutputTag,
Expand All @@ -148,7 +149,8 @@ public DoFnOp(
JobInfo jobInfo,
Map<String, TupleTag<?>> idToTupleTagMap,
DoFnSchemaInformation doFnSchemaInformation,
Map<?, PCollectionView<?>> sideInputMapping) {
Map<?, PCollectionView<?>> sideInputMapping,
Map<String, String> stateIdToStoreMapping) {
this.mainOutputTag = mainOutputTag;
this.doFn = doFn;
this.sideInputs = sideInputs;
Expand All @@ -171,6 +173,7 @@ public DoFnOp(
this.bundleStateId = "_samza_bundle_" + transformId;
this.doFnSchemaInformation = doFnSchemaInformation;
this.sideInputMapping = sideInputMapping;
this.stateIdToStoreMapping = stateIdToStoreMapping;
}

@Override
Expand Down Expand Up @@ -260,6 +263,7 @@ public void open(
outputCoders,
doFnSchemaInformation,
(Map<String, PCollectionView<?>>) sideInputMapping,
stateIdToStoreMapping,
emitter,
outputFutureCollector);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.WindowedValue;
Expand Down Expand Up @@ -91,15 +89,19 @@ public static <InT, FnOutT> DoFnRunner<InT, FnOutT> create(
Map<TupleTag<?>, Coder<?>> outputCoders,
DoFnSchemaInformation doFnSchemaInformation,
Map<String, PCollectionView<?>> sideInputMapping,
Map<String, String> stateIdToStoreIdMapping,
OpEmitter emitter,
FutureCollector futureCollector) {
final KeyedInternals keyedInternals;
final TimerInternals timerInternals;
final StateInternals stateInternals;
final DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
final SamzaStoreStateInternals.Factory<?> stateInternalsFactory =
SamzaStoreStateInternals.createStateInternalsFactory(
transformId, keyCoder, context.getTaskContext(), pipelineOptions, signature);
transformId,
keyCoder,
context.getTaskContext(),
pipelineOptions,
stateIdToStoreIdMapping);

final SamzaExecutionContext executionContext =
(SamzaExecutionContext) context.getApplicationContainerContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,12 @@
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
Expand Down Expand Up @@ -66,7 +64,6 @@
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.CombineWithContext;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.windowing.TimestampCombiner;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
Expand Down Expand Up @@ -129,18 +126,7 @@ static KeyValueStore<ByteArray, StateValue<?>> getBeamStore(TaskContext context)
*/
static <K> Factory<K> createNonKeyedStateInternalsFactory(
String id, TaskContext context, SamzaPipelineOptions pipelineOptions) {
return createStateInternalsFactory(id, null, context, pipelineOptions, Collections.emptySet());
}

static <K> Factory<K> createStateInternalsFactory(
seungjin-an marked this conversation as resolved.
Show resolved Hide resolved
String id,
Coder<K> keyCoder,
TaskContext context,
SamzaPipelineOptions pipelineOptions,
DoFnSignature signature) {

return createStateInternalsFactory(
id, keyCoder, context, pipelineOptions, signature.stateDeclarations().keySet());
return createStateInternalsFactory(id, null, context, pipelineOptions, Collections.emptyMap());
}

static <K> Factory<K> createStateInternalsFactory(
Expand All @@ -150,31 +136,35 @@ static <K> Factory<K> createStateInternalsFactory(
SamzaPipelineOptions pipelineOptions,
ExecutableStage executableStage) {

Set<String> stateIds =
Map<String, String> stateIdToStoreMap =
executableStage.getUserStates().stream()
.map(UserStateReference::localName)
.collect(Collectors.toSet());
.collect(
Collectors.toMap(UserStateReference::localName, UserStateReference::localName));

return createStateInternalsFactory(id, keyCoder, context, pipelineOptions, stateIds);
return createStateInternalsFactory(id, keyCoder, context, pipelineOptions, stateIdToStoreMap);
}

@SuppressWarnings("unchecked")
private static <K> Factory<K> createStateInternalsFactory(
static <K> Factory<K> createStateInternalsFactory(
String id,
@Nullable Coder<K> keyCoder,
TaskContext context,
SamzaPipelineOptions pipelineOptions,
Collection<String> stateIds) {
Map<String, String> stateIdToStoreMap) {
final int batchGetSize = pipelineOptions.getStoreBatchGetSize();
final Map<String, KeyValueStore<ByteArray, StateValue<?>>> stores = new HashMap<>();
stores.put(BEAM_STORE, getBeamStore(context));

final Coder<K> stateKeyCoder;
if (keyCoder != null) {
stateIds.forEach(
stateId ->
stores.put(
stateId, (KeyValueStore<ByteArray, StateValue<?>>) context.getStore(stateId)));
stateIdToStoreMap
.keySet()
.forEach(
stateId ->
stores.put(
xinyuiscool marked this conversation as resolved.
Show resolved Hide resolved
stateId,
(KeyValueStore<ByteArray, StateValue<?>>)
context.getStore(stateIdToStoreMap.get(stateId))));
stateKeyCoder = keyCoder;
} else {
stateKeyCoder = (Coder<K>) VoidCoder.of();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import org.apache.samza.runtime.LocalApplicationRunner;
import org.apache.samza.runtime.RemoteApplicationRunner;
import org.apache.samza.standalone.PassthroughJobCoordinatorFactory;
import org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory;
import org.apache.samza.zk.ZkJobCoordinatorFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand Down Expand Up @@ -306,9 +307,7 @@ private static Map<String, String> createSystemConfig(
static Map<String, String> createRocksDBStoreConfig(SamzaPipelineOptions options) {
final ImmutableMap.Builder<String, String> configBuilder =
ImmutableMap.<String, String>builder()
.put(
BEAM_STORE_FACTORY,
"org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory")
.put(BEAM_STORE_FACTORY, RocksDbKeyValueStorageEngineFactory.class.getName())
.put("stores.beamStore.rocksdb.compression", "lz4");

if (options.getStateDurable()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
*/
package org.apache.beam.runners.samza.translation;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.samza.SamzaPipelineOptions;
import org.apache.beam.runners.samza.util.StoreIdGenerator;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.PTransform;
Expand All @@ -35,12 +35,13 @@ public class ConfigContext {
private final Map<PValue, String> idMap;
private AppliedPTransform<?, ?, ?> currentTransform;
private final SamzaPipelineOptions options;
private final Set<String> stateIds;
private final StoreIdGenerator storeIdGenerator;

public ConfigContext(Map<PValue, String> idMap, SamzaPipelineOptions options) {
public ConfigContext(
xinyuiscool marked this conversation as resolved.
Show resolved Hide resolved
Map<PValue, String> idMap, Set<String> nonUniqueStateIds, SamzaPipelineOptions options) {
this.idMap = idMap;
this.options = options;
this.stateIds = new HashSet<>();
this.storeIdGenerator = new StoreIdGenerator(nonUniqueStateIds);
}

public void setCurrentTransform(AppliedPTransform<?, ?, ?> currentTransform) {
Expand All @@ -64,8 +65,8 @@ public SamzaPipelineOptions getPipelineOptions() {
return this.options;
}

public boolean addStateId(String stateId) {
return stateIds.add(stateId);
public StoreIdGenerator getStoreIdGenerator() {
return storeIdGenerator;
}

private String getIdForPValue(PValue pvalue) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import org.apache.samza.operators.MessageStream;
import org.apache.samza.operators.functions.FlatMapFunction;
import org.apache.samza.operators.functions.WatermarkFunction;
import org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory;
import org.joda.time.Instant;

/**
Expand Down Expand Up @@ -161,6 +162,13 @@ private static <InT, OutT> void doTranslate(
Map<String, PCollectionView<?>> sideInputMapping =
ParDoTranslation.getSideInputMapping(ctx.getCurrentTransform());

final DoFnSignature signature = DoFnSignatures.getSignature(transform.getFn().getClass());
final Map<String, String> stateIdToStoreMapping = new HashMap<>();
for (String stateId : signature.stateDeclarations().keySet()) {
final String transformFullName = node.getEnclosingNode().getFullName();
final String storeId = ctx.getStoreIdGenerator().getId(stateId, transformFullName);
stateIdToStoreMapping.put(stateId, storeId);
}
final DoFnOp<InT, OutT, RawUnionValue> op =
new DoFnOp<>(
transform.getMainOutputTag(),
Expand All @@ -182,7 +190,8 @@ private static <InT, OutT> void doTranslate(
null,
Collections.emptyMap(),
doFnSchemaInformation,
sideInputMapping);
sideInputMapping,
stateIdToStoreMapping);

final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
Expand Down Expand Up @@ -333,7 +342,8 @@ private static <InT, OutT> void doTranslatePortable(
ctx.getJobInfo(),
idToTupleTagMap,
doFnSchemaInformation,
sideInputMapping);
sideInputMapping,
Collections.emptyMap());

final MessageStream<OpMessage<InT>> mergedStreams;
if (sideInputStreams.isEmpty()) {
Expand Down Expand Up @@ -376,18 +386,11 @@ public Map<String, String> createConfig(

if (signature.usesState()) {
// set up user state configs
for (DoFnSignature.StateDeclaration state : signature.stateDeclarations().values()) {
final String storeId = state.id();

// TODO: remove validation after we support same state id in different ParDo.
if (!ctx.addStateId(storeId)) {
throw new IllegalStateException(
"Duplicate StateId " + storeId + " found in multiple ParDo.");
}

for (String stateId : signature.stateDeclarations().keySet()) {
final String transformFullName = node.getEnclosingNode().getFullName();
final String storeId = ctx.getStoreIdGenerator().getId(stateId, transformFullName);
config.put(
"stores." + storeId + ".factory",
"org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory");
"stores." + storeId + ".factory", RocksDbKeyValueStorageEngineFactory.class.getName());
config.put("stores." + storeId + ".key.serde", "byteArraySerde");
config.put("stores." + storeId + ".msg.serde", "stateValueSerde");
config.put("stores." + storeId + ".rocksdb.compression", "lz4");
Expand Down Expand Up @@ -430,8 +433,7 @@ public Map<String, String> createPortableConfig(
final String storeId = stateId.getLocalName();

config.put(
"stores." + storeId + ".factory",
"org.apache.samza.storage.kv.RocksDbKeyValueStorageEngineFactory");
"stores." + storeId + ".factory", RocksDbKeyValueStorageEngineFactory.class.getName());
config.put("stores." + storeId + ".key.serde", "byteArraySerde");
config.put("stores." + storeId + ".msg.serde", "stateValueSerde");
config.put("stores." + storeId + ".rocksdb.compression", "lz4");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class PortableTranslationContext extends TranslationContext {

public PortableTranslationContext(
StreamApplicationDescriptor appDescriptor, SamzaPipelineOptions options, JobInfo jobInfo) {
super(appDescriptor, Collections.emptyMap(), options);
super(appDescriptor, Collections.emptyMap(), Collections.emptySet(), options);
this.jobInfo = jobInfo;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.HashMap;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.TransformPayloadTranslatorRegistrar;
import org.apache.beam.runners.core.construction.graph.ExecutableStage;
Expand Down Expand Up @@ -78,8 +79,9 @@ public static void createConfig(
Pipeline pipeline,
SamzaPipelineOptions options,
Map<PValue, String> idMap,
Set<String> nonUniqueStateIds,
ConfigBuilder configBuilder) {
final ConfigContext ctx = new ConfigContext(idMap, options);
final ConfigContext ctx = new ConfigContext(idMap, nonUniqueStateIds, options);

final TransformVisitorFn configFn =
new TransformVisitorFn() {
Expand Down
Loading