From c079a760176556d444f94bb0a65a0ddcbd276cf5 Mon Sep 17 00:00:00 2001 From: Jan Lukavsky Date: Tue, 26 Nov 2024 09:43:28 +0100 Subject: [PATCH] [proxima-beam-core] debugging external state expander --- .github/workflows/build.yml | 2 +- .../beam/core/ProximaPipelineOptions.java | 54 +++++++ .../beam/core/direct/io/BatchLogRead.java | 22 +-- .../cz/o2/proxima/beam/util/RunnerUtils.java | 18 ++- .../beam/util/state/ExcludeExternal.java | 33 +++++ .../beam/util/state/ExpandContext.java | 127 ++++++++++------- .../util/state/ExternalStateExpander.java | 4 +- .../beam/util/state/FieldExtractor.java | 133 ++++++++++++++++++ .../state/FlushTimerParameterExpander.java | 42 ++++-- .../beam/util/state/MethodCallUtils.java | 49 +++++++ .../util/state/OnWindowParameterExpander.java | 8 +- .../cz/o2/proxima/beam/util/state/TypeId.java | 4 + .../beam/core/direct/io/BatchLogReadTest.java | 3 +- .../util/state/ExternalStateExpanderTest.java | 113 ++++++++++++++- .../beam/util/state/FieldExtractorTest.java | 83 +++++++++++ .../beam/tools/groovy/BeamStreamProvider.java | 7 +- 16 files changed, 600 insertions(+), 102 deletions(-) create mode 100644 beam/core/src/main/java/cz/o2/proxima/beam/core/ProximaPipelineOptions.java create mode 100644 beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExcludeExternal.java create mode 100644 beam/core/src/main/java/cz/o2/proxima/beam/util/state/FieldExtractor.java create mode 100644 beam/core/src/test/java/cz/o2/proxima/beam/util/state/FieldExtractorTest.java diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 85feb0447..20b25c03d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -32,7 +32,7 @@ jobs: with: fetch-depth: 0 - name: Cache maven repository - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: | ~/.m2/repository diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/core/ProximaPipelineOptions.java b/beam/core/src/main/java/cz/o2/proxima/beam/core/ProximaPipelineOptions.java new file mode 100644 index 000000000..28f8c6df9 --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/core/ProximaPipelineOptions.java @@ -0,0 +1,54 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.core; + +import com.google.auto.service.AutoService; +import java.util.Collections; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsRegistrar; +import org.checkerframework.checker.nullness.qual.Nullable; + +public interface ProximaPipelineOptions extends PipelineOptions { + + @AutoService(PipelineOptionsRegistrar.class) + class Factory implements PipelineOptionsRegistrar { + + @Override + public Iterable> getPipelineOptions() { + return Collections.singletonList(ProximaPipelineOptions.class); + } + } + + /** + * @{code false} to preserve UDF jar on exit. + */ + @Default.Boolean(true) + boolean getPreserveUDFJar(); + + void setPreserveUDFJar(boolean delete); + + /** Set directory where to store generated UDF jar(s). */ + @Nullable String getUdfJarDirPath(); + + void setUdfJarDirPath(String path); + + /** Set delay for {@BatchLogRead} in ms. */ + @Default.Long(0L) + long getStartBatchReadDelayMs(); + + void setStartBatchReadDelayMs(long readDelayMs); +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/core/direct/io/BatchLogRead.java b/beam/core/src/main/java/cz/o2/proxima/beam/core/direct/io/BatchLogRead.java index fbbc92d25..71185939b 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/core/direct/io/BatchLogRead.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/core/direct/io/BatchLogRead.java @@ -15,7 +15,7 @@ */ package cz.o2.proxima.beam.core.direct.io; -import com.google.auto.service.AutoService; +import cz.o2.proxima.beam.core.ProximaPipelineOptions; import cz.o2.proxima.beam.core.direct.io.BatchRestrictionTracker.PartitionList; import cz.o2.proxima.core.repository.AttributeDescriptor; import cz.o2.proxima.core.repository.Repository; @@ -42,9 +42,6 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.coders.SerializableCoder; -import org.apache.beam.sdk.options.Default; -import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.options.PipelineOptionsRegistrar; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.DoFn.BoundedPerElement; import org.apache.beam.sdk.transforms.Impulse; @@ -64,21 +61,6 @@ @Slf4j public class BatchLogRead extends PTransform> { - public interface BatchLogReadPipelineOptions extends PipelineOptions { - @Default.Long(0L) - long getStartBatchReadDelayMs(); - - void setStartBatchReadDelayMs(long readDelayMs); - } - - @AutoService(PipelineOptionsRegistrar.class) - public static class BatchLogReadOptionsFactory implements PipelineOptionsRegistrar { - @Override - public Iterable> getPipelineOptions() { - return Collections.singletonList(BatchLogReadPipelineOptions.class); - } - } - /** * Create the {@link BatchLogRead} transform that reads from {@link BatchLogReader} in batch * manner. @@ -351,7 +333,7 @@ public PCollection expand(PBegin input) { input .getPipeline() .getOptions() - .as(BatchLogReadPipelineOptions.class) + .as(ProximaPipelineOptions.class) .getStartBatchReadDelayMs(); return input .apply(Impulse.create()) diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/RunnerUtils.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/RunnerUtils.java index b97f6e345..600939590 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/RunnerUtils.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/RunnerUtils.java @@ -15,6 +15,8 @@ */ package cz.o2.proxima.beam.util; +import com.google.common.base.Strings; +import cz.o2.proxima.beam.core.ProximaPipelineOptions; import cz.o2.proxima.core.annotations.Internal; import cz.o2.proxima.core.util.ExceptionUtils; import java.io.ByteArrayInputStream; @@ -26,6 +28,7 @@ import java.net.URLClassLoader; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; @@ -87,11 +90,18 @@ public static void injectJarIntoContextClassLoader(Collection paths) { * @param classes map of class to bytecode * @return generated {@link File} */ - public static File createJarFromDynamicClasses(Map, byte[]> classes) - throws IOException { - Path tempFile = Files.createTempFile("proxima-beam-dynamic", ".jar"); + public static File createJarFromDynamicClasses( + ProximaPipelineOptions opts, Map, byte[]> classes) throws IOException { + + final Path tempFile = + Strings.isNullOrEmpty(opts.getUdfJarDirPath()) + ? Files.createTempFile("proxima-beam-dynamic", ".jar") + : Files.createTempFile( + Paths.get(opts.getUdfJarDirPath()), "proxima-beam-dynamic", ".jar"); File out = tempFile.toFile(); - out.deleteOnExit(); + if (!opts.getPreserveUDFJar()) { + out.deleteOnExit(); + } try (JarOutputStream output = new JarOutputStream(new FileOutputStream(out))) { long now = System.currentTimeMillis(); for (Map.Entry, byte[]> e : classes.entrySet()) { diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExcludeExternal.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExcludeExternal.java new file mode 100644 index 000000000..153ebf56d --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExcludeExternal.java @@ -0,0 +1,33 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import cz.o2.proxima.core.annotations.Experimental; +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation for {@link org.apache.beam.sdk.transforms.DoFn DoFns} that should be excluded from the + * external state expansion. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ElementType.TYPE, ElementType.METHOD}) +@Documented +@Experimental +public @interface ExcludeExternal {} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExpandContext.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExpandContext.java index 4169a36f4..8487d7a5c 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExpandContext.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExpandContext.java @@ -19,6 +19,7 @@ import static cz.o2.proxima.beam.util.state.MethodCallUtils.getWrapperInputType; import cz.o2.proxima.beam.util.state.MethodCallUtils.VoidMethodInvoker; +import cz.o2.proxima.core.annotations.Internal; import cz.o2.proxima.core.functional.BiFunction; import cz.o2.proxima.core.functional.UnaryFunction; import cz.o2.proxima.core.util.ExceptionUtils; @@ -96,6 +97,7 @@ import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.ParDo.MultiOutput; import org.apache.beam.sdk.transforms.WithKeys; import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; @@ -122,7 +124,8 @@ import org.joda.time.Instant; @Slf4j -class ExpandContext { +@Internal +public class ExpandContext { static final String EXPANDER_BUF_STATE_SPEC = "expanderBufStateSpec"; static final String EXPANDER_BUF_STATE_NAME = "_expanderBuf"; @@ -140,7 +143,7 @@ class ExpandContext { private static class StateTupleTag extends TupleTag {} @Getter - class DoFnExpandContext { + public class DoFnExpandContext { private final DoFn, ?> doFn; private final TupleTag mainTag; @@ -164,6 +167,10 @@ class DoFnExpandContext { private final OnWindowExpirationInterceptor onWindowExpirationInterceptor; private final ProcessElementInterceptor processElementInterceptor; + private final StateSpec>>> bufState; + private final StateSpec> flushState = StateSpecs.value(); + private final TimerSpec flushTimer = TimerSpecs.timer(TimeDomain.EVENT_TIME); + @SuppressWarnings("unchecked") DoFnExpandContext(DoFn, ?> doFn, KvCoder inputCoder, TupleTag mainTag) throws InvocationTargetException, @@ -224,6 +231,7 @@ class DoFnExpandContext { processElementInvoker, onWindowInvoker, onWindowExpander); this.processElementInterceptor = new ProcessElementInterceptor<>(processElementExpander, processElementInvoker); + this.bufState = StateSpecs.bag(TimestampedValueCoder.of(inputCoder)); } } @@ -318,10 +326,27 @@ private PTransformOverride statefulParMultiDoOverride( PCollection> inputs) { return PTransformOverride.of( - application -> application.getTransform() instanceof ParDo.MultiOutput, + application -> { + if (application.getTransform() instanceof ParDo.MultiOutput) { + ParDo.MultiOutput transform = (MultiOutput) application.getTransform(); + if (DoFnSignatures.isStateful(transform.getFn())) { + return shouldExpand(transform.getFn().getClass()); + } + } + return false; + }, parMultiDoReplacementFactory(inputs)); } + private boolean shouldExpand(Class cls) { + if (Arrays.stream(cls.getDeclaredMethods()) + .filter(m -> m.getAnnotation(DoFn.ProcessElement.class) != null) + .anyMatch(m -> m.getAnnotation(ExcludeExternal.class) != null)) { + return false; + } + return cls.getAnnotation(ExcludeExternal.class) == null; + } + @SuppressWarnings({"unchecked", "rawtypes"}) private PTransformOverrideFactory parMultiDoReplacementFactory( PCollection> inputs) { @@ -348,9 +373,6 @@ private PTransformReplacement replaceParMultiDo( (ParDo.MultiOutput) (PTransform) transform.getTransform(); DoFn, ?> doFn = (DoFn) rawTransform.getFn(); PInput pMainInput = getMainInput(transform); - if (!DoFnSignatures.isStateful(doFn)) { - return PTransformReplacement.of(pMainInput, (PTransform) transform.getTransform()); - } String transformName = asStableTransformName(transform.getFullName()); PCollection transformInputs = inputs @@ -459,27 +481,13 @@ private >> DoFn trans builder = builder .defineConstructor(Visibility.PUBLIC) - .withParameters( - doFnClass, - context.getFlushTimerInterceptor().getClass(), - context.getOnWindowExpirationInterceptor().getClass(), - context.getProcessElementInterceptor().getClass()) + .withParameters(doFnClass, context.getClass()) .intercept( addStateAndTimerValues( doFn, - inputCoder, MethodCall.invoke( ExceptionUtils.uncheckedFactory(() -> DoFn.class.getConstructor())) - .andThen(FieldAccessor.ofField(DELEGATE_FIELD_NAME).setsArgumentAt(0)) - .andThen( - FieldAccessor.ofField(FLUSH_TIMER_INTERCEPTOR_FIELD_NAME) - .setsArgumentAt(1)) - .andThen( - FieldAccessor.ofField(ON_WINDOW_INTERCEPTOR_FIELD_NAME) - .setsArgumentAt(2)) - .andThen( - FieldAccessor.ofField(PROCESS_ELEMENT_INTERCEPTOR_FIELD_NAME) - .setsArgumentAt(3)))); + .andThen(FieldAccessor.ofField(DELEGATE_FIELD_NAME).setsArgumentAt(0)))); builder = addProcessingMethods(context, builder); builder = implementDoFnProvider(builder); @@ -504,19 +512,10 @@ private >, V, K> DoFn newIn try { Constructor> ctor = - cls.getDeclaredConstructor( - context.getDoFnClass(), - context.getFlushTimerInterceptor().getClass(), - context.getOnWindowExpirationInterceptor().getClass(), - context.getProcessElementInterceptor().getClass()); + cls.getDeclaredConstructor(context.getDoFnClass(), context.getClass()); @SuppressWarnings("unchecked") DoFn instance = - (DoFn) - ctor.newInstance( - context.getDoFn(), - context.getFlushTimerInterceptor(), - context.getOnWindowExpirationInterceptor(), - context.getProcessElementInterceptor()); + (DoFn) ctor.newInstance(context.getDoFn(), context); return instance; } catch (Exception ex) { throw new IllegalStateException(String.format("Cannot instantiate class %s", cls), ex); @@ -549,30 +548,50 @@ private >, V, K> DoFn newIn } private static , OutputT> Implementation addStateAndTimerValues( - DoFn doFn, Coder> inputCoder, Composable delegate) { + DoFn doFn, Composable delegate) { List> acceptable = Arrays.asList(StateId.class, TimerId.class); @SuppressWarnings("unchecked") Class> doFnClass = (Class>) doFn.getClass(); - for (Field f : doFnClass.getDeclaredFields()) { - if (!Modifier.isStatic(f.getModifiers()) - && acceptable.stream().anyMatch(a -> f.getAnnotation(a) != null)) { - f.setAccessible(true); - Object value = ExceptionUtils.uncheckedFactory(() -> f.get(doFn)); - delegate = delegate.andThen(FieldAccessor.ofField(f.getName()).setsValue(value)); - } + try { + return delegate + .andThen( + MethodCall.invoke( + DoFnExpandContext.class.getDeclaredMethod("getFlushTimerInterceptor")) + .onArgument(1) + .setsField(m -> m.getName().equals(FLUSH_TIMER_INTERCEPTOR_FIELD_NAME))) + .andThen( + MethodCall.invoke( + DoFnExpandContext.class.getDeclaredMethod("getOnWindowExpirationInterceptor")) + .onArgument(1) + .setsField(m -> m.getName().equals(ON_WINDOW_INTERCEPTOR_FIELD_NAME))) + .andThen( + MethodCall.invoke( + DoFnExpandContext.class.getDeclaredMethod("getProcessElementInterceptor")) + .onArgument(1) + .setsField(m -> m.getName().equals(PROCESS_ELEMENT_INTERCEPTOR_FIELD_NAME))) + .andThen( + MethodCall.invoke(DoFnExpandContext.class.getDeclaredMethod("getBufState")) + .onArgument(1) + .setsField(m -> m.getName().equals(EXPANDER_BUF_STATE_SPEC))) + .andThen( + MethodCall.invoke(DoFnExpandContext.class.getDeclaredMethod("getFlushState")) + .onArgument(1) + .setsField(m -> m.getName().equals(EXPANDER_FLUSH_STATE_SPEC))) + .andThen( + MethodCall.invoke(DoFnExpandContext.class.getDeclaredMethod("getFlushTimer")) + .onArgument(1) + .setsField(m -> m.getName().equals(EXPANDER_TIMER_SPEC))) + .andThen( + new FieldExtractor( + doFnClass, + f -> + !Modifier.isStatic(f.getModifiers()) + && acceptable.stream().anyMatch(a -> f.getAnnotation(a) != null))); + } catch (Exception ex) { + throw new IllegalStateException(ex); } - delegate = - delegate - .andThen( - FieldAccessor.ofField(EXPANDER_BUF_STATE_SPEC) - .setsValue(StateSpecs.bag(TimestampedValueCoder.of(inputCoder)))) - .andThen(FieldAccessor.ofField(EXPANDER_FLUSH_STATE_SPEC).setsValue(StateSpecs.value())) - .andThen( - FieldAccessor.ofField(EXPANDER_TIMER_SPEC) - .setsValue(TimerSpecs.timer(TimeDomain.EVENT_TIME))); - return delegate; } @SuppressWarnings({"unchecked", "rawtypes"}) @@ -719,8 +738,10 @@ private static ParameterizedType getParameterizedDoFn( methodDefinition = methodDefinition.annotateParameter(i, paramAnnotation); } } - return methodDefinition.annotateMethod( - AnnotationDescription.Builder.ofType(annotation).build()); + for (Annotation a : method.getDeclaredAnnotations()) { + methodDefinition = methodDefinition.annotateMethod(a); + } + return methodDefinition; } return builder; } diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java index 35199e595..f17e5005d 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/ExternalStateExpander.java @@ -15,6 +15,7 @@ */ package cz.o2.proxima.beam.util.state; +import cz.o2.proxima.beam.core.ProximaPipelineOptions; import cz.o2.proxima.beam.util.RunnerUtils; import cz.o2.proxima.core.functional.UnaryFunction; import java.io.File; @@ -52,7 +53,8 @@ public static Pipeline expand( ExpandContext context = new ExpandContext(inputs, stateWriteInstant, nextFlushInstantFn, stateSink); Pipeline expanded = context.expand(pipeline); - File dynamicJar = RunnerUtils.createJarFromDynamicClasses(context.getGeneratedClasses()); + ProximaPipelineOptions opts = expanded.getOptions().as(ProximaPipelineOptions.class); + File dynamicJar = RunnerUtils.createJarFromDynamicClasses(opts, context.getGeneratedClasses()); RunnerUtils.registerToPipeline( expanded.getOptions(), expanded.getOptions().getRunner().getSimpleName(), diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FieldExtractor.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FieldExtractor.java new file mode 100644 index 000000000..4b20c2e2b --- /dev/null +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FieldExtractor.java @@ -0,0 +1,133 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import java.lang.reflect.Field; +import java.lang.reflect.Modifier; +import java.util.function.Predicate; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.dynamic.scaffold.InstrumentedType; +import net.bytebuddy.implementation.Implementation; +import net.bytebuddy.implementation.bytecode.ByteCodeAppender; +import net.bytebuddy.jar.asm.MethodVisitor; +import net.bytebuddy.jar.asm.Opcodes; + +public class FieldExtractor implements Implementation, ByteCodeAppender { + + private final Class sourceClass; + private final Predicate includePredicate; + + public FieldExtractor(Class sourceClass) { + this(sourceClass, f -> !Modifier.isStatic(f.getModifiers())); + } + + public FieldExtractor(Class sourceClass, Predicate includePredicate) { + this.sourceClass = sourceClass; + this.includePredicate = includePredicate; + } + + @Override + public Size apply( + MethodVisitor methodVisitor, + Context implementationContext, + MethodDescription instrumentedMethod) { + + String targetClass = implementationContext.getInstrumentedType().getInternalName(); + + // For each field in SourceClass, copy the value to the generated class + for (Field field : sourceClass.getDeclaredFields()) { + if (includePredicate.test(field)) { + field.setAccessible(true); + + // Load `this` onto the stack for setting the field + methodVisitor.visitVarInsn(Opcodes.ALOAD, 0); + + // Use reflection to get the Field object from SourceClass + // Step 1: Load the class name of source class onto the stack + methodVisitor.visitLdcInsn(TypeDescription.ForLoadedType.of(sourceClass).getName()); + + // Step 2: Invoke Class.forName(String) to get the Class object of SourceClass + methodVisitor.visitMethodInsn( + Opcodes.INVOKESTATIC, + "java/lang/Class", + "forName", + "(Ljava/lang/String;)Ljava/lang/Class;", + false); + + // Step 3: Load the field name as a constant onto the stack + methodVisitor.visitLdcInsn(field.getName()); + + // Step 4: Invoke getDeclaredField(String) on the Class object to get the Field object + methodVisitor.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + "java/lang/Class", + "getDeclaredField", + "(Ljava/lang/String;)Ljava/lang/reflect/Field;", + false); + + // Step 5: Duplicate the Field object on the stack (so we have two copies of it) + methodVisitor.visitInsn(Opcodes.DUP); + + // Step 6: Set the Field accessible (iconst_1 for true) + methodVisitor.visitInsn(Opcodes.ICONST_1); + methodVisitor.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, "java/lang/reflect/Field", "setAccessible", "(Z)V", false); + + // Step 7: Load the source instance to get the field value + methodVisitor.visitVarInsn(Opcodes.ALOAD, 1); + + // Step 8: Invoke Field.get(Object) to retrieve the field value from the source instance + methodVisitor.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + "java/lang/reflect/Field", + "get", + "(Ljava/lang/Object;)Ljava/lang/Object;", + false); + + // Step 9: Cast the value to the correct field type if necessary + if (!field.getType().isPrimitive()) { + // Use CHECKCAST to cast the returned Object to the correct type + methodVisitor.visitTypeInsn( + Opcodes.CHECKCAST, + TypeDescription.ForLoadedType.of(field.getType()).getInternalName()); + } + + // Step 10: Set the value to the corresponding field in the generated class + methodVisitor.visitFieldInsn( + Opcodes.PUTFIELD, + targetClass, + field.getName(), + TypeDescription.ForLoadedType.of(field.getType()).getDescriptor()); + } + } + + // Add the return statement to end the constructor properly + methodVisitor.visitInsn(Opcodes.RETURN); + + return new ByteCodeAppender.Size(4, instrumentedMethod.getStackSize()); + } + + @Override + public ByteCodeAppender appender(Target target) { + return this; + } + + @Override + public InstrumentedType prepare(InstrumentedType instrumentedType) { + return instrumentedType; + } +} diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java index 54a9a2a6b..ac679a9d3 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/FlushTimerParameterExpander.java @@ -40,6 +40,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; +import org.apache.flink.shaded.curator5.com.google.common.collect.Streams; import org.joda.time.Instant; interface FlushTimerParameterExpander extends Serializable { @@ -59,23 +60,12 @@ static FlushTimerParameterExpander of( createWrapperArgs(DoFn doFn, ParameterizedType inputType) { List> states = - Arrays.stream(doFn.getClass().getDeclaredFields()) - .filter(f -> f.getAnnotation(DoFn.StateId.class) != null) - .map( - f -> { - Preconditions.checkArgument( - f.getGenericType() instanceof ParameterizedType, - "Field %s has invalid type %s", - f.getName(), - f.getGenericType()); - return Pair.of( - (Annotation) f.getAnnotation(DoFn.StateId.class), - ((ParameterizedType) f.getGenericType()).getActualTypeArguments()[0]); - }) - .collect(Collectors.toList()); + collectFieldsWithAnnotation(doFn, DoFn.StateId.class, true); + List> timers = + collectFieldsWithAnnotation(doFn, DoFn.TimerId.class, false); List> types = - states.stream() + Streams.concat(states.stream(), timers.stream()) .map( p -> Pair.of( @@ -122,6 +112,28 @@ static FlushTimerParameterExpander of( return res; } + private static List> collectFieldsWithAnnotation( + DoFn doFn, Class annotation, boolean isState) { + + return Arrays.stream(doFn.getClass().getDeclaredFields()) + .filter(f -> f.getAnnotation(annotation) != null) + .map( + f -> { + if (isState) { + Preconditions.checkArgument( + f.getGenericType() instanceof ParameterizedType, + "Field %s has invalid type %s", + f.getName(), + f.getGenericType()); + return Pair.of( + (Annotation) f.getAnnotation(annotation), + ((ParameterizedType) f.getGenericType()).getActualTypeArguments()[0]); + } + return Pair.of((Annotation) f.getAnnotation(annotation), (Type) Timer.class); + }) + .collect(Collectors.toList()); + } + /** * Get arguments that must be declared by wrapper's call for both {@code @}ProcessElement and * {@code @}OnWindowExpiration be callable. diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java index 8d2463da7..ee6cc02c1 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/MethodCallUtils.java @@ -74,6 +74,7 @@ import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.StateBinder; import org.apache.beam.sdk.state.StateSpec; +import org.apache.beam.sdk.state.Timer; import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.state.WatermarkHoldState; import org.apache.beam.sdk.transforms.Combine.CombineFn; @@ -89,6 +90,7 @@ import org.apache.beam.sdk.values.TimestampedValue; import org.apache.beam.sdk.values.TupleTag; import org.checkerframework.checker.nullness.qual.Nullable; +import org.joda.time.Duration; import org.joda.time.Instant; @Slf4j @@ -204,6 +206,8 @@ private static void addUnknownWrapperArgument( res.add( (args, elem) -> singleOutput((MultiOutputReceiver) args[wrapperPos], elem, mainTag)); } + } else if (typeId.isTimer()) { + res.add((args, elem) -> asTimer(elem)); } else { throw new IllegalStateException( String.format( @@ -211,6 +215,51 @@ private static void addUnknownWrapperArgument( } } + private static Timer asTimer(TimestampedValue> elem) { + return new Timer() { + + @Override + public void set(Instant absoluteTime) { + // nop + } + + @Override + public void setRelative() { + // nop + } + + @Override + public void clear() { + // nop + } + + @Override + public Timer offset(Duration offset) { + return this; + } + + @Override + public Timer align(Duration period) { + return this; + } + + @Override + public Timer withOutputTimestamp(Instant outputTime) { + return this; + } + + @Override + public Timer withNoOutputTimestamp() { + return this; + } + + @Override + public Instant getCurrentRelativeTime() { + return elem.getTimestamp(); + } + }; + } + private static DoFn.MultiOutputReceiver remapTimestampIfNeeded( MultiOutputReceiver parent, @Nullable TimestampedValue> elem) { diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java index 86530808f..f7c5dc036 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/OnWindowParameterExpander.java @@ -18,18 +18,17 @@ import static cz.o2.proxima.beam.util.state.ExpandContext.bagStateFromInputType; import static cz.o2.proxima.beam.util.state.MethodCallUtils.*; +import com.google.common.collect.Streams; import cz.o2.proxima.core.functional.BiFunction; import cz.o2.proxima.core.util.Pair; import cz.o2.proxima.internal.com.google.common.base.MoreObjects; import cz.o2.proxima.internal.com.google.common.base.Preconditions; -import cz.o2.proxima.internal.com.google.common.collect.Sets; import java.io.Serializable; import java.lang.annotation.Annotation; import java.lang.reflect.Method; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.util.ArrayList; -import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Optional; @@ -75,7 +74,10 @@ static List> createWrapperArgList( LinkedHashMap> processArgs, LinkedHashMap> onWindowArgs) { - Set union = new HashSet<>(Sets.union(processArgs.keySet(), onWindowArgs.keySet())); + Set union = + Streams.concat(processArgs.keySet().stream(), onWindowArgs.keySet().stream()) + .filter(t -> !t.isTimer()) + .collect(Collectors.toSet()); // @Element is not supported by @OnWindowExpiration union.remove(TypeId.of(AnnotationDescription.Builder.ofType(DoFn.Element.class).build())); return union.stream() diff --git a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java index 2b8610638..003a72a09 100644 --- a/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java +++ b/beam/core/src/main/java/cz/o2/proxima/beam/util/state/TypeId.java @@ -113,4 +113,8 @@ public boolean isOutput(Type outputType) { public boolean isMultiOutput() { return equals(MULTI_OUTPUT_TYPE); } + + public boolean isTimer() { + return stringId.contains(DoFn.TimerId.class.getName()); + } } diff --git a/beam/core/src/test/java/cz/o2/proxima/beam/core/direct/io/BatchLogReadTest.java b/beam/core/src/test/java/cz/o2/proxima/beam/core/direct/io/BatchLogReadTest.java index 634f0b137..a55c69682 100644 --- a/beam/core/src/test/java/cz/o2/proxima/beam/core/direct/io/BatchLogReadTest.java +++ b/beam/core/src/test/java/cz/o2/proxima/beam/core/direct/io/BatchLogReadTest.java @@ -18,6 +18,7 @@ import static org.junit.Assert.*; import static org.mockito.Mockito.mock; +import cz.o2.proxima.beam.core.ProximaPipelineOptions; import cz.o2.proxima.beam.core.direct.io.BatchLogRead.BatchLogReadFn; import cz.o2.proxima.beam.core.direct.io.BatchRestrictionTracker.PartitionList; import cz.o2.proxima.core.repository.AttributeDescriptor; @@ -119,7 +120,7 @@ public void testBatchLogReadWithDelay() { List input = createInput(numElements); ListBatchReader reader = ListBatchReader.of(direct.getContext(), input); PipelineOptions opts = PipelineOptionsFactory.create(); - opts.as(BatchLogRead.BatchLogReadPipelineOptions.class).setStartBatchReadDelayMs(1000L); + opts.as(ProximaPipelineOptions.class).setStartBatchReadDelayMs(1000L); Pipeline p = createPipeline(opts); testReadingFromBatchLogMany( p, diff --git a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java index 0ef3e66d3..8223337be 100644 --- a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java +++ b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/ExternalStateExpanderTest.java @@ -16,7 +16,7 @@ package cz.o2.proxima.beam.util.state; import static com.mongodb.internal.connection.tlschannel.util.Util.assertTrue; -import static org.junit.Assert.assertEquals; +import static org.junit.Assert.*; import com.google.common.base.MoreObjects; import com.google.common.base.Preconditions; @@ -32,6 +32,7 @@ import org.apache.beam.runners.direct.DirectRunner; import org.apache.beam.runners.flink.FlinkRunner; import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.Pipeline.PipelineVisitor.Defaults; import org.apache.beam.sdk.PipelineRunner; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -39,8 +40,13 @@ import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.runners.TransformHierarchy; import org.apache.beam.sdk.state.StateSpec; import org.apache.beam.sdk.state.StateSpecs; +import org.apache.beam.sdk.state.TimeDomain; +import org.apache.beam.sdk.state.Timer; +import org.apache.beam.sdk.state.TimerSpec; +import org.apache.beam.sdk.state.TimerSpecs; import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestStream; @@ -49,8 +55,10 @@ import org.apache.beam.sdk.transforms.MapElements; import org.apache.beam.sdk.transforms.PTransform; import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.ParDo.MultiOutput; import org.apache.beam.sdk.transforms.Reify; import org.apache.beam.sdk.transforms.WithKeys; +import org.apache.beam.sdk.transforms.reflect.DoFnSignatures; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.values.KV; @@ -98,6 +106,33 @@ public void testSimpleExpand() throws IOException { expanded.run(); } + @Test + public void testWithTimer() throws IOException { + Pipeline pipeline = createPipeline(); + PCollection inputs = + pipeline.apply( + Create.timestamped( + TimestampedValue.of("1", new Instant(1)), + TimestampedValue.of("2", new Instant(1)))); + PCollection> withKeys = + inputs.apply( + WithKeys.of(e -> Integer.parseInt(e) % 2) + .withKeyType(TypeDescriptors.integers())); + PCollection count = withKeys.apply(ParDo.of(getTimerFn())); + PAssert.that(count) + .containsInAnyOrder( + BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis(), + BoundedWindow.TIMESTAMP_MAX_VALUE.getMillis()); + Pipeline expanded = + ExternalStateExpander.expand( + pipeline, + Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), + new Instant(0), + ign -> BoundedWindow.TIMESTAMP_MAX_VALUE, + dummy()); + expanded.run(); + } + @Test public void testSimpleExpandMultiOutput() throws IOException { Pipeline pipeline = createPipeline(); @@ -260,6 +295,20 @@ public void testStateWithElementEarly() throws IOException { (long) CoderUtils.decodeFromByteArray( VarLongCoder.of(), first.getValue().getValue().getValue())); + + // check that we did expand + expanded.traverseTopologically( + new Defaults() { + @Override + public void visitPrimitiveTransform(TransformHierarchy.Node node) { + if (node.getTransform() instanceof ParDo.MultiOutput) { + ParDo.MultiOutput transform = (MultiOutput) node.getTransform(); + assertTrue( + !DoFnSignatures.isStateful(transform.getFn()) + || transform.getFn().getClass().getName().contains("$Expanded")); + } + } + }); } @Test @@ -272,6 +321,48 @@ public void testBufferedTimestampInjectToMultiOutput() throws IOException { testTimestampInject(true); } + @Test + public void testExpansionExclusion() throws IOException { + Pipeline pipeline = createPipeline(); + PCollection input = pipeline.apply(Create.of("1", "2", "3")); + input + .apply(WithKeys.of(1)) + .apply( + "process", + ParDo.of( + new DoFn, Long>() { + @TimerId("timer") + private final TimerSpec timer = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @ExcludeExternal + @ProcessElement + public void process(@Element KV in) {} + + @OnTimer("timer") + public void timer(@TimerId("timer") Timer timer) {} + })); + Instant now = Instant.now(); + Pipeline expanded = + ExternalStateExpander.expand( + pipeline, + Create.empty(KvCoder.of(StringUtf8Coder.of(), StateValue.coder())), + now, + current -> BoundedWindow.TIMESTAMP_MAX_VALUE, + dummy()); + + // check that we didn't expand + expanded.traverseTopologically( + new Defaults() { + @Override + public void visitPrimitiveTransform(TransformHierarchy.Node node) { + if (node.getTransform() instanceof ParDo.MultiOutput) { + ParDo.MultiOutput transform = (MultiOutput) node.getTransform(); + assertFalse(transform.getFn().getClass().getName().contains("$Expanded")); + } + } + }); + } + private void testTimestampInject(boolean multiOutput) throws IOException { Pipeline pipeline = createPipeline(); Instant now = new Instant(0); @@ -429,4 +520,24 @@ public PDone expand(PCollection> input) { } }; } + + private static DoFn, Long> getTimerFn() { + return new DoFn, Long>() { + @TimerId("timer") + private final TimerSpec timerSpec = TimerSpecs.timer(TimeDomain.EVENT_TIME); + + @RequiresStableInput + @ProcessElement + public void process( + @Element KV elem, + @TimerId("timer") Timer timer, + OutputReceiver output) { + + output.output(timer.getCurrentRelativeTime().getMillis()); + } + + @OnTimer("timer") + public void timer(@TimerId("timer") Timer timer, OutputReceiver output) {} + }; + } } diff --git a/beam/core/src/test/java/cz/o2/proxima/beam/util/state/FieldExtractorTest.java b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/FieldExtractorTest.java new file mode 100644 index 000000000..909b7294f --- /dev/null +++ b/beam/core/src/test/java/cz/o2/proxima/beam/util/state/FieldExtractorTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2017-2024 O2 Czech Republic, a.s. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package cz.o2.proxima.beam.util.state; + +import static org.junit.Assert.assertEquals; + +import cz.o2.proxima.core.util.ExceptionUtils; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import net.bytebuddy.ByteBuddy; +import net.bytebuddy.description.modifier.FieldManifestation; +import net.bytebuddy.description.modifier.Visibility; +import net.bytebuddy.dynamic.loading.ClassLoadingStrategy; +import net.bytebuddy.implementation.Implementation; +import net.bytebuddy.implementation.MethodCall; +import org.junit.Test; + +public class FieldExtractorTest { + + // Define the source class which we will use to provide field values + public static class SourceClass { + private final String name; + private final Integer value; + + public SourceClass(String name, int value) { + this.name = name; + this.value = value; + } + } + + @Test + public void test() throws Exception { + + // Instantiate the source object whose fields will be copied + SourceClass sourceInstance = new SourceClass("TestName", 42); + + // Use Byte Buddy to generate a dynamic class + Class dynamicClass; + dynamicClass = + new ByteBuddy() + .subclass(Object.class) + .name(SourceClass.class.getPackageName() + ".GeneratedClass") + + // Define fields to match the source instance fields + .defineField("name", String.class, Visibility.PRIVATE, FieldManifestation.FINAL) + .defineField("value", Integer.class, Visibility.PRIVATE, FieldManifestation.FINAL) + + // Define a constructor that takes an instance of SourceClass + .defineConstructor(Visibility.PUBLIC) + .withParameters(SourceClass.class, Integer.class /* to be ignored */) + .intercept( + MethodCall.invoke( + ExceptionUtils.uncheckedFactory(() -> Object.class.getConstructor())) + .andThen(new Implementation.Simple(new FieldExtractor(SourceClass.class)))) + .make() + .load(FieldExtractorTest.class.getClassLoader(), ClassLoadingStrategy.Default.WRAPPER) + .getLoaded(); + + // Instantiate the generated class using the constructor that takes SourceClass as parameter + Constructor constructor = dynamicClass.getConstructor(SourceClass.class, Integer.class); + Object generatedInstance = constructor.newInstance(sourceInstance, 0); + + // Verify that fields are copied correctly + for (Field field : dynamicClass.getDeclaredFields()) { + field.setAccessible(true); + Field source = SourceClass.class.getDeclaredField(field.getName()); + assertEquals(source.get(sourceInstance), field.get(generatedInstance)); + } + } +} diff --git a/beam/tools/src/main/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProvider.java b/beam/tools/src/main/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProvider.java index d3c0f9dd4..a69a5b919 100644 --- a/beam/tools/src/main/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProvider.java +++ b/beam/tools/src/main/java/cz/o2/proxima/beam/tools/groovy/BeamStreamProvider.java @@ -22,6 +22,7 @@ import com.google.api.client.util.Lists; import com.google.auto.service.AutoService; import cz.o2.proxima.beam.core.BeamDataOperator; +import cz.o2.proxima.beam.core.ProximaPipelineOptions; import cz.o2.proxima.core.functional.Factory; import cz.o2.proxima.core.functional.UnaryFunction; import cz.o2.proxima.core.repository.AttributeDescriptor; @@ -251,7 +252,7 @@ protected UnaryFunction getCreatePipelineFromOpts() { void createUdfJarAndRegisterToPipeline(PipelineOptions opts) { String runnerName = opts.getRunner().getSimpleName(); try { - File path = createJarFromUdfs(); + File path = createJarFromUdfs(opts.as(ProximaPipelineOptions.class)); if (path != null) { log.info("Created jar {} with generated classes.", path); List files = new ArrayList<>(Collections.singletonList(path)); @@ -269,7 +270,7 @@ private Collection getAddedJars() { .orElse(Collections.emptySet()); } - private @Nullable File createJarFromUdfs() throws IOException { + private @Nullable File createJarFromUdfs(ProximaPipelineOptions opts) throws IOException { ToolsClassLoader loader = getToolsClassLoader(); Map, byte[]> codeMap = Optional.ofNullable(loader) @@ -281,7 +282,7 @@ private Collection getAddedJars() { .collect(Collectors.toMap(Pair::getFirst, Pair::getSecond))) .orElse(Collections.emptyMap()); log.info("Building jar from classes {} retrieved from {}", codeMap.keySet(), loader); - return createJarFromDynamicClasses(codeMap); + return createJarFromDynamicClasses(opts, codeMap); } private @Nullable ToolsClassLoader getToolsClassLoader() {