Skip to content

Commit

Permalink
working tests
Browse files Browse the repository at this point in the history
  • Loading branch information
je-ik committed Sep 25, 2024
1 parent b0eab68 commit ace470d
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
Expand Down Expand Up @@ -122,7 +123,7 @@ public class ExternalStateExpander {
static final String EXPANDER_TIMER_SPEC = "expanderTimerSpec";
static final String EXPANDER_TIMER_NAME = "_expanderTimer";

static final TupleTag<StateValue> stateValueTupleTag = new TupleTag<>() {};
static final TupleTag<StateValue> STATE_TUPLE_TAG = new StateTupleTag() {};

/**
* Expand the given @{link Pipeline} to support external state store and restore
Expand All @@ -133,7 +134,7 @@ public class ExternalStateExpander {
* @param nextFlushInstantFn function that returns instant of next flush from current time
* @param stateSink transform to store outputs
*/
public static void expand(
public static Pipeline expand(
Pipeline pipeline,
PTransform<PBegin, PCollection<KV<String, StateValue>>> inputs,
Instant stateWriteInstant,
Expand All @@ -143,12 +144,35 @@ public static void expand(
validatePipeline(pipeline);
pipeline.getCoderRegistry().registerCoderForClass(StateValue.class, StateValue.coder());
PCollection<KV<String, StateValue>> inputsMaterialized = pipeline.apply(inputs);

// replace all MultiParDos
pipeline.replaceAll(
Collections.singletonList(
statefulParMultiDoOverride(
inputsMaterialized, stateWriteInstant, nextFlushInstantFn, stateSink)));
statefulParMultiDoOverride(inputsMaterialized, stateWriteInstant, nextFlushInstantFn)));
// collect all StateValues
List<Pair<String, PCollection<StateValue>>> stateValues = new ArrayList<>();
pipeline.traverseTopologically(
new PipelineVisitor.Defaults() {
@Override
public void visitPrimitiveTransform(TransformHierarchy.Node node) {
if (node.getTransform() instanceof ParDo.MultiOutput<?, ?>) {
node.getOutputs().entrySet().stream()
.filter(e -> e.getKey() instanceof StateTupleTag)
.map(Entry::getValue)
.findAny()
.ifPresent(
p ->
stateValues.add(
Pair.of(node.getFullName(), (PCollection<StateValue>) p)));
}
}
});
PCollectionList<KV<String, StateValue>> list = PCollectionList.empty(pipeline);
for (Pair<String, PCollection<StateValue>> p : stateValues) {
list = list.and(p.getSecond().apply(WithKeys.of(p.getFirst())));
}
list.apply(Flatten.pCollections()).apply(stateSink);

return pipeline;
}

private static void validatePipeline(Pipeline pipeline) {
Expand Down Expand Up @@ -184,25 +208,22 @@ public void leavePipeline(Pipeline pipeline) {}
private static PTransformOverride statefulParMultiDoOverride(
PCollection<KV<String, StateValue>> inputs,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
PTransform<PCollection<KV<String, StateValue>>, PDone> stateSink) {
UnaryFunction<Instant, Instant> nextFlushInstantFn) {

return PTransformOverride.of(
application -> application.getTransform() instanceof ParDo.MultiOutput,
parMultiDoReplacementFactory(inputs, stateWriteInstant, nextFlushInstantFn, stateSink));
parMultiDoReplacementFactory(inputs, stateWriteInstant, nextFlushInstantFn));
}

private static PTransformOverrideFactory parMultiDoReplacementFactory(
PCollection<KV<String, StateValue>> inputs,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
PTransform<PCollection<KV<String, StateValue>>, PDone> stateSink) {
UnaryFunction<Instant, Instant> nextFlushInstantFn) {

return new PTransformOverrideFactory() {
@Override
public PTransformReplacement getReplacementTransform(AppliedPTransform transform) {
return replaceParMultiDo(
transform, inputs, stateWriteInstant, nextFlushInstantFn, stateSink);
return replaceParMultiDo(transform, inputs, stateWriteInstant, nextFlushInstantFn);
}

@SuppressWarnings("unchecked")
Expand All @@ -218,8 +239,7 @@ private static PTransformReplacement<PInput, POutput> replaceParMultiDo(
AppliedPTransform<PInput, POutput, ?> transform,
PCollection<KV<String, StateValue>> inputs,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
PTransform<PCollection<KV<String, StateValue>>, PDone> stateSink) {
UnaryFunction<Instant, Instant> nextFlushInstantFn) {

ParDo.MultiOutput<PInput, POutput> rawTransform =
(ParDo.MultiOutput<PInput, POutput>) (PTransform) transform.getTransform();
Expand All @@ -246,8 +266,7 @@ private static PTransformReplacement<PInput, POutput> replaceParMultiDo(
.filter(t -> !t.equals(mainOutputTag))
.collect(Collectors.toList())),
stateWriteInstant,
nextFlushInstantFn,
stateSink));
nextFlushInstantFn));
}

