Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add message translation to GetSpec #18130

Merged
merged 12 commits into from
Oct 21, 2022
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class AirbyteMessageMigrator {

private final List<AirbyteMessageMigration<?, ?>> migrationsToRegister;
private final SortedMap<String, AirbyteMessageMigration<?, ?>> migrations = new TreeMap<>();
private String mostRecentVersion = "";
private String mostRecentMajorVersion = "";

public AirbyteMessageMigrator(List<AirbyteMessageMigration<?, ?>> migrations) {
migrationsToRegister = migrations;
Expand All @@ -47,7 +47,7 @@ public void initialize() {
* required migrations
*/
public <PreviousVersion, CurrentVersion> PreviousVersion downgrade(final CurrentVersion message, final Version target) {
if (target.getMajorVersion().equals(mostRecentVersion)) {
if (target.getMajorVersion().equals(mostRecentMajorVersion)) {
return (PreviousVersion) message;
}

Expand All @@ -64,7 +64,7 @@ public <PreviousVersion, CurrentVersion> PreviousVersion downgrade(final Current
* migrations
*/
public <PreviousVersion, CurrentVersion> CurrentVersion upgrade(final PreviousVersion message, final Version source) {
if (source.getMajorVersion().equals(mostRecentVersion)) {
if (source.getMajorVersion().equals(mostRecentMajorVersion)) {
return (CurrentVersion) message;
}

Expand All @@ -75,6 +75,10 @@ public <PreviousVersion, CurrentVersion> CurrentVersion upgrade(final PreviousVe
return (CurrentVersion) result;
}

public Version getMostRecentVersion() {
return new Version(mostRecentMajorVersion, "0", "0");
}

private Collection<AirbyteMessageMigration<?, ?>> selectMigrations(final Version version) {
final Collection<AirbyteMessageMigration<?, ?>> results = migrations.tailMap(version.getMajorVersion()).values();
if (results.isEmpty()) {
Expand Down Expand Up @@ -107,8 +111,8 @@ void registerMigration(final AirbyteMessageMigration<?, ?> migration) {
final String key = migration.getPreviousVersion().getMajorVersion();
if (!migrations.containsKey(key)) {
migrations.put(key, migration);
if (migration.getCurrentVersion().getMajorVersion().compareTo(mostRecentVersion) > 0) {
mostRecentVersion = migration.getCurrentVersion().getMajorVersion();
if (migration.getCurrentVersion().getMajorVersion().compareTo(mostRecentMajorVersion) > 0) {
mostRecentMajorVersion = migration.getCurrentVersion().getMajorVersion();
}
} else {
throw new RuntimeException("Trying to register a duplicated migration " + migration.getClass().getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,8 @@ public <T> AirbyteMessageVersionedMigrator<T> getVersionedMigrator(final Version
return new AirbyteMessageVersionedMigrator<>(this.migrator, version);
}

public Version getMostRecentVersion() {
return migrator.getMostRecentVersion();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.airbyte.workers.internal;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Preconditions;
import io.airbyte.commons.json.Jsons;
import io.airbyte.commons.logging.MdcScope;
import io.airbyte.commons.protocol.AirbyteMessageSerDeProvider;
Expand All @@ -13,7 +14,12 @@
import io.airbyte.commons.protocol.serde.AirbyteMessageDeserializer;
import io.airbyte.commons.version.Version;
import io.airbyte.protocol.models.AirbyteMessage;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Stream;
import lombok.SneakyThrows;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -26,10 +32,23 @@
public class VersionedAirbyteStreamFactory<T> extends DefaultAirbyteStreamFactory {

private static final Logger LOGGER = LoggerFactory.getLogger(VersionedAirbyteStreamFactory.class);
private static final Version fallbackVersion = new Version("0.2.0");

private final AirbyteMessageDeserializer<T> deserializer;
private final AirbyteMessageVersionedMigrator<T> migrator;
private final Version protocolVersion;
// Buffer size to use when detecting the protocol version.
// Given that BufferedReader::reset fails if we try to reset if we go past its buffer size, this
// buffer has to be big enough to contain our longest spec and whatever messages get emitted before
// the SPEC.
private static final int BUFFER_READ_AHEAD_LIMIT = 32000;
private static final int MESSAGES_LOOK_AHEAD_FOR_DETECTION = 10;
private static final String TYPE_FIELD_NAME = "type";

private final AirbyteMessageSerDeProvider serDeProvider;
private final AirbyteMessageVersionedMigratorFactory migratorFactory;
private AirbyteMessageDeserializer<T> deserializer;
private AirbyteMessageVersionedMigrator<T> migrator;
private Version protocolVersion;

private boolean shouldDetectVersion = false;

public VersionedAirbyteStreamFactory(final AirbyteMessageSerDeProvider serDeProvider,
final AirbyteMessageVersionedMigratorFactory migratorFactory,
Expand All @@ -43,6 +62,90 @@ public VersionedAirbyteStreamFactory(final AirbyteMessageSerDeProvider serDeProv
final MdcScope.Builder containerLogMdcBuilder) {
// TODO AirbyteProtocolPredicate needs to be updated to be protocol version aware
super(new AirbyteProtocolPredicate(), LOGGER, containerLogMdcBuilder);
Preconditions.checkNotNull(protocolVersion);
this.serDeProvider = serDeProvider;
this.migratorFactory = migratorFactory;
this.initializeForProtocolVersion(protocolVersion);
}

/**
* Create the AirbyteMessage stream.
*
* If detectVersion is set to true, it will decide which protocol version to use from the content of
* the stream rather than the one passed from the constructor.
*/
@SneakyThrows
@Override
public Stream<AirbyteMessage> create(final BufferedReader bufferedReader) {
if (shouldDetectVersion) {
final Optional<Version> versionMaybe = detectVersion(bufferedReader);
if (versionMaybe.isPresent()) {
logger.info("Detected Protocol Version {}", versionMaybe.get().serialize());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: why use serialized here instead of toString?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It comes from our implementation, serialize returns a traditional "<major>.<minor>.patch>" while toString is a pretty print of the class.

initializeForProtocolVersion(versionMaybe.get());
} else {
// No version found, use the default as a fallback
logger.info("Unable to detect Protocol Version, assuming protocol version {}", fallbackVersion.serialize());
initializeForProtocolVersion(fallbackVersion);
}
}
return super.create(bufferedReader);
}

/**
* Attempt to detect the version by scanning the stream
*
* Using the BufferedReader reset/mark feature to get a look-ahead. We will attempt to find the
* first SPEC message and decide on a protocol version from this message.
*
* @param bufferedReader the stream to read
* @return The Version if found
* @throws IOException
*/
private Optional<Version> detectVersion(final BufferedReader bufferedReader) throws IOException {
// Buffersize needs to be big enough to containing everything we need for the detection. Otherwise,
// the reset will fail.
bufferedReader.mark(BUFFER_READ_AHEAD_LIMIT);
try {
// Cap detection to the first 10 messages. When doing the protocol detection, we expect the SPEC
// message to show up early in the stream. Ideally it should be first message however we do not
// enforce this constraint currently so connectors may send LOG messages before.
for (int i = 0; i < MESSAGES_LOOK_AHEAD_FOR_DETECTION; ++i) {
final String line = bufferedReader.readLine();
final Optional<JsonNode> jsonOpt = Jsons.tryDeserialize(line);
if (jsonOpt.isPresent()) {
final JsonNode json = jsonOpt.get();
if (isSpecMessage(json)) {
final JsonNode protocolVersionNode = json.at("/spec/protocol_version");
bufferedReader.reset();
return Optional.ofNullable(protocolVersionNode).filter(Predicate.not(JsonNode::isMissingNode)).map(node -> new Version(node.asText()));
}
}
}
bufferedReader.reset();
return Optional.empty();
} catch (IOException e) {
logger.warn(
"Protocol version detection failed, it is likely than the connector sent more than {}B without an complete SPEC message." +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the B after {} a typo or is it meant to represents bytes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is meant for bytes.

" A SPEC message that is too long could be the root cause here.",
BUFFER_READ_AHEAD_LIMIT);
throw e;
}
}

private boolean isSpecMessage(final JsonNode json) {
return json.has(TYPE_FIELD_NAME) && "spec".equalsIgnoreCase(json.get(TYPE_FIELD_NAME).asText());
}

public boolean setDetectVersion(final boolean detectVersion) {
return this.shouldDetectVersion = detectVersion;
}

public VersionedAirbyteStreamFactory<T> withDetectVersion(final boolean detectVersion) {
setDetectVersion(detectVersion);
return this;
}

final protected void initializeForProtocolVersion(final Version protocolVersion) {
this.deserializer = (AirbyteMessageDeserializer<T>) serDeProvider.getDeserializer(protocolVersion).orElseThrow();
this.migrator = migratorFactory.getVersionedMigrator(protocolVersion);
this.protocolVersion = protocolVersion;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright (c) 2022 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.workers.internal;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;

import io.airbyte.commons.protocol.AirbyteMessageMigrator;
import io.airbyte.commons.protocol.AirbyteMessageSerDeProvider;
import io.airbyte.commons.protocol.AirbyteMessageVersionedMigratorFactory;
import io.airbyte.commons.protocol.migrations.AirbyteMessageMigrationV0;
import io.airbyte.commons.protocol.serde.AirbyteMessageV0Deserializer;
import io.airbyte.commons.protocol.serde.AirbyteMessageV0Serializer;
import io.airbyte.commons.version.Version;
import io.airbyte.protocol.models.AirbyteMessage;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.nio.charset.Charset;
import java.util.List;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.platform.commons.util.ClassLoaderUtils;

class VersionedAirbyteStreamFactoryTest {

AirbyteMessageSerDeProvider serDeProvider;
AirbyteMessageVersionedMigratorFactory migratorFactory;

final static Version defaultVersion = new Version("0.2.0");

@BeforeEach
void beforeEach() {
serDeProvider = spy(new AirbyteMessageSerDeProvider(
List.of(new AirbyteMessageV0Deserializer()),
List.of(new AirbyteMessageV0Serializer())));
serDeProvider.initialize();
final AirbyteMessageMigrator migrator = new AirbyteMessageMigrator(
List.of(new AirbyteMessageMigrationV0()));
migrator.initialize();
migratorFactory = spy(new AirbyteMessageVersionedMigratorFactory(migrator));
}

@Test
void testCreate() {
final Version initialVersion = new Version("0.1.2");
final VersionedAirbyteStreamFactory<?> streamFactory = new VersionedAirbyteStreamFactory<>(serDeProvider, migratorFactory, initialVersion);

final BufferedReader bufferedReader = new BufferedReader(new StringReader(""));
streamFactory.create(bufferedReader);

verify(serDeProvider).getDeserializer(initialVersion);
verify(migratorFactory).getVersionedMigrator(initialVersion);
}

@Test
void testCreateWithVersionDetection() {
final Version initialVersion = new Version("0.0.0");
final VersionedAirbyteStreamFactory<?> streamFactory = new VersionedAirbyteStreamFactory<>(serDeProvider, migratorFactory, initialVersion)
.withDetectVersion(true);

final BufferedReader bufferedReader =
getBuffereredReader("version-detection/logs-with-version.jsonl");
final Stream<AirbyteMessage> stream = streamFactory.create(bufferedReader);

long messageCount = stream.toList().size();
verify(serDeProvider).getDeserializer(initialVersion);
verify(serDeProvider).getDeserializer(new Version("0.5.9"));
assertEquals(1, messageCount);
}

@Test
void testCreateWithVersionDetectionFallback() {
final Version initialVersion = new Version("0.0.6");
final VersionedAirbyteStreamFactory<?> streamFactory = new VersionedAirbyteStreamFactory<>(serDeProvider, migratorFactory, initialVersion)
.withDetectVersion(true);

final BufferedReader bufferedReader =
getBuffereredReader("version-detection/logs-without-version.jsonl");
final Stream<AirbyteMessage> stream = streamFactory.create(bufferedReader);

final long messageCount = stream.toList().size();
verify(serDeProvider).getDeserializer(initialVersion);
verify(serDeProvider).getDeserializer(defaultVersion);
assertEquals(1, messageCount);
}

@Test
void testCreateWithVersionDetectionWithoutSpecMessage() {
final Version initialVersion = new Version("0.0.1");
final VersionedAirbyteStreamFactory<?> streamFactory = new VersionedAirbyteStreamFactory<>(serDeProvider, migratorFactory, initialVersion)
.withDetectVersion(true);

final BufferedReader bufferedReader =
getBuffereredReader("version-detection/logs-without-spec-message.jsonl");
final Stream<AirbyteMessage> stream = streamFactory.create(bufferedReader);

final long messageCount = stream.toList().size();
verify(serDeProvider).getDeserializer(initialVersion);
verify(serDeProvider).getDeserializer(defaultVersion);
assertEquals(2, messageCount);
}

BufferedReader getBuffereredReader(final String resourceFile) {
return new BufferedReader(
new InputStreamReader(
ClassLoaderUtils.getDefaultClassLoader().getResourceAsStream(resourceFile),
Charset.defaultCharset()));
}

}
Loading