Skip to content

Commit

Permalink
Skip validation when previous state is empty due to reset (#20585)
Browse files Browse the repository at this point in the history
* Skip validation when previous state is empty due to reset

* Handle null state object

* Fix formatting

* Fix logic

* Fix method name
  • Loading branch information
jdpgrailsdev authored Jan 3, 2023
1 parent 4897e29 commit c3987a9
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,8 @@ public boolean persist(final UUID connectionId, final StandardSyncOutput syncOut
AirbyteApiClient.retryWithJitter(
() -> airbyteApiClient.getStateApi().getState(new ConnectionIdRequestBody().connectionId(connectionId)),
"get state");
if (featureFlags.needStateValidation() && previousState != null) {
final StateType newStateType = maybeStateWrapper.get().getStateType();
final StateType prevStateType = convertClientStateTypeToInternal(previousState.getStateType());

if (isMigration(newStateType, prevStateType) && newStateType == StateType.STREAM) {
validateStreamStates(maybeStateWrapper.get(), configuredCatalog);
}
}
validate(configuredCatalog, maybeStateWrapper, previousState);

AirbyteApiClient.retryWithJitter(
() -> {
Expand All @@ -85,6 +79,42 @@ public boolean persist(final UUID connectionId, final StandardSyncOutput syncOut
}
}

/**
* Validates whether it is safe to persist the new state based on the previously saved state.
*
* @param configuredCatalog The configured catalog of streams for the connection.
* @param newState The new state.
* @param previousState The previous state.
*/
private void validate(final ConfiguredAirbyteCatalog configuredCatalog,
final Optional<StateWrapper> newState,
final ConnectionState previousState) {
/**
* If state validation is enabled and the previous state exists and is not empty, make sure that
* state will not be lost as part of the migration from legacy -> per stream.
*
* Otherwise, it is okay to update if the previous state is missing or empty.
*/
if (featureFlags.needStateValidation() && !isStateEmpty(previousState)) {
final StateType newStateType = newState.get().getStateType();
final StateType prevStateType = convertClientStateTypeToInternal(previousState.getStateType());

if (isMigration(newStateType, prevStateType) && newStateType == StateType.STREAM) {
validateStreamStates(newState.get(), configuredCatalog);
}
}
}

/**
* Test whether the connection state is empty.
*
* @param connectionState The connection state.
* @return {@code true} if the connection state is null or empty, {@code false} otherwise.
*/
private boolean isStateEmpty(final ConnectionState connectionState) {
return connectionState == null || connectionState.getState() == null || connectionState.getState().isEmpty();
}

@VisibleForTesting
void validateStreamStates(final StateWrapper state, final ConfiguredAirbyteCatalog configuredCatalog) {
final List<StreamDescriptor> stateStreamDescriptors =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,19 @@

package io.airbyte.workers.temporal.sync;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.when;

import com.fasterxml.jackson.databind.JsonNode;
import io.airbyte.api.client.AirbyteApiClient;
import io.airbyte.api.client.generated.StateApi;
import io.airbyte.api.client.invoker.generated.ApiException;
import io.airbyte.api.client.model.generated.ConnectionIdRequestBody;
import io.airbyte.api.client.model.generated.ConnectionState;
import io.airbyte.api.client.model.generated.ConnectionStateCreateOrUpdate;
import io.airbyte.api.client.model.generated.ConnectionStateType;
import io.airbyte.commons.features.FeatureFlags;
import io.airbyte.commons.json.Jsons;
import io.airbyte.config.StandardSyncOutput;
Expand Down Expand Up @@ -42,6 +48,10 @@
class PersistStateActivityTest {

private final static UUID CONNECTION_ID = UUID.randomUUID();
private static final String STREAM_A = "a";
private static final String STREAM_A_NAMESPACE = "a1";
private static final String STREAM_B = "b";
private static final String STREAM_C = "c";

@Mock
AirbyteApiClient airbyteApiClient;
Expand Down Expand Up @@ -78,7 +88,7 @@ void testPersistEmpty() {

@Test
void testPersist() throws ApiException {
Mockito.when(featureFlags.useStreamCapableState()).thenReturn(true);
when(featureFlags.useStreamCapableState()).thenReturn(true);

final JsonNode jsonState = Jsons.jsonNode(Map.ofEntries(
Map.entry("some", "state")));
Expand All @@ -88,7 +98,7 @@ void testPersist() throws ApiException {
persistStateActivity.persist(CONNECTION_ID, new StandardSyncOutput().withState(state), new ConfiguredAirbyteCatalog());

// The ser/der of the state into a state wrapper is tested in StateMessageHelperTest
Mockito.verify(stateApi).createOrUpdateState(Mockito.any(ConnectionStateCreateOrUpdate.class));
Mockito.verify(stateApi).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class));
}

// For per-stream state, we expect there to be state for each stream within the configured catalog
Expand All @@ -97,8 +107,9 @@ void testPersist() throws ApiException {
// catalog has a state message when migrating from Legacy to Per-Stream
@Test
void testPersistWithValidMissingStateDuringMigration() throws ApiException {
final ConfiguredAirbyteStream stream = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("a").withNamespace("a1"));
final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("b"));
final ConfiguredAirbyteStream stream =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE));
final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B));

final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
Expand All @@ -110,19 +121,20 @@ void testPersistWithValidMissingStateDuringMigration() throws ApiException {

final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2));
final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state);
Mockito.when(featureFlags.useStreamCapableState()).thenReturn(true);
when(featureFlags.useStreamCapableState()).thenReturn(true);

mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.STREAM), Mockito.any(StateType.class))).thenReturn(true);
mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.STREAM), any(StateType.class))).thenReturn(true);
persistStateActivity.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog);
Mockito.verify(stateApi).createOrUpdateState(Mockito.any(ConnectionStateCreateOrUpdate.class));
Mockito.verify(stateApi).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class));
}

@Test
void testPersistWithValidStateDuringMigration() throws ApiException {
final ConfiguredAirbyteStream stream = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("a").withNamespace("a1"));
final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("b"));
final ConfiguredAirbyteStream stream =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE));
final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B));
final ConfiguredAirbyteStream stream3 =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("c")).withSyncMode(SyncMode.FULL_REFRESH);
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_C)).withSyncMode(SyncMode.FULL_REFRESH);

final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
Expand All @@ -138,30 +150,149 @@ void testPersistWithValidStateDuringMigration() throws ApiException {

final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2, stream3));
final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state);
Mockito.when(featureFlags.useStreamCapableState()).thenReturn(true);
mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.STREAM), Mockito.any(StateType.class))).thenReturn(true);
when(featureFlags.useStreamCapableState()).thenReturn(true);
mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.STREAM), any(StateType.class))).thenReturn(true);
persistStateActivity.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog);
Mockito.verify(stateApi).createOrUpdateState(Mockito.any(ConnectionStateCreateOrUpdate.class));
Mockito.verify(stateApi).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class));
}

// Global stream states do not need to be validated during the migration to per-stream state
@Test
void testPersistWithGlobalStateDuringMigration() throws ApiException {
final ConfiguredAirbyteStream stream = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("a").withNamespace("a1"));
final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName("b"));
final ConfiguredAirbyteStream stream =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE));
final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B));

final AirbyteStateMessage stateMessage = new AirbyteStateMessage().withType(AirbyteStateType.GLOBAL);
final JsonNode jsonState = Jsons.jsonNode(List.of(stateMessage));
final State state = new State().withState(jsonState);

final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2));
final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state);
Mockito.when(featureFlags.useStreamCapableState()).thenReturn(true);
mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.GLOBAL), Mockito.any(StateType.class))).thenReturn(true);
when(featureFlags.useStreamCapableState()).thenReturn(true);
mockedStateMessageHelper.when(() -> StateMessageHelper.isMigration(Mockito.eq(StateType.GLOBAL), any(StateType.class))).thenReturn(true);
persistStateActivity.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog);
final PersistStateActivityImpl persistStateSpy = spy(persistStateActivity);
Mockito.verify(persistStateSpy, Mockito.times(0)).validateStreamStates(Mockito.any(), Mockito.any());
Mockito.verify(stateApi).createOrUpdateState(Mockito.any(ConnectionStateCreateOrUpdate.class));
Mockito.verify(persistStateSpy, Mockito.times(0)).validateStreamStates(any(), any());
Mockito.verify(stateApi).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class));
}

@Test
void testPersistWithPerStreamStateDuringMigrationFromEmptyLegacyState() throws ApiException {
/*
* This test covers a scenario where a reset is executed before any successful syncs for a
* connection. When this occurs, an empty, legacy state is stored for the connection.
*/
final ConfiguredAirbyteStream stream =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE));
final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B));
final ConfiguredAirbyteStream stream3 =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_C)).withSyncMode(SyncMode.FULL_REFRESH);

final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(
new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream))
.withStreamState(Jsons.emptyObject()));
final AirbyteStateMessage stateMessage2 = new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(
new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream2)));
final JsonNode jsonState = Jsons.jsonNode(List.of(stateMessage1, stateMessage2));
final State state = new State().withState(jsonState);

