Skip to content

Commit

Permalink
Progress Bar Estimate (#19814)
Browse files Browse the repository at this point in the history
Implement estimate message processing allowing the platform to hold on to estimate message counts in memory.

The estimate message is protocol message connectors can choose to emit to provide support for progress bar calculations. There are two kinds of estimates, per-Sync or per-Stream. Sources cannot emit both types in a single sync.

Per-stream estimates are what we usually expect. Per-sync estimates are for sources that cannot provide more granular estimates for whatever reasons e.g. CDC sources.

In a follow up PR, the platform will periodically save these messages through the save stats api.
  • Loading branch information
davinchia authored Nov 29, 2022
1 parent 0870187 commit a1b9db5
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Charsets;
import com.google.common.base.Preconditions;
import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.hash.HashFunction;
Expand All @@ -19,6 +20,7 @@
import io.airbyte.config.State;
import io.airbyte.protocol.models.AirbyteControlConnectorConfigMessage;
import io.airbyte.protocol.models.AirbyteControlMessage;
import io.airbyte.protocol.models.AirbyteEstimateTraceMessage;
import io.airbyte.protocol.models.AirbyteMessage;
import io.airbyte.protocol.models.AirbyteRecordMessage;
import io.airbyte.protocol.models.AirbyteStateMessage;
Expand Down Expand Up @@ -51,6 +53,7 @@ public class AirbyteMessageTracker implements MessageTracker {
private final Map<Short, Long> streamToRunningCount;
private final HashFunction hashFunction;
private final BiMap<AirbyteStreamNameNamespacePair, Short> nameNamespacePairToIndex;
private final Map<AirbyteStreamNameNamespacePair, StreamStats> nameNamespacePairToStreamStats;
private final Map<Short, Long> streamToTotalBytesEmitted;
private final Map<Short, Long> streamToTotalRecordsEmitted;
private final StateDeltaTracker stateDeltaTracker;
Expand All @@ -60,6 +63,11 @@ public class AirbyteMessageTracker implements MessageTracker {
private final StateAggregator stateAggregator;
private final boolean logConnectorMessages = new EnvVariableFeatureFlags().logConnectorMessages();

// These variables support SYNC level estimates and are meant for sources where stream level
// estimates are not possible e.g. CDC sources.
private Long totalRecordsEstimatedSync;
private Long totalBytesEstimatedSync;

private short nextStreamIndex;

/**
Expand All @@ -78,6 +86,11 @@ private enum ConnectorType {
DESTINATION
}

/**
* POJO for all per-stream stats.
*/
private record StreamStats(long estimatedBytes, long emittedBytes, long estimatedRecords, long emittedRecords) {}

public AirbyteMessageTracker() {
this(new StateDeltaTracker(STATE_DELTA_TRACKER_MEMORY_LIMIT_BYTES),
new DefaultStateAggregator(new EnvVariableFeatureFlags().useStreamCapableState()),
Expand All @@ -93,6 +106,7 @@ protected AirbyteMessageTracker(final StateDeltaTracker stateDeltaTracker,
this.streamToRunningCount = new HashMap<>();
this.nameNamespacePairToIndex = HashBiMap.create();
this.hashFunction = Hashing.murmur3_32_fixed();
this.nameNamespacePairToStreamStats = new HashMap<>();
this.streamToTotalBytesEmitted = new HashMap<>();
this.streamToTotalRecordsEmitted = new HashMap<>();
this.stateDeltaTracker = stateDeltaTracker;
Expand Down Expand Up @@ -252,7 +266,7 @@ private void handleEmittedOrchestratorConnectorConfig(final AirbyteControlConnec
*/
private void handleEmittedTrace(final AirbyteTraceMessage traceMessage, final ConnectorType connectorType) {
switch (traceMessage.getType()) {
case ESTIMATE -> handleEmittedEstimateTrace(traceMessage, connectorType);
case ESTIMATE -> handleEmittedEstimateTrace(traceMessage.getEstimate());
case ERROR -> handleEmittedErrorTrace(traceMessage, connectorType);
default -> log.warn("Invalid message type for trace message: {}", traceMessage);
}
Expand All @@ -266,8 +280,34 @@ private void handleEmittedErrorTrace(final AirbyteTraceMessage errorTraceMessage
}
}

@SuppressWarnings("PMD") // until method is implemented
private void handleEmittedEstimateTrace(final AirbyteTraceMessage estimateTraceMessage, final ConnectorType connectorType) {
/**
* There are several assumptions here:
* <p>
* - Assume the estimate is a whole number and not a sum i.e. each estimate replaces the previous
* estimate.
* <p>
* - Sources cannot emit both STREAM and SYNC estimates in a same sync. Error out if this happens.
*/
@SuppressWarnings("PMD.AvoidDuplicateLiterals")
private void handleEmittedEstimateTrace(final AirbyteEstimateTraceMessage estimate) {
switch (estimate.getType()) {
case STREAM -> {
Preconditions.checkArgument(totalBytesEstimatedSync == null, "STREAM and SYNC estimates should not be emitted in the same sync.");
Preconditions.checkArgument(totalRecordsEstimatedSync == null, "STREAM and SYNC estimates should not be emitted in the same sync.");

log.debug("Saving stream estimates for namespace: {}, stream: {}", estimate.getNamespace(), estimate.getName());
nameNamespacePairToStreamStats.put(
new AirbyteStreamNameNamespacePair(estimate.getName(), estimate.getNamespace()),
new StreamStats(estimate.getByteEstimate(), 0L, estimate.getRowEstimate(), 0L));
}
case SYNC -> {
Preconditions.checkArgument(nameNamespacePairToStreamStats.isEmpty(), "STREAM and SYNC estimates should not be emitted in the same sync.");

log.debug("Saving sync estimates");
totalBytesEstimatedSync = estimate.getByteEstimate();
totalRecordsEstimatedSync = estimate.getRowEstimate();
}
}

}

Expand Down Expand Up @@ -368,6 +408,17 @@ public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedRecords() {
entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue));
}

/**
* Swap out stream indices for stream names and return total records estimated by stream.
*/
@Override
public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEstimatedRecords() {
return nameNamespacePairToStreamStats.entrySet().stream().collect(
Collectors.toMap(
Entry::getKey,
entry -> entry.getValue().estimatedRecords()));
}

/**
* Swap out stream indices for stream names and return total bytes emitted by stream.
*/
Expand All @@ -377,6 +428,17 @@ public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedBytes() {
entry -> nameNamespacePairToIndex.inverse().get(entry.getKey()), Entry::getValue));
}

/**
* Swap out stream indices for stream names and return total bytes estimated by stream.
*/
@Override
public Map<AirbyteStreamNameNamespacePair, Long> getStreamToEstimatedBytes() {
return nameNamespacePairToStreamStats.entrySet().stream().collect(
Collectors.toMap(
Entry::getKey,
entry -> entry.getValue().estimatedBytes()));
}

/**
* Compute sum of emitted record counts across all streams.
*/
Expand All @@ -385,6 +447,20 @@ public long getTotalRecordsEmitted() {
return streamToTotalRecordsEmitted.values().stream().reduce(0L, Long::sum);
}

/**
* Compute sum of estimated record counts across all streams.
*/
@Override
public long getTotalRecordsEstimated() {
if (!nameNamespacePairToStreamStats.isEmpty()) {
return nameNamespacePairToStreamStats.values().stream()
.map(e -> e.estimatedRecords)
.reduce(0L, Long::sum);
}

return totalRecordsEstimatedSync;
}

/**
* Compute sum of emitted bytes across all streams.
*/
Expand All @@ -393,6 +469,20 @@ public long getTotalBytesEmitted() {
return streamToTotalBytesEmitted.values().stream().reduce(0L, Long::sum);
}

/**
* Compute sum of estimated bytes across all streams.
*/
@Override
public long getTotalBytesEstimated() {
if (!nameNamespacePairToStreamStats.isEmpty()) {
return nameNamespacePairToStreamStats.values().stream()
.map(e -> e.estimatedBytes)
.reduce(0L, Long::sum);
}

return totalBytesEstimatedSync;
}

/**
* Compute sum of committed record counts across all streams. If the delta tracker has exceeded its
* capacity, return empty because committed record counts cannot be reliably computed.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@ public interface MessageTracker {
*/
Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedRecords();

/**
* Get the per-stream estimated record count provided by
* {@link io.airbyte.protocol.models.AirbyteEstimateTraceMessage}.
*
* @return returns a map of estimated record count by stream name.
*/
Map<AirbyteStreamNameNamespacePair, Long> getStreamToEstimatedRecords();

/**
* Get the per-stream emitted byte count. This includes messages that were emitted by the source,
* but never committed by the destination.
Expand All @@ -74,6 +82,14 @@ public interface MessageTracker {
*/
Map<AirbyteStreamNameNamespacePair, Long> getStreamToEmittedBytes();

/**
* Get the per-stream estimated byte count provided by
* {@link io.airbyte.protocol.models.AirbyteEstimateTraceMessage}.
*
* @return returns a map of estimated bytes by stream name.
*/
Map<AirbyteStreamNameNamespacePair, Long> getStreamToEstimatedBytes();

/**
* Get the overall emitted record count. This includes messages that were emitted by the source, but
* never committed by the destination.
Expand All @@ -82,6 +98,13 @@ public interface MessageTracker {
*/
long getTotalRecordsEmitted();

/**
* Get the overall estimated record count.
*
* @return returns the total count of estimated records across all streams.
*/
long getTotalRecordsEstimated();

/**
* Get the overall emitted bytes. This includes messages that were emitted by the source, but never
* committed by the destination.
Expand All @@ -90,6 +113,13 @@ public interface MessageTracker {
*/
long getTotalBytesEmitted();

/**
* Get the overall estimated bytes.
*
* @return returns the total count of estimated bytes across all streams.
*/
long getTotalBytesEstimated();

/**
* Get the overall committed record count.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.google.common.collect.ImmutableMap;
import io.airbyte.commons.json.Jsons;
import io.airbyte.protocol.models.AirbyteErrorTraceMessage;
import io.airbyte.protocol.models.AirbyteEstimateTraceMessage;
import io.airbyte.protocol.models.AirbyteGlobalState;
import io.airbyte.protocol.models.AirbyteLogMessage;
import io.airbyte.protocol.models.AirbyteMessage;
Expand Down Expand Up @@ -102,29 +103,60 @@ public static AirbyteStreamState createStreamState(final String streamName) {
return new AirbyteStreamState().withStreamDescriptor(new StreamDescriptor().withName(streamName));
}

public static AirbyteMessage createStreamEstimateMessage(final String name, final String namespace, final long byteEst, final long rowEst) {
return createEstimateMessage(AirbyteEstimateTraceMessage.Type.STREAM, name, namespace, byteEst, rowEst);
}

public static AirbyteMessage createSyncEstimateMessage(final long byteEst, final long rowEst) {
return createEstimateMessage(AirbyteEstimateTraceMessage.Type.SYNC, null, null, byteEst, rowEst);
}

public static AirbyteMessage createEstimateMessage(AirbyteEstimateTraceMessage.Type type,
final String name,
final String namespace,
final long byteEst,
final long rowEst) {
final var est = new AirbyteEstimateTraceMessage()
.withType(type)
.withByteEstimate(byteEst)
.withRowEstimate(rowEst);

if (name != null) {
est.withName(name);
}
if (namespace != null) {
est.withNamespace(namespace);
}

return new AirbyteMessage()
.withType(Type.TRACE)
.withTrace(new AirbyteTraceMessage().withType(AirbyteTraceMessage.Type.ESTIMATE)
.withEstimate(est));
}

public static AirbyteMessage createErrorMessage(final String message, final Double emittedAt) {
return new AirbyteMessage()
.withType(AirbyteMessage.Type.TRACE)
.withTrace(createErrorTraceMessage(message, emittedAt));
}

public static AirbyteTraceMessage createErrorTraceMessage(final String message, final Double emittedAt) {
return new AirbyteTraceMessage()
.withType(io.airbyte.protocol.models.AirbyteTraceMessage.Type.ERROR)
.withEmittedAt(emittedAt)
.withError(new AirbyteErrorTraceMessage().withMessage(message));
return createErrorTraceMessage(message, emittedAt, null);
}

public static AirbyteTraceMessage createErrorTraceMessage(final String message,
final Double emittedAt,
final AirbyteErrorTraceMessage.FailureType failureType) {
return new AirbyteTraceMessage()
final var msg = new AirbyteTraceMessage()
.withType(io.airbyte.protocol.models.AirbyteTraceMessage.Type.ERROR)
.withEmittedAt(emittedAt)
.withError(new AirbyteErrorTraceMessage().withMessage(message).withFailureType(failureType));
}
.withError(new AirbyteErrorTraceMessage().withMessage(message))
.withEmittedAt(emittedAt);

public static AirbyteMessage createTraceMessage(final String message, final Double emittedAt) {
return new AirbyteMessage()
.withType(AirbyteMessage.Type.TRACE)
.withTrace(new AirbyteTraceMessage()
.withType(AirbyteTraceMessage.Type.ERROR)
.withEmittedAt(emittedAt)
.withError(new AirbyteErrorTraceMessage().withMessage(message)));
if (failureType != null) {
msg.getError().withFailureType(failureType);
}

return msg;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ void testReplicationRunnableWorkerFailure() throws Exception {
@Test
void testOnlyStateAndRecordMessagesDeliveredToDestination() throws Exception {
final AirbyteMessage LOG_MESSAGE = AirbyteMessageUtils.createLogMessage(Level.INFO, "a log message");
final AirbyteMessage TRACE_MESSAGE = AirbyteMessageUtils.createTraceMessage("a trace message", 123456.0);
final AirbyteMessage TRACE_MESSAGE = AirbyteMessageUtils.createErrorMessage("a trace message", 123456.0);
when(mapper.mapMessage(LOG_MESSAGE)).thenReturn(LOG_MESSAGE);
when(mapper.mapMessage(TRACE_MESSAGE)).thenReturn(TRACE_MESSAGE);
when(source.isFinished()).thenReturn(false, false, false, false, true);
Expand Down
Loading

0 comments on commit a1b9db5

Please sign in to comment.