Skip to content

Commit

Permalink
Adjust to protocol changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jdpgrailsdev committed Jun 14, 2022
1 parent 649effc commit 6f75fd2
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -354,12 +354,15 @@ void testDiscoverWithMultipleSchemas() throws Exception {
final AirbyteCatalog actual = source.discover(config);

final AirbyteCatalog expected = getCatalog(getDefaultNamespace());
expected.getStreams().add(CatalogHelpers
final List<AirbyteStream> catalogStreams = new ArrayList<>();
catalogStreams.addAll(expected.getStreams());
catalogStreams.add(CatalogHelpers
.createAirbyteStream(TABLE_NAME,
SCHEMA_NAME2,
Field.of(COL_ID, JsonSchemaType.STRING),
Field.of(COL_NAME, JsonSchemaType.STRING))
.withSupportedSyncModes(List.of(SyncMode.FULL_REFRESH, SyncMode.INCREMENTAL)));
expected.setStreams(catalogStreams);
// sort streams by name so that we are comparing lists with the same order.
final Comparator<AirbyteStream> schemaTableCompare = Comparator.comparing(stream -> stream.getNamespace() + "." + stream.getName());
expected.getStreams().sort(schemaTableCompare);
Expand Down Expand Up @@ -661,9 +664,7 @@ protected List<AirbyteMessage> getExpectedAirbyteMessagesSecondSync(final String
.withStreamNamespace(namespace)
.withCursorField(List.of(COL_ID))
.withCursor("5");
expectedMessages.add(new AirbyteMessage()
.withType(Type.STATE)
.withState(Jsons.object(createState(List.of(state)), AirbyteStateMessage.class)));
expectedMessages.addAll(createExpectedTestMessages(List.of(state)));
return expectedMessages;
}

Expand Down Expand Up @@ -734,9 +735,9 @@ void testReadMultipleTablesIncrementally() throws Exception {
.withCursor("3"));

final List<AirbyteMessage> expectedMessagesFirstSync = new ArrayList<>(getTestMessages());
expectedMessagesFirstSync.add(createExpectedTestMessage(expectedStateStreams1));
expectedMessagesFirstSync.addAll(createExpectedTestMessages(expectedStateStreams1));
expectedMessagesFirstSync.addAll(secondStreamExpectedMessages);
expectedMessagesFirstSync.add(createExpectedTestMessage(expectedStateStreams2));
expectedMessagesFirstSync.addAll(createExpectedTestMessages(expectedStateStreams2));

setEmittedAtToNull(actualMessagesFirstSync);

Expand Down Expand Up @@ -803,7 +804,7 @@ private void incrementalCursorCheck(
.withCursor(initialCursorValue);

final List<AirbyteMessage> actualMessages = MoreIterators
.toList(source.read(config, configuredCatalog, createState(List.of(dbStreamState))));
.toList(source.read(config, configuredCatalog, Jsons.jsonNode(createState(List.of(dbStreamState)))));

setEmittedAtToNull(actualMessages);

Expand All @@ -814,7 +815,7 @@ private void incrementalCursorCheck(
.withCursorField(List.of(cursorField))
.withCursor(endCursorValue));
final List<AirbyteMessage> expectedMessages = new ArrayList<>(expectedRecordMessages);
expectedMessages.add(createExpectedTestMessage(expectedStreams));
expectedMessages.addAll(createExpectedTestMessages(expectedStreams));

assertEquals(actualMessages.size(), expectedMessages.size());
assertEquals(actualMessages, expectedMessages);
Expand Down Expand Up @@ -883,10 +884,30 @@ protected List<AirbyteMessage> getTestMessages() {
COL_UPDATED_AT, "2006-10-19T00:00:00Z")))));
}

protected AirbyteMessage createExpectedTestMessage(final List<DbStreamState> states) {
return new AirbyteMessage()
.withType(Type.STATE)
.withState(Jsons.object(createState(states), AirbyteStateMessage.class));
protected List<AirbyteMessage> createExpectedTestMessages(final List<DbStreamState> states) {
return supportsPerStream()
? states.stream()
.map(s -> new AirbyteMessage().withType(Type.STATE)
.withState(new AirbyteStateMessage().withStateType(AirbyteStateType.STREAM)
.withStream(new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor().withNamespace(s.getStreamNamespace()).withName(s.getStreamName()))
.withStreamState(Jsons.jsonNode(s)))))
.collect(
Collectors.toList())
: List.of(new AirbyteMessage().withType(Type.STATE).withState(new AirbyteStateMessage().withStateType(AirbyteStateType.LEGACY)
.withData(Jsons.jsonNode(new DbState().withCdc(false).withStreams(states)))));
}

protected List<AirbyteStateMessage> createState(final List<DbStreamState> states) {
return supportsPerStream()
? states.stream()
.map(s -> new AirbyteStateMessage().withStateType(AirbyteStateType.STREAM)
.withStream(new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor().withNamespace(s.getStreamNamespace()).withName(s.getStreamName()))
.withStreamState(Jsons.jsonNode(s))))
.collect(
Collectors.toList())
: List.of(new AirbyteStateMessage().withStateType(AirbyteStateType.LEGACY).withData(Jsons.jsonNode(new DbState().withStreams(states))));
}

protected ConfiguredAirbyteStream createTableWithSpaces() throws SQLException {
Expand Down Expand Up @@ -1011,28 +1032,6 @@ protected JsonNode createEmptyState(final String streamName, final String stream
}
}

/**
* Creates state with the provided stream(s).
*
* @param streams A list of streams.
* @return A {@link JsonNode} representation of the state with the provided stream state.
*/
protected JsonNode createState(final List<DbStreamState> streams) {
if (supportsPerStream()) {
final List<AirbyteStateMessage> messages = streams.stream()
.map(s -> new AirbyteStateMessage().withStateType(AirbyteStateType.STREAM)
.withStream(new AirbyteStreamState()
.withStreamDescriptor(new StreamDescriptor().withName(s.getStreamName()).withNamespace(s.getStreamNamespace()))
.withStreamState(Jsons.jsonNode(s))))
.collect(Collectors.toList());
return Jsons.jsonNode(messages);
} else {
final DbState dbState = new DbState()
.withStreams(streams.stream().collect(Collectors.toList()));
return Jsons.jsonNode(dbState);
}
}

/**
* Extracts the state component from the provided {@link AirbyteMessage} based on the value returned
* by {@link #supportsPerStream()}.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import io.airbyte.protocol.models.AirbyteCatalog;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteRecordMessage;
import io.airbyte.protocol.models.AirbyteStateMessage;
import io.airbyte.protocol.models.CatalogHelpers;
import io.airbyte.protocol.models.ConfiguredAirbyteStream;
import io.airbyte.protocol.models.ConnectorSpecification;
Expand Down Expand Up @@ -435,9 +434,7 @@ protected List<AirbyteMessage> getExpectedAirbyteMessagesSecondSync(final String
.withStreamNamespace(namespace)
.withCursorField(ImmutableList.of(COL_ID))
.withCursor("5");
expectedMessages.add(new AirbyteMessage()
.withType(AirbyteMessage.Type.STATE)
.withState(Jsons.object(createState(List.of(state)), AirbyteStateMessage.class)));
expectedMessages.addAll(createExpectedTestMessages(List.of(state)));
return expectedMessages;
}

Expand Down

0 comments on commit 6f75fd2

Please sign in to comment.