Skip to content

Commit

Permalink
feat(Platform): update actor configuration when receiving control mes…
Browse files Browse the repository at this point in the history
…sages from connectors during sync (#19811)

* track latest config message

* pass new config as part of outputs

* persist new config

* persist config as the messages come through, dont set output

* clean up old implementation

* accept control messages for destinations

* get api client from micronaut

* mask instance-wide oauth params when updating configs

* defaultreplicationworker tests

* formatting

* tests for source/destination handlers

* rm todo

* refactor test a bit to fix pmd

* fix pmd

* fix test

* add PersistConfigHelperTest

* update message tracker comment

* fix pmd

* format

* move ApiClientBeanFactory to commons-worker, use in container-orchestrator

* pull out config updating to separate methods

* add jitter

* rename PersistConfigHelper -> UpdateConnectorConfigHelper, docs

* fix exception type

* fmt

* move message type check into runnable

* formatting

* pass api client env vars to container orchestrator

* pass micronaut envs to container orchestrator

* print stacktrace for debugging

* different api host for container orchestrator

* fix default env var

* format

* fix errors after merge

* set source and destination actor id as part of the sync input

* fix: get destination definition

* fix null ptr

* remove "actor" from naming

* fix missing change from rename

* revert ContainerOrchestratorConfigBeanFactory changes

* inject sourceapi/destinationapi directly rather than airbyteapiclient

* UpdateConnectorConfigHelper -> ConnectorConfigUpdater

* rm log

* fix test

* dont fail on config update error

* pass id, not full config to runnables/accept control message

* add new config required for api client

* add test file

* fix test compatibility

* mount data plane credentials secret to container orchestrator (#20724)

* mount data plane credentials secret to container orchestrator

* rm copy-pasta

* properly handle empty strings

* set env vars like before

* use the right config vars
  • Loading branch information
pedroslopez authored Jan 6, 2023
1 parent f11d7ff commit 2a38177
Show file tree
Hide file tree
Showing 36 changed files with 808 additions and 115 deletions.
1 change: 1 addition & 0 deletions airbyte-commons-worker/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ dependencies {
implementation libs.bundles.micronaut

implementation 'io.fabric8:kubernetes-client:5.12.2'
implementation 'com.auth0:java-jwt:3.19.2'
implementation libs.guava
implementation (libs.temporal.sdk) {
exclude module: 'guava'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ public record ContainerOrchestratorConfig(
KubernetesClient kubernetesClient,
String secretName,
String secretMountPath,
String dataPlaneCredsSecretName,
String dataPlaneCredsSecretMountPath,
String containerOrchestratorImage,
String containerOrchestratorImagePullPolicy,
String googleApplicationCredentials,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ public static void cancelProcess(final Process process) {
*/
public static WorkerSourceConfig syncToWorkerSourceConfig(final StandardSyncInput sync) {
return new WorkerSourceConfig()
.withSourceId(sync.getSourceId())
.withSourceConnectionConfiguration(sync.getSourceConfiguration())
.withCatalog(sync.getCatalog())
.withState(sync.getState());
Expand All @@ -102,6 +103,7 @@ public static WorkerSourceConfig syncToWorkerSourceConfig(final StandardSyncInpu
*/
public static WorkerDestinationConfig syncToWorkerDestinationConfig(final StandardSyncInput sync) {
return new WorkerDestinationConfig()
.withDestinationId(sync.getDestinationId())
.withDestinationConnectionConfiguration(sync.getDestinationConfiguration())
.withCatalog(sync.getCatalog())
.withState(sync.getState());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com.google.auth.oauth2.ServiceAccountCredentials;
import io.airbyte.api.client.AirbyteApiClient;
import io.airbyte.api.client.generated.ConnectionApi;
import io.airbyte.api.client.generated.DestinationApi;
import io.airbyte.api.client.generated.SourceApi;
import io.airbyte.api.client.generated.WorkspaceApi;
import io.airbyte.api.client.invoker.generated.ApiClient;
Expand Down Expand Up @@ -60,7 +61,7 @@ public ApiClient apiClient(@Value("${airbyte.internal.api.auth-header.name}") fi
}

@Singleton
public AirbyteApiClient airbyteApiClient(ApiClient apiClient) {
public AirbyteApiClient airbyteApiClient(final ApiClient apiClient) {
return new AirbyteApiClient(apiClient);
}

Expand All @@ -69,6 +70,11 @@ public SourceApi sourceApi(final ApiClient apiClient) {
return new SourceApi(apiClient);
}

@Singleton
public DestinationApi destinationApi(final ApiClient apiClient) {
return new DestinationApi(apiClient);
}

@Singleton
public ConnectionApi connectionApi(final ApiClient apiClient) {
return new ConnectionApi(apiClient);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.airbyte.config.WorkerDestinationConfig;
import io.airbyte.config.WorkerSourceConfig;
import io.airbyte.metrics.lib.ApmTraceUtils;
import io.airbyte.protocol.models.AirbyteControlMessage;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteMessage.Type;
import io.airbyte.protocol.models.AirbyteRecordMessage;
Expand All @@ -36,6 +37,7 @@
import io.airbyte.workers.WorkerUtils;
import io.airbyte.workers.exception.RecordSchemaValidationException;
import io.airbyte.workers.exception.WorkerException;
import io.airbyte.workers.helper.ConnectorConfigUpdater;
import io.airbyte.workers.helper.FailureHelper;
import io.airbyte.workers.helper.ThreadedTimeTracker;
import io.airbyte.workers.internal.AirbyteDestination;
Expand All @@ -50,6 +52,7 @@
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -98,6 +101,7 @@ public class DefaultReplicationWorker implements ReplicationWorker {
private final AtomicBoolean hasFailed;
private final RecordSchemaValidator recordSchemaValidator;
private final WorkerMetricReporter metricReporter;
private final ConnectorConfigUpdater connectorConfigUpdater;
private final boolean fieldSelectionEnabled;

public DefaultReplicationWorker(final String jobId,
Expand All @@ -108,6 +112,7 @@ public DefaultReplicationWorker(final String jobId,
final MessageTracker messageTracker,
final RecordSchemaValidator recordSchemaValidator,
final WorkerMetricReporter metricReporter,
final ConnectorConfigUpdater connectorConfigUpdater,
final boolean fieldSelectionEnabled) {
this.jobId = jobId;
this.attempt = attempt;
Expand All @@ -118,6 +123,7 @@ public DefaultReplicationWorker(final String jobId,
this.executors = Executors.newFixedThreadPool(2);
this.recordSchemaValidator = recordSchemaValidator;
this.metricReporter = metricReporter;
this.connectorConfigUpdater = connectorConfigUpdater;
this.fieldSelectionEnabled = fieldSelectionEnabled;

this.cancelled = new AtomicBoolean(false);
Expand Down Expand Up @@ -192,7 +198,7 @@ private void replicate(final Path jobRoot,
// note: `whenComplete` is used instead of `exceptionally` so that the original exception is still
// thrown
final CompletableFuture<?> readFromDstThread = CompletableFuture.runAsync(
readFromDstRunnable(destination, cancelled, messageTracker, mdc, timeTracker),
readFromDstRunnable(destination, cancelled, messageTracker, connectorConfigUpdater, mdc, timeTracker, destinationConfig.getDestinationId()),
executors)
.whenComplete((msg, ex) -> {
if (ex != null) {
Expand All @@ -213,10 +219,12 @@ private void replicate(final Path jobRoot,
cancelled,
mapper,
messageTracker,
connectorConfigUpdater,
mdc,
recordSchemaValidator,
metricReporter,
timeTracker,
sourceConfig.getSourceId(),
fieldSelectionEnabled),
executors)
.whenComplete((msg, ex) -> {
Expand Down Expand Up @@ -254,8 +262,10 @@ private void replicate(final Path jobRoot,
private static Runnable readFromDstRunnable(final AirbyteDestination destination,
final AtomicBoolean cancelled,
final MessageTracker messageTracker,
final ConnectorConfigUpdater connectorConfigUpdater,
final Map<String, String> mdc,
final ThreadedTimeTracker timeHolder) {
final ThreadedTimeTracker timeHolder,
final UUID destinationId) {
return () -> {
MDC.setContextMap(mdc);
LOGGER.info("Destination output thread started.");
Expand All @@ -268,8 +278,18 @@ private static Runnable readFromDstRunnable(final AirbyteDestination destination
throw new DestinationException("Destination process read attempt failed", e);
}
if (messageOptional.isPresent()) {
LOGGER.info("State in DefaultReplicationWorker from destination: {}", messageOptional.get());
messageTracker.acceptFromDestination(messageOptional.get());
final AirbyteMessage message = messageOptional.get();
LOGGER.info("State in DefaultReplicationWorker from destination: {}", message);

messageTracker.acceptFromDestination(message);

try {
if (message.getType() == Type.CONTROL) {
acceptDstControlMessage(destinationId, message.getControl(), connectorConfigUpdater);
}
} catch (final Exception e) {
LOGGER.error("Error updating destination configuration", e);
}
}
}
timeHolder.trackDestinationWriteEndTime();
Expand Down Expand Up @@ -301,10 +321,12 @@ private static Runnable readFromSrcAndWriteToDstRunnable(final AirbyteSource sou
final AtomicBoolean cancelled,
final AirbyteMapper mapper,
final MessageTracker messageTracker,
final ConnectorConfigUpdater connectorConfigUpdater,
final Map<String, String> mdc,
final RecordSchemaValidator recordSchemaValidator,
final WorkerMetricReporter metricReporter,
final ThreadedTimeTracker timeHolder,
final UUID sourceId,
final boolean fieldSelectionEnabled) {
return () -> {
MDC.setContextMap(mdc);
Expand Down Expand Up @@ -334,6 +356,14 @@ private static Runnable readFromSrcAndWriteToDstRunnable(final AirbyteSource sou

messageTracker.acceptFromSource(message);

try {
if (message.getType() == Type.CONTROL) {
acceptSrcControlMessage(sourceId, message.getControl(), connectorConfigUpdater);
}
} catch (final Exception e) {
LOGGER.error("Error updating source configuration", e);
}

try {
if (message.getType() == Type.RECORD || message.getType() == Type.STATE) {
destination.accept(message);
Expand Down Expand Up @@ -392,6 +422,22 @@ private static Runnable readFromSrcAndWriteToDstRunnable(final AirbyteSource sou
};
}

private static void acceptSrcControlMessage(final UUID sourceId,
final AirbyteControlMessage controlMessage,
final ConnectorConfigUpdater connectorConfigUpdater) {
if (controlMessage.getType() == AirbyteControlMessage.Type.CONNECTOR_CONFIG) {
connectorConfigUpdater.updateSource(sourceId, controlMessage.getConnectorConfig().getConfig());
}
}

private static void acceptDstControlMessage(final UUID destinationId,
final AirbyteControlMessage controlMessage,
final ConnectorConfigUpdater connectorConfigUpdater) {
if (controlMessage.getType() == AirbyteControlMessage.Type.CONNECTOR_CONFIG) {
connectorConfigUpdater.updateDestination(destinationId, controlMessage.getConnectorConfig().getConfig());
}
}

private ReplicationOutput getReplicationOutput(final StandardSyncInput syncInput,
final WorkerDestinationConfig destinationConfig,
final AtomicReference<FailureReason> replicationRunnableFailureRef,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Copyright (c) 2022 Airbyte, Inc., all rights reserved.
*/

package io.airbyte.workers.helper;

import com.google.common.hash.Hashing;
import io.airbyte.api.client.AirbyteApiClient;
import io.airbyte.api.client.generated.DestinationApi;
import io.airbyte.api.client.generated.SourceApi;
import io.airbyte.api.client.model.generated.DestinationIdRequestBody;
import io.airbyte.api.client.model.generated.DestinationRead;
import io.airbyte.api.client.model.generated.DestinationUpdate;
import io.airbyte.api.client.model.generated.SourceIdRequestBody;
import io.airbyte.api.client.model.generated.SourceRead;
import io.airbyte.api.client.model.generated.SourceUpdate;
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.Config;
import java.nio.charset.StandardCharsets;
import java.util.UUID;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Helper class for workers to persist updates to Source/Destination configs emitted from
* AirbyteControlMessages.
*
* This is in order to support connectors updating configs when running commands, which is specially
* useful for migrating configuration to a new version or for enabling connectors that require
* single-use or short-lived OAuth tokens.
*/
public class ConnectorConfigUpdater {

private static final Logger LOGGER = LoggerFactory.getLogger(ConnectorConfigUpdater.class);

private final SourceApi sourceApi;
private final DestinationApi destinationApi;

public ConnectorConfigUpdater(final SourceApi sourceApi, final DestinationApi destinationApi) {
this.sourceApi = sourceApi;
this.destinationApi = destinationApi;
}

/**
* Updates the Source from a sync job ID with the provided Configuration. Secrets and OAuth
* parameters will be masked when saving.
*/
public void updateSource(final UUID sourceId, final Config config) {
final SourceRead source = AirbyteApiClient.retryWithJitter(
() -> sourceApi.getSource(new SourceIdRequestBody().sourceId(sourceId)),
"get source");

final SourceRead updatedSource = AirbyteApiClient.retryWithJitter(
() -> sourceApi
.updateSource(new SourceUpdate()
.sourceId(sourceId)
.name(source.getName())
.connectionConfiguration(Jsons.jsonNode(config.getAdditionalProperties()))),
"update source");

LOGGER.info("Persisted updated configuration for source {}. New config hash: {}.", sourceId,
Hashing.sha256().hashString(updatedSource.getConnectionConfiguration().asText(), StandardCharsets.UTF_8));

}

/**
* Updates the Destination from a sync job ID with the provided Configuration. Secrets and OAuth
* parameters will be masked when saving.
*/
public void updateDestination(final UUID destinationId, final Config config) {
final DestinationRead destination = AirbyteApiClient.retryWithJitter(
() -> destinationApi.getDestination(new DestinationIdRequestBody().destinationId(destinationId)),
"get destination");

final DestinationRead updatedDestination = AirbyteApiClient.retryWithJitter(
() -> destinationApi
.updateDestination(new DestinationUpdate()
.destinationId(destinationId)
.name(destination.getName())
.connectionConfiguration(Jsons.jsonNode(config.getAdditionalProperties()))),
"update destination");

LOGGER.info("Persisted updated configuration for destination {}. New config hash: {}.", destinationId,
Hashing.sha256().hashString(updatedDestination.getConnectionConfiguration().asText(), StandardCharsets.UTF_8));
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.io.OutputStreamWriter;
import java.nio.file.Path;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -86,8 +87,9 @@ public void start(final WorkerDestinationConfig destinationConfig, final Path jo

writer = messageWriterFactory.createWriter(new BufferedWriter(new OutputStreamWriter(destinationProcess.getOutputStream(), Charsets.UTF_8)));

final List<Type> acceptedMessageTypes = List.of(Type.STATE, Type.TRACE, Type.CONTROL);
messageIterator = streamFactory.create(IOs.newBufferedReader(destinationProcess.getInputStream()))
.filter(message -> message.getType() == Type.STATE || message.getType() == Type.TRACE)
.filter(message -> acceptedMessageTypes.contains(message.getType()))
.iterator();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -91,9 +92,10 @@ public void start(final WorkerSourceConfig sourceConfig, final Path jobRoot) thr

logInitialStateAsJSON(sourceConfig);

final List<Type> acceptedMessageTypes = List.of(Type.RECORD, Type.STATE, Type.TRACE, Type.CONTROL);
messageIterator = streamFactory.create(IOs.newBufferedReader(sourceProcess.getInputStream()))
.peek(message -> heartbeatMonitor.beat())
.filter(message -> message.getType() == Type.RECORD || message.getType() == Type.STATE || message.getType() == Type.TRACE)
.filter(message -> acceptedMessageTypes.contains(message.getType()))
.iterator();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,12 +252,8 @@ private void handleEmittedOrchestratorMessage(final AirbyteControlMessage contro
@SuppressWarnings("PMD") // until method is implemented
private void handleEmittedOrchestratorConnectorConfig(final AirbyteControlConnectorConfigMessage configMessage,
final ConnectorType connectorType) {
// TODO: Update config here
/**
* Pseudocode: for (key in configMessage.getConfig()) { validateIsReallyConfig(key);
* persistConfigChange(connectorType, key, configMessage.getConfig().get(key)); // nuance here for
* secret storage or not. May need to be async over API for replication orchestrator }
*/
// Config updates are being persisted as part of the DefaultReplicationWorker.
// In the future, we could add tracking of these kinds of messages here. Nothing to do for now.
}

/**
Expand Down
Loading

0 comments on commit 2a38177

Please sign in to comment.