Skip to content

Commit

Permalink
Bmoric/state aggregator (#14364)
Browse files Browse the repository at this point in the history
* Update state.state type

* Add state aggregator

* Test and format

* PR comments

* Move to its own package

* Update airbyte-workers/src/test/java/io/airbyte/workers/internal/state_aggregator/StateAggregatorTest.java

Co-authored-by: Lake Mossman <lake@airbyte.io>

* format

* Update airbyte-workers/src/main/java/io/airbyte/workers/internal/state_aggregator/DefaultStateAggregator.java

Co-authored-by: Lake Mossman <lake@airbyte.io>

* format

Co-authored-by: Lake Mossman <lake@airbyte.io>
  • Loading branch information
benmoriceau and lmossman authored Jul 2, 2022
1 parent ca272c3 commit 59e20f2
Show file tree
Hide file tree
Showing 7 changed files with 313 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.AirbyteTraceMessage;
import io.airbyte.workers.helper.FailureHelper;
import io.airbyte.workers.internal.state_aggregator.DefaultStateAggregator;
import io.airbyte.workers.internal.state_aggregator.StateAggregator;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
Expand All @@ -44,6 +46,7 @@ public class AirbyteMessageTracker implements MessageTracker {
private final StateDeltaTracker stateDeltaTracker;
private final List<AirbyteTraceMessage> destinationErrorTraceMessages;
private final List<AirbyteTraceMessage> sourceErrorTraceMessages;
private final StateAggregator stateAggregator;

private short nextStreamIndex;

Expand All @@ -59,11 +62,11 @@ private enum ConnectorType {
}

public AirbyteMessageTracker() {
this(new StateDeltaTracker(STATE_DELTA_TRACKER_MEMORY_LIMIT_BYTES));
this(new StateDeltaTracker(STATE_DELTA_TRACKER_MEMORY_LIMIT_BYTES), new DefaultStateAggregator());
}

@VisibleForTesting
protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker) {
protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker, final StateAggregator stateAggregator) {
this.sourceOutputState = new AtomicReference<>();
this.destinationOutputState = new AtomicReference<>();
this.totalEmittedStateMessages = new AtomicLong(0L);
Expand All @@ -77,6 +80,7 @@ protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker) {
this.unreliableCommittedCounts = false;
this.destinationErrorTraceMessages = new ArrayList<>();
this.sourceErrorTraceMessages = new ArrayList<>();
this.stateAggregator = stateAggregator;
}

@Override
Expand Down Expand Up @@ -144,7 +148,8 @@ private void handleSourceEmittedState(final AirbyteStateMessage stateMessage) {
* committed in the {@link StateDeltaTracker}. Also record this state as the last committed state.
*/
private void handleDestinationEmittedState(final AirbyteStateMessage stateMessage) {
destinationOutputState.set(new State().withState(stateMessage.getData()));
stateAggregator.ingest(stateMessage);
destinationOutputState.set(stateAggregator.getAggregated());
try {
if (!unreliableCommittedCounts) {
stateDeltaTracker.commitStateHash(getStateHashCode(stateMessage));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright (c) 2022 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.workers.internal.state_aggregator;

import com.google.common.base.Preconditions;
import io.airbyte.config.State;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType;

public class DefaultStateAggregator implements StateAggregator {

private AirbyteStateType stateType = null;
private final StateAggregator streamStateAggregator = new StreamStateAggregator();
private final StateAggregator singleStateAggregator = new SingleStateAggregator();

@Override
public void ingest(final AirbyteStateMessage stateMessage) {
checkTypeOrSetType(stateMessage.getType());

getStateAggregator().ingest(stateMessage);
}

@Override
public State getAggregated() {
return getStateAggregator().getAggregated();
}

/**
* Return the state aggregator that match the state type.
*/
private StateAggregator getStateAggregator() {
return switch (stateType) {
case STREAM -> streamStateAggregator;
case GLOBAL, LEGACY -> singleStateAggregator;
};
}

/**
* We can not have 2 different state types given to the same instance of this class. This method set
* the type if it is not. If the state type doesn't exist in the message, it is set to LEGACY
*/
private void checkTypeOrSetType(AirbyteStateType inputStateType) {
if (inputStateType == null) {
inputStateType = AirbyteStateType.LEGACY;
}
if (this.stateType == null) {
this.stateType = inputStateType;
}
Preconditions.checkArgument(this.stateType == inputStateType,
"Input state type " + inputStateType + " does not match the aggregator's current state type " + this.stateType);
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/*
* Copyright (c) 2022 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.workers.internal.state_aggregator;

import io.airbyte.commons.json.Jsons;
import io.airbyte.config.State;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType;
import java.util.List;

class SingleStateAggregator implements StateAggregator {

AirbyteStateMessage state;

@Override
public void ingest(final AirbyteStateMessage stateMessage) {
state = stateMessage;
}

@Override
public State getAggregated() {
if (state.getType() == null || state.getType() == AirbyteStateType.LEGACY) {
return new State().withState(state.getData());
} else {
return new State()
.withState(Jsons.jsonNode(List.of(state)));
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright (c) 2022 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.workers.internal.state_aggregator;

import io.airbyte.config.State;
import io.airbyte.protocol.models.AirbyteStateMessage;

public interface StateAggregator {

void ingest(AirbyteStateMessage stateMessage);

State getAggregated();

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright (c) 2022 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.workers.internal.state_aggregator;

import io.airbyte.commons.json.Jsons;
import io.airbyte.config.State;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.StreamDescriptor;
import java.util.HashMap;
import java.util.Map;

class StreamStateAggregator implements StateAggregator {

Map<StreamDescriptor, AirbyteStateMessage> aggregatedState = new HashMap<>();

@Override
public void ingest(final AirbyteStateMessage stateMessage) {
aggregatedState.put(stateMessage.getStream().getStreamDescriptor(), stateMessage);
}

@Override
public State getAggregated() {

return new State()
.withState(
Jsons.jsonNode(aggregatedState.values()));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.workers.helper.FailureHelper;
import io.airbyte.workers.internal.StateDeltaTracker.StateDeltaTrackerException;
import io.airbyte.workers.internal.state_aggregator.StateAggregator;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.BeforeEach;
Expand All @@ -34,9 +35,12 @@ class AirbyteMessageTrackerTest {
@Mock
private StateDeltaTracker mStateDeltaTracker;

@Mock
private StateAggregator mStateAggregator;

@BeforeEach
public void setup() {
this.messageTracker = new AirbyteMessageTracker(mStateDeltaTracker);
this.messageTracker = new AirbyteMessageTracker(mStateDeltaTracker, mStateAggregator);
}

@Test
Expand Down Expand Up @@ -65,6 +69,9 @@ public void testRetainsLatestSourceAndDestinationState() {
final AirbyteMessage s2 = AirbyteMessageUtils.createStateMessage(s2Value);
final AirbyteMessage s3 = AirbyteMessageUtils.createStateMessage(s3Value);

final State expectedState = new State().withState(Jsons.jsonNode(s2Value));
Mockito.when(mStateAggregator.getAggregated()).thenReturn(expectedState);

messageTracker.acceptFromSource(s1);
messageTracker.acceptFromSource(s2);
messageTracker.acceptFromSource(s3);
Expand All @@ -75,7 +82,7 @@ public void testRetainsLatestSourceAndDestinationState() {
assertEquals(new State().withState(Jsons.jsonNode(s3Value)), messageTracker.getSourceOutputState().get());

assertTrue(messageTracker.getDestinationOutputState().isPresent());
assertEquals(new State().withState(Jsons.jsonNode(s2Value)), messageTracker.getDestinationOutputState().get());
assertEquals(expectedState, messageTracker.getDestinationOutputState().get());
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
/*
* Copyright (c) 2022 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.workers.internal.state_aggregator;

import static io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType.GLOBAL;
import static io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType.LEGACY;
import static io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType.STREAM;

import com.google.common.collect.Lists;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.State;
import io.airbyte.protocol.models.AirbyteGlobalState;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.AirbyteStateMessage.AirbyteStateType;
import io.airbyte.protocol.models.AirbyteStreamState;
import io.airbyte.protocol.models.StreamDescriptor;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Stream;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;

public class StateAggregatorTest {

StateAggregator stateAggregator;

@BeforeEach
public void init() {
stateAggregator = new DefaultStateAggregator();
}

@ParameterizedTest
@EnumSource(AirbyteStateType.class)
public void testCantMixType(final AirbyteStateType stateType) {
final Stream<AirbyteStateType> allTypes = Arrays.stream(AirbyteStateType.values());

stateAggregator.ingest(getEmptyMessage(stateType));

final List<AirbyteStateType> differentTypes = allTypes.filter(type -> type != stateType).toList();
differentTypes.forEach(differentType -> Assertions.assertThatThrownBy(() -> stateAggregator.ingest(getEmptyMessage(differentType))));
}

@Test
public void testCantMixNullType() {
final List<AirbyteStateType> allIncompatibleTypes = Lists.newArrayList(GLOBAL, STREAM);

stateAggregator.ingest(getEmptyMessage(null));

allIncompatibleTypes.forEach(differentType -> Assertions.assertThatThrownBy(() -> stateAggregator.ingest(getEmptyMessage(differentType))));

stateAggregator.ingest(getEmptyMessage(LEGACY));
}

@Test
public void testNullState() {
final AirbyteStateMessage state1 = getNullMessage(1);
final AirbyteStateMessage state2 = getNullMessage(2);

stateAggregator.ingest(state1);
Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State()
.withState(state1.getData()));

stateAggregator.ingest(state2);
Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State()
.withState(state2.getData()));
}

@Test
public void testLegacyState() {
final AirbyteStateMessage state1 = getLegacyMessage(1);
final AirbyteStateMessage state2 = getLegacyMessage(2);

stateAggregator.ingest(state1);
Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State()
.withState(state1.getData()));

stateAggregator.ingest(state2);
Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State()
.withState(state2.getData()));
}

@Test
public void testGlobalState() {
final AirbyteStateMessage state1 = getGlobalMessage(1);
final AirbyteStateMessage state2 = getGlobalMessage(2);

stateAggregator.ingest(state1);
Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State()
.withState(Jsons.jsonNode(List.of(state1))));

stateAggregator.ingest(state2);
Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State()
.withState(Jsons.jsonNode(List.of(state2))));
}

@Test
public void testStreamState() {
final AirbyteStateMessage state1 = getStreamMessage("a", 1);
final AirbyteStateMessage state2 = getStreamMessage("b", 2);
final AirbyteStateMessage state3 = getStreamMessage("b", 3);

stateAggregator.ingest(state1);
Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State()
.withState(Jsons.jsonNode(List.of(state1))));

stateAggregator.ingest(state2);
Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State()
.withState(Jsons.jsonNode(List.of(state2, state1))));

stateAggregator.ingest(state3);
Assertions.assertThat(stateAggregator.getAggregated()).isEqualTo(new State()
.withState(Jsons.jsonNode(List.of(state3, state1))));
}

private AirbyteStateMessage getNullMessage(final int stateValue) {
return new AirbyteStateMessage().withData(Jsons.jsonNode(stateValue));
}

private AirbyteStateMessage getLegacyMessage(final int stateValue) {
return new AirbyteStateMessage().withType(LEGACY).withData(Jsons.jsonNode(stateValue));
}

private AirbyteStateMessage getGlobalMessage(final int stateValue) {
return new AirbyteStateMessage().withType(GLOBAL)
.withGlobal(new AirbyteGlobalState()
.withStreamStates(
List.of(
new AirbyteStreamState()
.withStreamDescriptor(
new StreamDescriptor()
.withName("test"))
.withStreamState(Jsons.jsonNode(stateValue)))));
}

private AirbyteStateMessage getStreamMessage(final String streamName, final int stateValue) {
return new AirbyteStateMessage().withType(STREAM)
.withStream(
new AirbyteStreamState()
.withStreamDescriptor(
new StreamDescriptor()
.withName(streamName))
.withStreamState(Jsons.jsonNode(stateValue)));
}

private AirbyteStateMessage getEmptyMessage(final AirbyteStateType stateType) {
if (stateType == STREAM) {
return new AirbyteStateMessage()
.withType(STREAM)
.withStream(
new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor()));
}

return new AirbyteStateMessage().withType(stateType);
}

}

0 comments on commit 59e20f2

Please sign in to comment.