Skip to content

Commit

Permalink
Add message translation to GetSpec (airbytehq#18130)
Browse files Browse the repository at this point in the history
* Update SpecActivityImpl to build a VersionedStreamFactory

* Enable Protocol Version detection from Stream for SPEC

* Print log before the action for better debugging

* Fix buffer size for protocol detection

* Improve detectVersion error handling

* extract constan

* Rename attribute for clarity
  • Loading branch information
gosusnp authored and jhammarstedt committed Oct 31, 2022
1 parent 7fe6d70 commit da7db29
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class AirbyteMessageMigrator {

private final List<AirbyteMessageMigration<?, ?>> migrationsToRegister;
private final SortedMap<String, AirbyteMessageMigration<?, ?>> migrations = new TreeMap<>();
private String mostRecentVersion = "";
private String mostRecentMajorVersion = "";

public AirbyteMessageMigrator(List<AirbyteMessageMigration<?, ?>> migrations) {
migrationsToRegister = migrations;
Expand All @@ -47,7 +47,7 @@ public void initialize() {
* required migrations
*/
public <PreviousVersion, CurrentVersion> PreviousVersion downgrade(final CurrentVersion message, final Version target) {
if (target.getMajorVersion().equals(mostRecentVersion)) {
if (target.getMajorVersion().equals(mostRecentMajorVersion)) {
return (PreviousVersion) message;
}

Expand All @@ -64,7 +64,7 @@ public <PreviousVersion, CurrentVersion> PreviousVersion downgrade(final Current
* migrations
*/
public <PreviousVersion, CurrentVersion> CurrentVersion upgrade(final PreviousVersion message, final Version source) {
if (source.getMajorVersion().equals(mostRecentVersion)) {
if (source.getMajorVersion().equals(mostRecentMajorVersion)) {
return (CurrentVersion) message;
}

Expand All @@ -75,6 +75,10 @@ public <PreviousVersion, CurrentVersion> CurrentVersion upgrade(final PreviousVe
return (CurrentVersion) result;
}

public Version getMostRecentVersion() {
return new Version(mostRecentMajorVersion, "0", "0");
}

private Collection<AirbyteMessageMigration<?, ?>> selectMigrations(final Version version) {
final Collection<AirbyteMessageMigration<?, ?>> results = migrations.tailMap(version.getMajorVersion()).values();
if (results.isEmpty()) {
Expand Down Expand Up @@ -107,8 +111,8 @@ void registerMigration(final AirbyteMessageMigration<?, ?> migration) {
final String key = migration.getPreviousVersion().getMajorVersion();
if (!migrations.containsKey(key)) {
migrations.put(key, migration);
if (migration.getCurrentVersion().getMajorVersion().compareTo(mostRecentVersion) > 0) {
mostRecentVersion = migration.getCurrentVersion().getMajorVersion();
if (migration.getCurrentVersion().getMajorVersion().compareTo(mostRecentMajorVersion) > 0) {
mostRecentMajorVersion = migration.getCurrentVersion().getMajorVersion();
}
} else {
throw new RuntimeException("Trying to register a duplicated migration " + migration.getClass().getName());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,8 @@ public <T> AirbyteMessageVersionedMigrator<T> getVersionedMigrator(final Version
return new AirbyteMessageVersionedMigrator<>(this.migrator, version);
}

public Version getMostRecentVersion() {
return migrator.getMostRecentVersion();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.airbyte.workers.internal;

import com.fasterxml.jackson.databind.JsonNode;
import com.google.common.base.Preconditions;
import io.airbyte.commons.json.Jsons;
import io.airbyte.commons.logging.MdcScope;
import io.airbyte.commons.protocol.AirbyteMessageSerDeProvider;
Expand All @@ -13,7 +14,12 @@
import io.airbyte.commons.protocol.serde.AirbyteMessageDeserializer;
import io.airbyte.commons.version.Version;
import io.airbyte.protocol.models.AirbyteMessage;
import java.io.BufferedReader;
import java.io.IOException;
import java.util.Optional;
import java.util.function.Predicate;
import java.util.stream.Stream;
import lombok.SneakyThrows;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -26,10 +32,23 @@
public class VersionedAirbyteStreamFactory<T> extends DefaultAirbyteStreamFactory {

private static final Logger LOGGER = LoggerFactory.getLogger(VersionedAirbyteStreamFactory.class);
private static final Version fallbackVersion = new Version("0.2.0");

private final AirbyteMessageDeserializer<T> deserializer;
private final AirbyteMessageVersionedMigrator<T> migrator;
private final Version protocolVersion;
// Buffer size to use when detecting the protocol version.
// Given that BufferedReader::reset fails if we try to reset if we go past its buffer size, this
// buffer has to be big enough to contain our longest spec and whatever messages get emitted before
// the SPEC.
private static final int BUFFER_READ_AHEAD_LIMIT = 32000;
private static final int MESSAGES_LOOK_AHEAD_FOR_DETECTION = 10;
private static final String TYPE_FIELD_NAME = "type";

private final AirbyteMessageSerDeProvider serDeProvider;
private final AirbyteMessageVersionedMigratorFactory migratorFactory;
private AirbyteMessageDeserializer<T> deserializer;
private AirbyteMessageVersionedMigrator<T> migrator;
private Version protocolVersion;

private boolean shouldDetectVersion = false;

public VersionedAirbyteStreamFactory(final AirbyteMessageSerDeProvider serDeProvider,
final AirbyteMessageVersionedMigratorFactory migratorFactory,
Expand All @@ -43,6 +62,90 @@ public VersionedAirbyteStreamFactory(final AirbyteMessageSerDeProvider serDeProv
final MdcScope.Builder containerLogMdcBuilder) {
// TODO AirbyteProtocolPredicate needs to be updated to be protocol version aware
super(new AirbyteProtocolPredicate(), LOGGER, containerLogMdcBuilder);
Preconditions.checkNotNull(protocolVersion);
this.serDeProvider = serDeProvider;
this.migratorFactory = migratorFactory;
this.initializeForProtocolVersion(protocolVersion);
}

/**
* Create the AirbyteMessage stream.
*
* If detectVersion is set to true, it will decide which protocol version to use from the content of
* the stream rather than the one passed from the constructor.
*/
@SneakyThrows
@Override
public Stream<AirbyteMessage> create(final BufferedReader bufferedReader) {
if (shouldDetectVersion) {
final Optional<Version> versionMaybe = detectVersion(bufferedReader);
if (versionMaybe.isPresent()) {
logger.info("Detected Protocol Version {}", versionMaybe.get().serialize());
initializeForProtocolVersion(versionMaybe.get());
} else {
// No version found, use the default as a fallback
logger.info("Unable to detect Protocol Version, assuming protocol version {}", fallbackVersion.serialize());
initializeForProtocolVersion(fallbackVersion);
}
}
return super.create(bufferedReader);
}

/**
* Attempt to detect the version by scanning the stream
*
* Using the BufferedReader reset/mark feature to get a look-ahead. We will attempt to find the
* first SPEC message and decide on a protocol version from this message.
*
* @param bufferedReader the stream to read
* @return The Version if found
* @throws IOException
*/
private Optional<Version> detectVersion(final BufferedReader bufferedReader) throws IOException {
// Buffersize needs to be big enough to containing everything we need for the detection. Otherwise,
// the reset will fail.
bufferedReader.mark(BUFFER_READ_AHEAD_LIMIT);
try {
// Cap detection to the first 10 messages. When doing the protocol detection, we expect the SPEC
// message to show up early in the stream. Ideally it should be first message however we do not
// enforce this constraint currently so connectors may send LOG messages before.
for (int i = 0; i < MESSAGES_LOOK_AHEAD_FOR_DETECTION; ++i) {
final String line = bufferedReader.readLine();
final Optional<JsonNode> jsonOpt = Jsons.tryDeserialize(line);
if (jsonOpt.isPresent()) {
final JsonNode json = jsonOpt.get();
if (isSpecMessage(json)) {
final JsonNode protocolVersionNode = json.at("/spec/protocol_version");
bufferedReader.reset();
return Optional.ofNullable(protocolVersionNode).filter(Predicate.not(JsonNode::isMissingNode)).map(node -> new Version(node.asText()));
}
}
}
bufferedReader.reset();
return Optional.empty();
} catch (IOException e) {
logger.warn(
"Protocol version detection failed, it is likely than the connector sent more than {}B without an complete SPEC message." +
" A SPEC message that is too long could be the root cause here.",
BUFFER_READ_AHEAD_LIMIT);
throw e;
}
}

private boolean isSpecMessage(final JsonNode json) {
return json.has(TYPE_FIELD_NAME) && "spec".equalsIgnoreCase(json.get(TYPE_FIELD_NAME).asText());
}

public boolean setDetectVersion(final boolean detectVersion) {
return this.shouldDetectVersion = detectVersion;
}

public VersionedAirbyteStreamFactory<T> withDetectVersion(final boolean detectVersion) {
setDetectVersion(detectVersion);
return this;
}

final protected void initializeForProtocolVersion(final Version protocolVersion) {
this.deserializer = (AirbyteMessageDeserializer<T>) serDeProvider.getDeserializer(protocolVersion).orElseThrow();
this.migrator = migratorFactory.getVersionedMigrator(protocolVersion);
this.protocolVersion = protocolVersion;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright (c) 2022 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.workers.internal;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;

import io.airbyte.commons.protocol.AirbyteMessageMigrator;
import io.airbyte.commons.protocol.AirbyteMessageSerDeProvider;
import io.airbyte.commons.protocol.AirbyteMessageVersionedMigratorFactory;
import io.airbyte.commons.protocol.migrations.AirbyteMessageMigrationV0;
import io.airbyte.commons.protocol.serde.AirbyteMessageV0Deserializer;
import io.airbyte.commons.protocol.serde.AirbyteMessageV0Serializer;
import io.airbyte.commons.version.Version;
import io.airbyte.protocol.models.AirbyteMessage;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.nio.charset.Charset;
import java.util.List;
import java.util.stream.Stream;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.platform.commons.util.ClassLoaderUtils;

class VersionedAirbyteStreamFactoryTest {

AirbyteMessageSerDeProvider serDeProvider;
AirbyteMessageVersionedMigratorFactory migratorFactory;

final static Version defaultVersion = new Version("0.2.0");

@BeforeEach
void beforeEach() {
serDeProvider = spy(new AirbyteMessageSerDeProvider(
List.of(new AirbyteMessageV0Deserializer()),
List.of(new AirbyteMessageV0Serializer())));
serDeProvider.initialize();
final AirbyteMessageMigrator migrator = new AirbyteMessageMigrator(
List.of(new AirbyteMessageMigrationV0()));
migrator.initialize();
migratorFactory = spy(new AirbyteMessageVersionedMigratorFactory(migrator));
}

@Test
void testCreate() {
final Version initialVersion = new Version("0.1.2");
final VersionedAirbyteStreamFactory<?> streamFactory = new VersionedAirbyteStreamFactory<>(serDeProvider, migratorFactory, initialVersion);

final BufferedReader bufferedReader = new BufferedReader(new StringReader(""));
streamFactory.create(bufferedReader);

verify(serDeProvider).getDeserializer(initialVersion);
verify(migratorFactory).getVersionedMigrator(initialVersion);
}

@Test
void testCreateWithVersionDetection() {
final Version initialVersion = new Version("0.0.0");
final VersionedAirbyteStreamFactory<?> streamFactory = new VersionedAirbyteStreamFactory<>(serDeProvider, migratorFactory, initialVersion)
.withDetectVersion(true);

final BufferedReader bufferedReader =
getBuffereredReader("version-detection/logs-with-version.jsonl");
final Stream<AirbyteMessage> stream = streamFactory.create(bufferedReader);

long messageCount = stream.toList().size();
verify(serDeProvider).getDeserializer(initialVersion);
verify(serDeProvider).getDeserializer(new Version("0.5.9"));
assertEquals(1, messageCount);
}

@Test
void testCreateWithVersionDetectionFallback() {
final Version initialVersion = new Version("0.0.6");
final VersionedAirbyteStreamFactory<?> streamFactory = new VersionedAirbyteStreamFactory<>(serDeProvider, migratorFactory, initialVersion)
.withDetectVersion(true);

final BufferedReader bufferedReader =
getBuffereredReader("version-detection/logs-without-version.jsonl");
final Stream<AirbyteMessage> stream = streamFactory.create(bufferedReader);

final long messageCount = stream.toList().size();
verify(serDeProvider).getDeserializer(initialVersion);
verify(serDeProvider).getDeserializer(defaultVersion);
assertEquals(1, messageCount);
}

@Test
void testCreateWithVersionDetectionWithoutSpecMessage() {
final Version initialVersion = new Version("0.0.1");
final VersionedAirbyteStreamFactory<?> streamFactory = new VersionedAirbyteStreamFactory<>(serDeProvider, migratorFactory, initialVersion)
.withDetectVersion(true);

final BufferedReader bufferedReader =
getBuffereredReader("version-detection/logs-without-spec-message.jsonl");
final Stream<AirbyteMessage> stream = streamFactory.create(bufferedReader);

final long messageCount = stream.toList().size();
verify(serDeProvider).getDeserializer(initialVersion);
verify(serDeProvider).getDeserializer(defaultVersion);
assertEquals(2, messageCount);
}

BufferedReader getBuffereredReader(final String resourceFile) {
return new BufferedReader(
new InputStreamReader(
ClassLoaderUtils.getDefaultClassLoader().getResourceAsStream(resourceFile),
Charset.defaultCharset()));
}

}
Loading

0 comments on commit da7db29

Please sign in to comment.