From 59e20f20de73ced59ae2c782612fa7554fc1fced Mon Sep 17 00:00:00 2001 From: Benoit Moriceau Date: Fri, 1 Jul 2022 17:37:32 -0700 Subject: [PATCH] Bmoric/state aggregator (#14364) * Update state.state type * Add state aggregator * Test and format * PR comments * Move to its own package * Update airbyte-workers/src/test/java/io/airbyte/workers/internal/state_aggregator/StateAggregatorTest.java Co-authored-by: Lake Mossman * format * Update airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/DefaultStateAggregator.java Co-authored-by: Lake Mossman * format Co-authored-by: Lake Mossman --- .../internal/AirbyteMessageTracker.java | 11 +- .../DefaultStateAggregator.java | 55 ++++++ .../SingleStateAggregator.java | 32 ++++ .../state_aggregator/StateAggregator.java | 16 ++ .../StreamStateAggregator.java | 31 ++++ .../internal/AirbyteMessageTrackerTest.java | 11 +- .../state_aggregator/StateAggregatorTest.java | 162 ++++++++++++++++++ 7 files changed, 313 insertions(+), 5 deletions(-) create mode 100644 airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/DefaultStateAggregator.java create mode 100644 airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/SingleStateAggregator.java create mode 100644 airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/StateAggregator.java create mode 100644 airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/StreamStateAggregator.java create mode 100644 airbyte-workers/src/test/java/io/airbyte/workers/internal/state_aggregator/StateAggregatorTest.java diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java b/airbyte-workers/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java index b8ceb1bef528..0f99e0edeca1 100644 --- a/airbyte-workers/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java +++ b/airbyte-workers/src/main/java/io/airbyte/workers/internal/AirbyteMessageTracker.java @@ -18,6 +18,8 @@ import io.airbyte.protocol.models.AirbyteStateMessage; import io.airbyte.protocol.models.AirbyteTraceMessage; import io.airbyte.workers.helper.FailureHelper; +import io.airbyte.workers.internal.state_aggregator.DefaultStateAggregator; +import io.airbyte.workers.internal.state_aggregator.StateAggregator; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -44,6 +46,7 @@ public class AirbyteMessageTracker implements MessageTracker { private final StateDeltaTracker stateDeltaTracker; private final List destinationErrorTraceMessages; private final List sourceErrorTraceMessages; + private final StateAggregator stateAggregator; private short nextStreamIndex; @@ -59,11 +62,11 @@ private enum ConnectorType { } public AirbyteMessageTracker() { - this(new StateDeltaTracker(STATE_DELTA_TRACKER_MEMORY_LIMIT_BYTES)); + this(new StateDeltaTracker(STATE_DELTA_TRACKER_MEMORY_LIMIT_BYTES), new DefaultStateAggregator()); } @VisibleForTesting - protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker) { + protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker, final StateAggregator stateAggregator) { this.sourceOutputState = new AtomicReference<>(); this.destinationOutputState = new AtomicReference<>(); this.totalEmittedStateMessages = new AtomicLong(0L); @@ -77,6 +80,7 @@ protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker) { this.unreliableCommittedCounts = false; this.destinationErrorTraceMessages = new ArrayList<>(); this.sourceErrorTraceMessages = new ArrayList<>(); + this.stateAggregator = stateAggregator; } @Override @@ -144,7 +148,8 @@ private void handleSourceEmittedState(final AirbyteStateMessage stateMessage) { * committed in the {@link StateDeltaTracker}. Also record this state as the last committed state. */ private void handleDestinationEmittedState(final AirbyteStateMessage stateMessage) { - destinationOutputState.set(new State().withState(stateMessage.getData())); + stateAggregator.ingest(stateMessage); + destinationOutputState.set(stateAggregator.getAggregated()); try { if (!unreliableCommittedCounts) { stateDeltaTracker.commitStateHash(getStateHashCode(stateMessage)); diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/DefaultStateAggregator.java b/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/DefaultStateAggregator.java new file mode 100644 index 000000000000..a076a006530f --- /dev/null +++ b/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/DefaultStateAggregator.java @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2022 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.workers.internal.state_aggregator; + +import com.google.common.base.Preconditions; +import io.airbyte.config.State; +import io.airbyte.protocol.models.AirbyteStateMessage; +import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType; + +public class DefaultStateAggregator implements StateAggregator { + + private AirbyteStateType stateType = null; + private final StateAggregator streamStateAggregator = new StreamStateAggregator(); + private final StateAggregator singleStateAggregator = new SingleStateAggregator(); + + @Override + public void ingest(final AirbyteStateMessage stateMessage) { + checkTypeOrSetType(stateMessage.getType()); + + getStateAggregator().ingest(stateMessage); + } + + @Override + public State getAggregated() { + return getStateAggregator().getAggregated(); + } + + /** + * Return the state aggregator that match the state type. + */ + private StateAggregator getStateAggregator() { + return switch (stateType) { + case STREAM -> streamStateAggregator; + case GLOBAL, LEGACY -> singleStateAggregator; + }; + } + + /** + * We can not have 2 different state types given to the same instance of this class. This method set + * the type if it is not. If the state type doesn't exist in the message, it is set to LEGACY + */ + private void checkTypeOrSetType(AirbyteStateType inputStateType) { + if (inputStateType == null) { + inputStateType = AirbyteStateType.LEGACY; + } + if (this.stateType == null) { + this.stateType = inputStateType; + } + Preconditions.checkArgument(this.stateType == inputStateType, + "Input state type " + inputStateType + " does not match the aggregator's current state type " + this.stateType); + } + +} diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/SingleStateAggregator.java b/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/SingleStateAggregator.java new file mode 100644 index 000000000000..09106a08c1e2 --- /dev/null +++ b/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/SingleStateAggregator.java @@ -0,0 +1,32 @@ +/* + * Copyright (c) 2022 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.workers.internal.state_aggregator; + +import io.airbyte.commons.json.Jsons; +import io.airbyte.config.State; +import io.airbyte.protocol.models.AirbyteStateMessage; +import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType; +import java.util.List; + +class SingleStateAggregator implements StateAggregator { + + AirbyteStateMessage state; + + @Override + public void ingest(final AirbyteStateMessage stateMessage) { + state = stateMessage; + } + + @Override + public State getAggregated() { + if (state.getType() == null || state.getType() == AirbyteStateType.LEGACY) { + return new State().withState(state.getData()); + } else { + return new State() + .withState(Jsons.jsonNode(List.of(state))); + } + } + +} diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/StateAggregator.java b/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/StateAggregator.java new file mode 100644 index 000000000000..97c02e7b0a90 --- /dev/null +++ b/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/StateAggregator.java @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2022 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.workers.internal.state_aggregator; + +import io.airbyte.config.State; +import io.airbyte.protocol.models.AirbyteStateMessage; + +public interface StateAggregator { + + void ingest(AirbyteStateMessage stateMessage); + + State getAggregated(); + +} diff --git a/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/StreamStateAggregator.java b/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/StreamStateAggregator.java new file mode 100644 index 000000000000..d55563efe0ec --- /dev/null +++ b/airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/StreamStateAggregator.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2022 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.workers.internal.state_aggregator; + +import io.airbyte.commons.json.Jsons; +import io.airbyte.config.State; +import io.airbyte.protocol.models.AirbyteStateMessage; +import io.airbyte.protocol.models.StreamDescriptor; +import java.util.HashMap; +import java.util.Map; + +class StreamStateAggregator implements StateAggregator { + + Map aggregatedState = new HashMap<>(); + + @Override + public void ingest(final AirbyteStateMessage stateMessage) { + aggregatedState.put(stateMessage.getStream().getStreamDescriptor(), stateMessage); + } + + @Override + public State getAggregated() { + + return new State() + .withState( + Jsons.jsonNode(aggregatedState.values())); + } + +} diff --git a/airbyte-workers/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java b/airbyte-workers/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java index 75e9ccdf38b5..b54e1892b7d6 100644 --- a/airbyte-workers/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java +++ b/airbyte-workers/src/test/java/io/airbyte/workers/internal/AirbyteMessageTrackerTest.java @@ -13,6 +13,7 @@ import io.airbyte.protocol.models.AirbyteMessage; import io.airbyte.workers.helper.FailureHelper; import io.airbyte.workers.internal.StateDeltaTracker.StateDeltaTrackerException; +import io.airbyte.workers.internal.state_aggregator.StateAggregator; import java.util.HashMap; import java.util.Map; import org.junit.jupiter.api.BeforeEach; @@ -34,9 +35,12 @@ class AirbyteMessageTrackerTest { @Mock private StateDeltaTracker mStateDeltaTracker; + @Mock + private StateAggregator mStateAggregator; + @BeforeEach public void setup() { - this.messageTracker = new AirbyteMessageTracker(mStateDeltaTracker); + this.messageTracker = new AirbyteMessageTracker(mStateDeltaTracker, mStateAggregator); } @Test @@ -65,6 +69,9 @@ public void testRetainsLatestSourceAndDestinationState() { final AirbyteMessage s2 = AirbyteMessageUtils.createStateMessage(s2Value); final AirbyteMessage s3 = AirbyteMessageUtils.createStateMessage(s3Value); + final State expectedState = new State().withState(Jsons.jsonNode(s2Value)); + Mockito.when(mStateAggregator.getAggregated()).thenReturn(expectedState); + messageTracker.acceptFromSource(s1); messageTracker.acceptFromSource(s2); messageTracker.acceptFromSource(s3); @@ -75,7 +82,7 @@ public void testRetainsLatestSourceAndDestinationState() { assertEquals(new State().withState(Jsons.jsonNode(s3Value)), messageTracker.getSourceOutputState().get()); assertTrue(messageTracker.getDestinationOutputState().isPresent()); - assertEquals(new State().withState(Jsons.jsonNode(s2Value)), messageTracker.getDestinationOutputState().get()); + assertEquals(expectedState, messageTracker.getDestinationOutputState().get()); } @Test diff --git a/airbyte-workers/src/test/java/io/airbyte/workers/internal/state_aggregator/StateAggregatorTest.java b/airbyte-workers/src/test/java/io/airbyte/workers/internal/state_aggregator/StateAggregatorTest.java new file mode 100644 index 000000000000..13df955ff2da --- /dev/null +++ b/airbyte-workers/src/test/java/io/airbyte/workers/internal/state_aggregator/StateAggregatorTest.java @@ -0,0 +1,162 @@ +/* + * Copyright (c) 2022 Airbyte, Inc., all rights reserved. + */ + +package io.airbyte.workers.internal.state_aggregator; + +import static io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType.GLOBAL; +import static io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType.LEGACY; +import static io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType.STREAM; + +import com.google.common.collect.Lists; +import io.airbyte.commons.json.Jsons; +import io.airbyte.config.State; +import io.airbyte.protocol.models.AirbyteGlobalState; +import io.airbyte.protocol.models.AirbyteStateMessage; +import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType; +import io.airbyte.protocol.models.AirbyteStreamState; +import io.airbyte.protocol.models.StreamDescriptor; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Stream; +import org.assertj.core.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +public class StateAggregatorTest { + + StateAggregator stateAggregator; + + @BeforeEach + public void init() { + stateAggregator = new DefaultStateAggregator(); + } + + @ParameterizedTest + @EnumSource(AirbyteStateType.class) + public void testCantMixType(final AirbyteStateType stateType) { + final Stream allTypes = Arrays.stream(AirbyteStateType.values()); + + stateAggregator.ingest(getEmptyMessage(stateType)); + + final List differentTypes = allTypes.filter(type -> type != stateType).toList(); + differentTypes.forEach(differentType -> Assertions.assertThatThrownBy(() -> stateAggregator.ingest(getEmptyMessage(differentType)))); + } + + @Test + public void testCantMixNullType() { + final List allIncompatibleTypes = Lists.newArrayList(GLOBAL, STREAM); + + stateAggregator.ingest(getEmptyMessage(null)); + + allIncompatibleTypes.forEach(differentType -> Assertions.assertThatThrownBy(() -> stateAggregator.ingest(getEmptyMessage(differentType)))); + + stateAggregator.ingest(getEmptyMessage(LEGACY)); + } + + @Test + public void testNullState() { + final AirbyteStateMessage state1 = getNullMessage(1); + final AirbyteStateMessage state2 = getNullMessage(2); + + stateAggregator.ingest(state1); + Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State() + .withState(state1.getData())); + + stateAggregator.ingest(state2); + Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State() + .withState(state2.getData())); + } + + @Test + public void testLegacyState() { + final AirbyteStateMessage state1 = getLegacyMessage(1); + final AirbyteStateMessage state2 = getLegacyMessage(2); + + stateAggregator.ingest(state1); + Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State() + .withState(state1.getData())); + + stateAggregator.ingest(state2); + Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State() + .withState(state2.getData())); + } + + @Test + public void testGlobalState() { + final AirbyteStateMessage state1 = getGlobalMessage(1); + final AirbyteStateMessage state2 = getGlobalMessage(2); + + stateAggregator.ingest(state1); + Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State() + .withState(Jsons.jsonNode(List.of(state1)))); + + stateAggregator.ingest(state2); + Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State() + .withState(Jsons.jsonNode(List.of(state2)))); + } + + @Test + public void testStreamState() { + final AirbyteStateMessage state1 = getStreamMessage("a", 1); + final AirbyteStateMessage state2 = getStreamMessage("b", 2); + final AirbyteStateMessage state3 = getStreamMessage("b", 3); + + stateAggregator.ingest(state1); + Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State() + .withState(Jsons.jsonNode(List.of(state1)))); + + stateAggregator.ingest(state2); + Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State() + .withState(Jsons.jsonNode(List.of(state2, state1)))); + + stateAggregator.ingest(state3); + Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State() + .withState(Jsons.jsonNode(List.of(state3, state1)))); + } + + private AirbyteStateMessage getNullMessage(final int stateValue) { + return new AirbyteStateMessage().withData(Jsons.jsonNode(stateValue)); + } + + private AirbyteStateMessage getLegacyMessage(final int stateValue) { + return new AirbyteStateMessage().withType(LEGACY).withData(Jsons.jsonNode(stateValue)); + } + + private AirbyteStateMessage getGlobalMessage(final int stateValue) { + return new AirbyteStateMessage().withType(GLOBAL) + .withGlobal(new AirbyteGlobalState() + .withStreamStates( + List.of( + new AirbyteStreamState() + .withStreamDescriptor( + new StreamDescriptor() + .withName("test")) + .withStreamState(Jsons.jsonNode(stateValue))))); + } + + private AirbyteStateMessage getStreamMessage(final String streamName, final int stateValue) { + return new AirbyteStateMessage().withType(STREAM) + .withStream( + new AirbyteStreamState() + .withStreamDescriptor( + new StreamDescriptor() + .withName(streamName)) + .withStreamState(Jsons.jsonNode(stateValue))); + } + + private AirbyteStateMessage getEmptyMessage(final AirbyteStateType stateType) { + if (stateType == STREAM) { + return new AirbyteStateMessage() + .withType(STREAM) + .withStream( + new AirbyteStreamState() + .withStreamDescriptor(new StreamDescriptor())); + } + + return new AirbyteStateMessage().withType(stateType); + } + +}