Skip to content

Commit

Permalink
Refactor state manager creation
Browse files Browse the repository at this point in the history
  • Loading branch information
jdpgrailsdev committed Jun 15, 2022
1 parent 39b6691 commit 86891b9
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -411,34 +411,24 @@ private static AirbyteStream addCdcMetadataColumns(final AirbyteStream stream) {
return stream;
}

// TODO This is a temporary override so that the Postgres source can take advantage of per-stream
// state.
// TODO This is a temporary override so that the Postgres source can take advantage of per-stream state
@Override
protected List<AirbyteStateMessage> deserializeState(final JsonNode stateJson, final JsonNode config) {
if (stateJson == null) {
if (supportedStateTypeSupplier(config).get() == AirbyteStateType.GLOBAL) {
final AirbyteGlobalState globalState = new AirbyteGlobalState()
.withSharedState(Jsons.jsonNode(new CdcState()))
.withStreamStates(List.of());
return List.of(new AirbyteStateMessage().withStateType(AirbyteStateType.GLOBAL).withGlobal(globalState));
} else {
return List.of(new AirbyteStateMessage()
.withStateType(AirbyteStateType.STREAM)
.withStream(new AirbyteStreamState()));
}
protected List<AirbyteStateMessage> generateEmptyInitialState(final JsonNode config) {
if (getSupportedStateType(config) == AirbyteStateType.GLOBAL) {
final AirbyteGlobalState globalState = new AirbyteGlobalState()
.withSharedState(Jsons.jsonNode(new CdcState()))
.withStreamStates(List.of());
return List.of(new AirbyteStateMessage().withStateType(AirbyteStateType.GLOBAL).withGlobal(globalState));
} else {
try {
return Jsons.object(stateJson, new AirbyteStateMessageListTypeReference());
} catch (final IllegalArgumentException e) {
LOGGER.warn("Defaulting to legacy state object...");
return List.of(new AirbyteStateMessage().withStateType(AirbyteStateType.LEGACY).withData(stateJson));
}
return List.of(new AirbyteStateMessage()
.withStateType(AirbyteStateType.STREAM)
.withStream(new AirbyteStreamState()));
}
}

@Override
protected Supplier<AirbyteStateType> supportedStateTypeSupplier(final JsonNode config) {
return () -> isCdc(config) ? AirbyteStateType.GLOBAL : AirbyteStateType.STREAM;
protected AirbyteStateType getSupportedStateType(final JsonNode config) {
return isCdc(config) ? AirbyteStateType.GLOBAL : AirbyteStateType.STREAM;
}

public static void main(final String[] args) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
Expand Down Expand Up @@ -110,7 +109,7 @@ public AutoCloseableIterator<AirbyteMessage> read(final JsonNode config,
final JsonNode state)
throws Exception {
final StateManager stateManager =
StateManagerFactory.createStateManager(deserializeState(state, config), catalog, supportedStateTypeSupplier(config));
StateManagerFactory.createStateManager(getSupportedStateType(config), deserializeInitialState(state, config), catalog);
final Instant emittedAt = Instant.now();

final Database database = createDatabaseInternal(config);
Expand Down Expand Up @@ -517,33 +516,41 @@ private Database createDatabaseInternal(final JsonNode sourceConfig) throws Exce
/**
* Deserializes the state represented as JSON into an object representation.
*
* @param stateJson The state as JSON.
* @param initialStateJson The state as JSON.
* @param config The connector configuration.
* @return The deserialized object representation of the state.
*/
protected List<AirbyteStateMessage> deserializeState(final JsonNode stateJson, final JsonNode config) {
if (stateJson == null) {
// For backwards compatibility with existing connectors
return List.of(new AirbyteStateMessage().withStateType(AirbyteStateType.LEGACY).withData(Jsons.jsonNode(new DbState())));
protected List<AirbyteStateMessage> deserializeInitialState(final JsonNode initialStateJson, final JsonNode config) {
if (initialStateJson == null) {
return generateEmptyInitialState(config);
} else {
try {
return Jsons.object(stateJson, new AirbyteStateMessageListTypeReference());
return Jsons.object(initialStateJson, new AirbyteStateMessageListTypeReference());
} catch (final IllegalArgumentException e) {
LOGGER.warn("Defaulting to legacy state object...");
return List.of(new AirbyteStateMessage().withStateType(AirbyteStateType.LEGACY).withData(stateJson));
return List.of(new AirbyteStateMessage().withStateType(AirbyteStateType.LEGACY).withData(initialStateJson));
}
}
}

/**
* Generates a {@link Supplier} that can be used to determine which state manager should be selected
* for use by this connector.
* Generates an empty, initial state for use by the connector.
* @param config The connector configuration.
* @return The empty, initial state.
*/
protected List<AirbyteStateMessage> generateEmptyInitialState(final JsonNode config) {
// For backwards compatibility with existing connectors
return List.of(new AirbyteStateMessage().withStateType(AirbyteStateType.LEGACY).withData(Jsons.jsonNode(new DbState())));
}

/**
* Returns the {@link AirbyteStateType} supported by this connector.
*
* @param config The connector configuration.
* @return A {@link Supplier}.
* @return A {@link AirbyteStateType} representing the state supported by this connector.
*/
protected Supplier<AirbyteStateType> supportedStateTypeSupplier(final JsonNode config) {
return () -> AirbyteStateType.LEGACY;
protected AirbyteStateType getSupportedStateType(final JsonNode config) {
return AirbyteStateType.LEGACY;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import java.util.ArrayList;
import java.util.List;
import java.util.function.Supplier;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -31,19 +30,18 @@ private StateManagerFactory() {}
* Creates a {@link StateManager} based on the provided state object and catalog. This method will handle the
* conversion of the provided state to match the requested state manager based on the provided {@link AirbyteStateType}.
*
* @param supportedStateType The type of state supported by the connector.
* @param initialState The deserialized initial state that will be provided to the selected {@link StateManager}.
* @param catalog The {@link ConfiguredAirbyteCatalog} for the connector that will utilize the state
* manager.
* @param stateTypeSupplier {@link Supplier} that provides the {@link AirbyteStateType} that will be
* used to select the correct state manager.
* @return A newly created {@link StateManager} implementation based on the provided state.
*/
public static StateManager createStateManager(final List<AirbyteStateMessage> initialState,
final ConfiguredAirbyteCatalog catalog,
final Supplier<AirbyteStateType> stateTypeSupplier) {
public static StateManager createStateManager(final AirbyteStateType supportedStateType,
final List<AirbyteStateMessage> initialState,
final ConfiguredAirbyteCatalog catalog) {
if (initialState != null && !initialState.isEmpty()) {
final AirbyteStateMessage airbyteStateMessage = initialState.get(0);
switch (stateTypeSupplier.get()) {
switch (supportedStateType) {
case LEGACY:
LOGGER.info("Legacy state manager selected to manage state object with type {}.", airbyteStateMessage.getStateType());
return new LegacyStateManager(Jsons.object(airbyteStateMessage.getData(), DbState.class), catalog);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import io.airbyte.protocol.models.StreamDescriptor;
import java.util.List;
import java.util.function.Supplier;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

Expand All @@ -30,38 +29,32 @@ public class StateManagerFactoryTest {
private static final String NAMESPACE = "namespace";
private static final String NAME = "name";

private static final Supplier<AirbyteStateType> GLOBAL_STATE_TYPE = () -> AirbyteStateType.GLOBAL;

private static final Supplier<AirbyteStateType> LEGACY_STATE_TYPE = () -> AirbyteStateType.LEGACY;

private static final Supplier<AirbyteStateType> STREAM_STATE_TYPE = () -> AirbyteStateType.STREAM;

@Test
void testNullOrEmptyState() {
final ConfiguredAirbyteCatalog catalog = mock(ConfiguredAirbyteCatalog.class);

Assertions.assertThrows(IllegalArgumentException.class, () -> {
StateManagerFactory.createStateManager(null, catalog, GLOBAL_STATE_TYPE);
StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, null, catalog);
});

Assertions.assertThrows(IllegalArgumentException.class, () -> {
StateManagerFactory.createStateManager(List.of(), catalog, GLOBAL_STATE_TYPE);
StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(), catalog);
});

Assertions.assertThrows(IllegalArgumentException.class, () -> {
StateManagerFactory.createStateManager(null, catalog, LEGACY_STATE_TYPE);
StateManagerFactory.createStateManager(AirbyteStateType.LEGACY,null, catalog);
});

Assertions.assertThrows(IllegalArgumentException.class, () -> {
StateManagerFactory.createStateManager(List.of(), catalog, LEGACY_STATE_TYPE);
StateManagerFactory.createStateManager(AirbyteStateType.LEGACY, List.of(), catalog);
});

Assertions.assertThrows(IllegalArgumentException.class, () -> {
StateManagerFactory.createStateManager(null, catalog, STREAM_STATE_TYPE);
StateManagerFactory.createStateManager(AirbyteStateType.STREAM, null, catalog);
});

Assertions.assertThrows(IllegalArgumentException.class, () -> {
StateManagerFactory.createStateManager(List.of(), catalog, STREAM_STATE_TYPE);
StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(), catalog);
});
}

Expand All @@ -71,7 +64,7 @@ void testLegacyStateManagerCreationFromAirbyteStateMessage() {
final AirbyteStateMessage airbyteStateMessage = mock(AirbyteStateMessage.class);
when(airbyteStateMessage.getData()).thenReturn(Jsons.jsonNode(new DbState()));

final StateManager stateManager = StateManagerFactory.createStateManager(List.of(airbyteStateMessage), catalog, LEGACY_STATE_TYPE);
final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.LEGACY, List.of(airbyteStateMessage), catalog);

Assertions.assertNotNull(stateManager);
Assertions.assertEquals(LegacyStateManager.class, stateManager.getClass());
Expand All @@ -86,7 +79,7 @@ void testGlobalStateManagerCreation() {
.withStreamState(Jsons.jsonNode(new DbStreamState()))));
final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage().withStateType(AirbyteStateType.GLOBAL).withGlobal(globalState);

final StateManager stateManager = StateManagerFactory.createStateManager(List.of(airbyteStateMessage), catalog, GLOBAL_STATE_TYPE);
final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog);

Assertions.assertNotNull(stateManager);
Assertions.assertEquals(GlobalStateManager.class, stateManager.getClass());
Expand All @@ -102,7 +95,7 @@ void testGlobalStateManagerCreationFromLegacyState() {
final AirbyteStateMessage airbyteStateMessage =
new AirbyteStateMessage().withStateType(AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState));

final StateManager stateManager = StateManagerFactory.createStateManager(List.of(airbyteStateMessage), catalog, GLOBAL_STATE_TYPE);
final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog);

Assertions.assertNotNull(stateManager);
Assertions.assertEquals(GlobalStateManager.class, stateManager.getClass());
Expand All @@ -116,7 +109,7 @@ void testGlobalStateManagerCreationFromStreamState() {
NAMESPACE)).withStreamState(Jsons.jsonNode(new DbStreamState())));

Assertions.assertThrows(IllegalArgumentException.class,
() -> StateManagerFactory.createStateManager(List.of(airbyteStateMessage), catalog, GLOBAL_STATE_TYPE));
() -> StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog));
}

@Test
Expand All @@ -129,7 +122,7 @@ void testGlobalStateManagerCreationWithLegacyDataPresent() {
final AirbyteStateMessage airbyteStateMessage =
new AirbyteStateMessage().withStateType(AirbyteStateType.GLOBAL).withGlobal(globalState).withData(Jsons.jsonNode(new DbState()));

final StateManager stateManager = StateManagerFactory.createStateManager(List.of(airbyteStateMessage), catalog, GLOBAL_STATE_TYPE);
final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.GLOBAL, List.of(airbyteStateMessage), catalog);

Assertions.assertNotNull(stateManager);
Assertions.assertEquals(GlobalStateManager.class, stateManager.getClass());
Expand All @@ -142,7 +135,7 @@ void testStreamStateManagerCreation() {
.withStream(new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withName(NAME).withNamespace(
NAMESPACE)).withStreamState(Jsons.jsonNode(new DbStreamState())));

final StateManager stateManager = StateManagerFactory.createStateManager(List.of(airbyteStateMessage), catalog, STREAM_STATE_TYPE);
final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog);

Assertions.assertNotNull(stateManager);
Assertions.assertEquals(StreamStateManager.class, stateManager.getClass());
Expand All @@ -158,7 +151,7 @@ void testStreamStateManagerCreationFromLegacy() {
final AirbyteStateMessage airbyteStateMessage =
new AirbyteStateMessage().withStateType(AirbyteStateType.LEGACY).withData(Jsons.jsonNode(dbState));

final StateManager stateManager = StateManagerFactory.createStateManager(List.of(airbyteStateMessage), catalog, STREAM_STATE_TYPE);
final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog);

Assertions.assertNotNull(stateManager);
Assertions.assertEquals(StreamStateManager.class, stateManager.getClass());
Expand All @@ -173,7 +166,7 @@ void testStreamStateManagerCreationFromGlobal() {
.withStreamState(Jsons.jsonNode(new DbStreamState()))));
final AirbyteStateMessage airbyteStateMessage = new AirbyteStateMessage().withStateType(AirbyteStateType.GLOBAL).withGlobal(globalState);

Assertions.assertThrows(IllegalArgumentException.class, () -> StateManagerFactory.createStateManager(List.of(airbyteStateMessage), catalog, STREAM_STATE_TYPE));
Assertions.assertThrows(IllegalArgumentException.class, () -> StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog));
}

@Test
Expand All @@ -184,7 +177,7 @@ void testStreamStateManagerCreationWithLegacyDataPresent() {
NAMESPACE)).withStreamState(Jsons.jsonNode(new DbStreamState())))
.withData(Jsons.jsonNode(new DbState()));

final StateManager stateManager = StateManagerFactory.createStateManager(List.of(airbyteStateMessage), catalog, STREAM_STATE_TYPE);
final StateManager stateManager = StateManagerFactory.createStateManager(AirbyteStateType.STREAM, List.of(airbyteStateMessage), catalog);

Assertions.assertNotNull(stateManager);
Assertions.assertEquals(StreamStateManager.class, stateManager.getClass());
Expand Down

0 comments on commit 86891b9

Please sign in to comment.