@SuppressWarnings("unchecked")
Expand All @@ -259,8 +278,7 @@ PTransform<PCollection<InputT>, PCollectionTuple> transformedParDo(
TupleTag<OutputT> mainOutputTag,
TupleTagList otherOutputs,
Instant stateWriteInstant,
UnaryFunction<Instant, Instant> nextFlushInstantFn,
PTransform<PCollection<KV<String, StateValue>>, PDone> stateSink) {
UnaryFunction<Instant, Instant> nextFlushInstantFn) {

return new PTransform<>() {
@Override
Expand Down Expand Up @@ -304,13 +322,11 @@ public PCollectionTuple expand(PCollection<InputT> input) {
mainOutputTag,
stateWriteInstant,
nextFlushInstantFn))
.withOutputTags(mainOutputTag, otherOutputs.and(stateValueTupleTag)));
PCollection<StateValue> stateValuePCollection = tuple.get(stateValueTupleTag);
stateValuePCollection.apply(WithKeys.of(transformName)).apply(stateSink);
.withOutputTags(mainOutputTag, otherOutputs.and(STATE_TUPLE_TAG)));
PCollectionTuple res = PCollectionTuple.empty(input.getPipeline());
for (Entry<TupleTag<Object>, PCollection<Object>> e :
(Set<Entry<TupleTag<Object>, PCollection<Object>>>) (Set) tuple.getAll().entrySet()) {
if (!e.getKey().equals(stateValueTupleTag)) {
if (!e.getKey().equals(STATE_TUPLE_TAG)) {
res = res.and(e.getKey(), e.getValue());
}
}
Expand All @@ -333,7 +349,8 @@ DoFn<InputT, OutputT> transformedDoFn(
(Class<? extends DoFn<KV<K, V>, OutputT>>) doFn.getClass();

ClassLoadingStrategy<ClassLoader> strategy = ByteBuddyUtils.getClassLoadingStrategy(doFnClass);
final String className = doFnClass.getName() + "$Expanded";
final String className =
doFnClass.getName() + "$Expanded" + Objects.hash(stateWriteInstant, nextFlushInstantFn);
final ClassLoader classLoader = ExternalStateExpander.class.getClassLoader();
try {
@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -481,7 +498,7 @@ Builder<DoFn<InputT, OutputT>> addProcessingMethods(
inputType,
keyCoder,
mainTag,
stateValueTupleTag,
STATE_TUPLE_TAG,
outputType,
nextFlushInstantFn,
builder);
Expand Down Expand Up @@ -820,6 +837,7 @@ public void intercept(

private static class TimerFlushInterceptor<K, V> {

private final DoFn<KV<K, V>, ?> doFn;
private final LinkedHashMap<String, BiFunction<Object, byte[], Iterable<StateValue>>>
stateReaders;
private final Method processElementMethod;
Expand All @@ -836,6 +854,7 @@ private static class TimerFlushInterceptor<K, V> {
TupleTag<StateValue> stateTag,
UnaryFunction<Instant, Instant> nextFlushInstantFn) {

this.doFn = doFn;
this.stateReaders = getStateReaders(doFn);
this.processElementMethod = processElementMethod;
this.expander = expander;
Expand All @@ -850,20 +869,25 @@ public void intercept(@This DoFn<KV<V, StateOrInput<V>>, ?> doFn, @AllArguments
@SuppressWarnings("unchecked")
K key = (K) args[args.length - 5];
@SuppressWarnings("unchecked")
ValueState<Instant> lastFlushState = (ValueState<Instant>) args[args.length - 3];
ValueState<Instant> nextFlushState = (ValueState<Instant>) args[args.length - 3];
Timer flushTimer = (Timer) args[args.length - 4];
Instant nextFlush = nextFlushInstantFn.apply(now);
Instant lastFlush = lastFlushState.read();
Instant lastFlush = nextFlushState.read();
if (nextFlush != null && nextFlush.isBefore(BoundedWindow.TIMESTAMP_MAX_VALUE)) {
flushTimer.set(nextFlush);
lastFlushState.write(nextFlush);
nextFlushState.write(nextFlush);
}
if (lastFlush == null) {
// FIXME: flush the buffer to processFn
BagState<TimestampedValue<KV<K, V>>> bufState =
(BagState<TimestampedValue<KV<K, V>>>) args[args.length - 2];
bufState.read().forEach(kv -> System.err.println(" *** flushing state " + kv));
} else {
BagState<TimestampedValue<KV<K, V>>> bufState =
(BagState<TimestampedValue<KV<K, V>>>) args[args.length - 2];
bufState
.read()
.forEach(
kv -> {
Object[] processArgs = expander.getProcessElementArgs(kv.getValue(), args);
ExceptionUtils.unchecked(() -> processElementMethod.invoke(this.doFn, processArgs));
});
bufState.clear();
if (lastFlush != null) {
MultiOutputReceiver outputReceiver = (MultiOutputReceiver) args[args.length - 1];
OutputReceiver<StateValue> output = outputReceiver.get(stateTag);
byte[] keyBytes =
Expand All @@ -883,4 +907,6 @@ static Generic bagStateFromInputType(ParameterizedType inputType) {
Generic.Builder.parameterizedType(TimestampedValue.class, inputType).build())
.build();
}

private static class StateTupleTag extends TupleTag<StateValue> {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,6 @@ private static UnaryFunction<Object[], Boolean> createProcessFn(
boolean shouldBuffer = nextFlush == null /* we have not finished reading state */
// FIXME: || nextFlush.isBefore(/* timestamp */)
;
System.err.println(
" *** process "
+ elem
+ ", "
+ shouldBuffer
+ " , at "
+ flushTimer.getCurrentRelativeTime());
if (shouldBuffer) {
// store to state
@SuppressWarnings("unchecked")
Expand Down
Loading

0 comments on commit ace470d

Please sign in to comment.