final AirbyteApiClient airbyteApiClient1 = mock(AirbyteApiClient.class);
final StateApi stateApi1 = mock(StateApi.class);
final ConnectionState connectionState = mock(ConnectionState.class);
Mockito.lenient().when(connectionState.getStateType()).thenReturn(ConnectionStateType.LEGACY);
Mockito.lenient().when(connectionState.getState()).thenReturn(Jsons.emptyObject());
when(stateApi1.getState(any(ConnectionIdRequestBody.class))).thenReturn(connectionState);
Mockito.lenient().when(airbyteApiClient1.getStateApi()).thenReturn(stateApi1);

final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2, stream3));
final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state);
when(featureFlags.useStreamCapableState()).thenReturn(true);

final PersistStateActivityImpl persistStateActivity1 = new PersistStateActivityImpl(airbyteApiClient1, featureFlags);

persistStateActivity1.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog);

Mockito.verify(stateApi1).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class));
}

@Test
void testPersistWithPerStreamStateDuringMigrationFromNullLegacyState() throws ApiException {
final ConfiguredAirbyteStream stream =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE));
final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B));
final ConfiguredAirbyteStream stream3 =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_C)).withSyncMode(SyncMode.FULL_REFRESH);

final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(
new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream))
.withStreamState(Jsons.emptyObject()));
final AirbyteStateMessage stateMessage2 = new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(
new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream2)));
final JsonNode jsonState = Jsons.jsonNode(List.of(stateMessage1, stateMessage2));
final State state = new State().withState(jsonState);

final AirbyteApiClient airbyteApiClient1 = mock(AirbyteApiClient.class);
final StateApi stateApi1 = mock(StateApi.class);
final ConnectionState connectionState = mock(ConnectionState.class);
Mockito.lenient().when(connectionState.getStateType()).thenReturn(ConnectionStateType.LEGACY);
Mockito.lenient().when(connectionState.getState()).thenReturn(null);
when(stateApi1.getState(any(ConnectionIdRequestBody.class))).thenReturn(connectionState);
Mockito.lenient().when(airbyteApiClient1.getStateApi()).thenReturn(stateApi1);

final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2, stream3));
final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state);
when(featureFlags.useStreamCapableState()).thenReturn(true);

final PersistStateActivityImpl persistStateActivity1 = new PersistStateActivityImpl(airbyteApiClient1, featureFlags);

persistStateActivity1.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog);

Mockito.verify(stateApi1).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class));
}

@Test
void testPersistWithPerStreamStateDuringMigrationWithNoPreviousState() throws ApiException {
final ConfiguredAirbyteStream stream =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_A).withNamespace(STREAM_A_NAMESPACE));
final ConfiguredAirbyteStream stream2 = new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_B));
final ConfiguredAirbyteStream stream3 =
new ConfiguredAirbyteStream().withStream(new AirbyteStream().withName(STREAM_C)).withSyncMode(SyncMode.FULL_REFRESH);

final AirbyteStateMessage stateMessage1 = new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(
new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream))
.withStreamState(Jsons.emptyObject()));
final AirbyteStateMessage stateMessage2 = new AirbyteStateMessage()
.withType(AirbyteStateType.STREAM)
.withStream(
new AirbyteStreamState().withStreamDescriptor(CatalogHelpers.extractDescriptor(stream2)));
final JsonNode jsonState = Jsons.jsonNode(List.of(stateMessage1, stateMessage2));
final State state = new State().withState(jsonState);

final AirbyteApiClient airbyteApiClient1 = mock(AirbyteApiClient.class);
final StateApi stateApi1 = mock(StateApi.class);
when(stateApi1.getState(any(ConnectionIdRequestBody.class))).thenReturn(null);
Mockito.lenient().when(airbyteApiClient1.getStateApi()).thenReturn(stateApi1);

final ConfiguredAirbyteCatalog migrationConfiguredCatalog = new ConfiguredAirbyteCatalog().withStreams(List.of(stream, stream2, stream3));
final StandardSyncOutput syncOutput = new StandardSyncOutput().withState(state);
when(featureFlags.useStreamCapableState()).thenReturn(true);

final PersistStateActivityImpl persistStateActivity1 = new PersistStateActivityImpl(airbyteApiClient1, featureFlags);

persistStateActivity1.persist(CONNECTION_ID, syncOutput, migrationConfiguredCatalog);

Mockito.verify(stateApi1).createOrUpdateState(any(ConnectionStateCreateOrUpdate.class));
}

}

0 comments on commit c3987a9

Please sign in to comment.