Skip to content

Commit

Permalink
implement column filtering in the replication workflow (#20369)
Browse files Browse the repository at this point in the history
* implement column filtering in the replication workflow

* fixes to column selection in replication workflow

* add a basic acceptance test for column selection

* make CI acceptance tests run with new field selection flag enabled

* fix format

* readability improvements around columns selection tests and other small fixes
  • Loading branch information
mfsiega-airbyte authored Dec 13, 2022
1 parent f76833e commit 0fac8c8
Show file tree
Hide file tree
Showing 9 changed files with 253 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
import static io.airbyte.metrics.lib.ApmTraceConstants.WORKER_OPERATION_NAME;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import datadog.trace.api.Trace;
import io.airbyte.commons.io.LineGobbler;
import io.airbyte.config.FailureReason;
Expand All @@ -28,6 +30,7 @@
import io.airbyte.protocol.models.AirbyteMessage.Type;
import io.airbyte.protocol.models.AirbyteRecordMessage;
import io.airbyte.protocol.models.AirbyteStreamNameNamespacePair;
import io.airbyte.protocol.models.ConfiguredAirbyteCatalog;
import io.airbyte.workers.RecordSchemaValidator;
import io.airbyte.workers.WorkerMetricReporter;
import io.airbyte.workers.WorkerUtils;
Expand All @@ -41,6 +44,7 @@
import io.airbyte.workers.internal.book_keeping.MessageTracker;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -94,6 +98,7 @@ public class DefaultReplicationWorker implements ReplicationWorker {
private final AtomicBoolean hasFailed;
private final RecordSchemaValidator recordSchemaValidator;
private final WorkerMetricReporter metricReporter;
private final boolean fieldSelectionEnabled;

public DefaultReplicationWorker(final String jobId,
final int attempt,
Expand All @@ -102,7 +107,8 @@ public DefaultReplicationWorker(final String jobId,
final AirbyteDestination destination,
final MessageTracker messageTracker,
final RecordSchemaValidator recordSchemaValidator,
final WorkerMetricReporter metricReporter) {
final WorkerMetricReporter metricReporter,
final boolean fieldSelectionEnabled) {
this.jobId = jobId;
this.attempt = attempt;
this.source = source;
Expand All @@ -112,6 +118,7 @@ public DefaultReplicationWorker(final String jobId,
this.executors = Executors.newFixedThreadPool(2);
this.recordSchemaValidator = recordSchemaValidator;
this.metricReporter = metricReporter;
this.fieldSelectionEnabled = fieldSelectionEnabled;

this.cancelled = new AtomicBoolean(false);
this.hasFailed = new AtomicBoolean(false);
Expand Down Expand Up @@ -198,8 +205,18 @@ private void replicate(final Path jobRoot,
});

final CompletableFuture<?> readSrcAndWriteDstThread = CompletableFuture.runAsync(
readFromSrcAndWriteToDstRunnable(source, destination, cancelled, mapper, messageTracker, mdc, recordSchemaValidator, metricReporter,
timeTracker),
readFromSrcAndWriteToDstRunnable(
source,
destination,
sourceConfig.getCatalog(),
cancelled,
mapper,
messageTracker,
mdc,
recordSchemaValidator,
metricReporter,
timeTracker,
fieldSelectionEnabled),
executors)
.whenComplete((msg, ex) -> {
if (ex != null) {
Expand Down Expand Up @@ -279,18 +296,24 @@ private static Runnable readFromDstRunnable(final AirbyteDestination destination
@SuppressWarnings("PMD.AvoidInstanceofChecksInCatchClause")
private static Runnable readFromSrcAndWriteToDstRunnable(final AirbyteSource source,
final AirbyteDestination destination,
final ConfiguredAirbyteCatalog catalog,
final AtomicBoolean cancelled,
final AirbyteMapper mapper,
final MessageTracker messageTracker,
final Map<String, String> mdc,
final RecordSchemaValidator recordSchemaValidator,
final WorkerMetricReporter metricReporter,
final ThreadedTimeTracker timeHolder) {
final ThreadedTimeTracker timeHolder,
final boolean fieldSelectionEnabled) {
return () -> {
MDC.setContextMap(mdc);
LOGGER.info("Replication thread started.");
Long recordsRead = 0L;
final Map<AirbyteStreamNameNamespacePair, ImmutablePair<Set<String>, Integer>> validationErrors = new HashMap<>();
final Map<AirbyteStreamNameNamespacePair, List<String>> streamToSelectedFields = new HashMap<>();
if (fieldSelectionEnabled) {
populatedStreamToSelectedFields(catalog, streamToSelectedFields);
}
try {
while (!cancelled.get() && !source.isFinished()) {
final Optional<AirbyteMessage> messageOptional;
Expand All @@ -302,6 +325,9 @@ private static Runnable readFromSrcAndWriteToDstRunnable(final AirbyteSource sou

if (messageOptional.isPresent()) {
final AirbyteMessage airbyteMessage = messageOptional.get();
if (fieldSelectionEnabled) {
filterSelectedFields(streamToSelectedFields, airbyteMessage);
}
validateSchema(recordSchemaValidator, validationErrors, airbyteMessage);
final AirbyteMessage message = mapper.mapMessage(airbyteMessage);

Expand Down Expand Up @@ -549,6 +575,47 @@ private static void validateSchema(final RecordSchemaValidator recordSchemaValid
}
}

/**
* Generates a map from stream -> the explicit list of fields included for that stream, according to
* the configured catalog. Since the configured catalog only includes the selected fields, this lets
* us filter records to only the fields explicitly requested.
*
* @param catalog
* @param streamToSelectedFields
*/
private static void populatedStreamToSelectedFields(final ConfiguredAirbyteCatalog catalog,
final Map<AirbyteStreamNameNamespacePair, List<String>> streamToSelectedFields) {
for (final var s : catalog.getStreams()) {
final List<String> selectedFields = new ArrayList<>();
final JsonNode propertiesNode = s.getStream().getJsonSchema().findPath("properties");
if (propertiesNode.isObject()) {
propertiesNode.fieldNames().forEachRemaining((fieldName) -> selectedFields.add(fieldName));
} else {
throw new RuntimeException("No properties node in stream schema");
}
streamToSelectedFields.put(AirbyteStreamNameNamespacePair.fromConfiguredAirbyteSteam(s), selectedFields);
}
}

private static void filterSelectedFields(final Map<AirbyteStreamNameNamespacePair, List<String>> streamToSelectedFields,
final AirbyteMessage airbyteMessage) {
final AirbyteRecordMessage record = airbyteMessage.getRecord();

if (record == null) {
// This isn't a record message, so we don't need to do any filtering.
return;
}

final AirbyteStreamNameNamespacePair messageStream = AirbyteStreamNameNamespacePair.fromRecordMessage(record);
final List<String> selectedFields = streamToSelectedFields.getOrDefault(messageStream, Collections.emptyList());
final JsonNode data = record.getData();
if (data.isObject()) {
((ObjectNode) data).retain(selectedFields);
} else {
throw new RuntimeException(String.format("Unexpected data in record: %s", data.toString()));
}
}

@Trace(operationName = WORKER_OPERATION_NAME)
@Override
public void cancel() {
Expand Down
Loading

0 comments on commit 0fac8c8

Please sign in to comment.