Skip to content

Commit

Permalink
[YAML] Fix error handling for KafkaSchemaTransforms (#29261) (#29289)
Browse files Browse the repository at this point in the history
  • Loading branch information
Polber committed Nov 3, 2023
1 parent 71e9895 commit 0e8e54c
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 90 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ public static Schema errorSchema(Schema inputSchema) {
Schema.Field.of("error_message", Schema.FieldType.STRING));
}

public static Schema errorSchemaBytes() {
return Schema.of(
Schema.Field.of("failed_row", Schema.FieldType.BYTES),
Schema.Field.of("error_message", Schema.FieldType.STRING));
}

@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
Expand All @@ -62,4 +68,14 @@ public static Row errorRecord(Schema errorSchema, Row inputRow, Throwable th) {
.withFieldValue("error_message", th.getMessage())
.build();
}

@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
public static Row errorRecord(Schema errorSchema, byte[] inputBytes, Throwable th) {
return Row.withSchema(errorSchema)
.withFieldValue("failed_row", inputBytes)
.withFieldValue("error_message", th.getMessage())
.build();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.beam.sdk.schemas.AutoValueSchema;
import org.apache.beam.sdk.schemas.annotations.DefaultSchema;
import org.apache.beam.sdk.schemas.annotations.SchemaFieldDescription;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;

/**
Expand Down Expand Up @@ -105,6 +106,10 @@ public static Builder builder() {
/** Sets the topic from which to read. */
public abstract String getTopic();

@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.")
@Nullable
public abstract ErrorHandling getErrorHandling();

/** Builder for the {@link KafkaReadSchemaTransformConfiguration}. */
@AutoValue.Builder
public abstract static class Builder {
Expand All @@ -127,6 +132,8 @@ public abstract static class Builder {
/** Sets the topic from which to read. */
public abstract Builder setTopic(String value);

public abstract Builder setErrorHandling(ErrorHandling errorHandling);

/** Builds a {@link KafkaReadSchemaTransformConfiguration} instance. */
public abstract KafkaReadSchemaTransformConfiguration build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,9 @@
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.sdk.schemas.utils.JsonUtils;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.FinishBundle;
import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
Expand Down Expand Up @@ -77,8 +76,6 @@ public class KafkaReadSchemaTransformProvider

public static final TupleTag<Row> OUTPUT_TAG = new TupleTag<Row>() {};
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
public static final Schema ERROR_SCHEMA =
Schema.builder().addStringField("error").addNullableByteArrayField("row").build();

final Boolean isTest;
final Integer testTimeoutSecs;
Expand All @@ -98,6 +95,9 @@ protected Class<KafkaReadSchemaTransformConfiguration> configurationClass() {
return KafkaReadSchemaTransformConfiguration.class;
}

@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
@Override
protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configuration) {
final String inputSchema = configuration.getSchema();
Expand All @@ -114,14 +114,32 @@ protected SchemaTransform from(KafkaReadSchemaTransformConfiguration configurati
consumerConfigs.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, autoOffsetReset);

String format = configuration.getFormat();

if (format != null && format.equals("RAW")) {
if (inputSchema != null) {
throw new IllegalArgumentException(
"To read from Kafka in RAW format, you can't provide a schema.");
boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());

if ((format != null && format.equals("RAW")) || (!Strings.isNullOrEmpty(inputSchema))) {
SerializableFunction<byte[], Row> valueMapper;
Schema beamSchema;
if (format != null && format.equals("RAW")) {
if (inputSchema != null) {
throw new IllegalArgumentException(
"To read from Kafka in RAW format, you can't provide a schema.");
}
beamSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
valueMapper = getRawBytesToRowFunction(beamSchema);
} else {
assert Strings.isNullOrEmpty(configuration.getConfluentSchemaRegistryUrl())
: "To read from Kafka, a schema must be provided directly or though Confluent "
+ "Schema Registry, but not both.";

beamSchema =
Objects.equals(format, "JSON")
? JsonUtils.beamSchemaFromJsonSchema(inputSchema)
: AvroUtils.toBeamSchema(new org.apache.avro.Schema.Parser().parse(inputSchema));
valueMapper =
Objects.equals(format, "JSON")
? JsonUtils.getJsonBytesToRowFunction(beamSchema)
: AvroUtils.getAvroBytesToRowFunction(beamSchema);
}
Schema rawSchema = Schema.builder().addField("payload", Schema.FieldType.BYTES).build();
SerializableFunction<byte[], Row> valueMapper = getRawBytesToRowFunction(rawSchema);
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
Expand All @@ -138,59 +156,23 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
PCollection<byte[]> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

Schema errorSchema = ErrorHandling.errorSchemaBytes();
PCollectionTuple outputTuple =
kafkaValues.apply(
ParDo.of(new ErrorFn("Kafka-read-error-counter", valueMapper))
ParDo.of(
new ErrorFn(
"Kafka-read-error-counter", valueMapper, errorSchema, handleErrors))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

return PCollectionRowTuple.of(
"output",
outputTuple.get(OUTPUT_TAG).setRowSchema(rawSchema),
"errors",
outputTuple.get(ERROR_TAG).setRowSchema(ERROR_SCHEMA));
}
};
}
PCollectionRowTuple outputRows =
PCollectionRowTuple.of(
"output", outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema));

if (inputSchema != null && !inputSchema.isEmpty()) {
assert Strings.isNullOrEmpty(configuration.getConfluentSchemaRegistryUrl())
: "To read from Kafka, a schema must be provided directly or though Confluent "
+ "Schema Registry, but not both.";

final Schema beamSchema =
Objects.equals(format, "JSON")
? JsonUtils.beamSchemaFromJsonSchema(inputSchema)
: AvroUtils.toBeamSchema(new org.apache.avro.Schema.Parser().parse(inputSchema));
SerializableFunction<byte[], Row> valueMapper =
Objects.equals(format, "JSON")
? JsonUtils.getJsonBytesToRowFunction(beamSchema)
: AvroUtils.getAvroBytesToRowFunction(beamSchema);
return new SchemaTransform() {
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
KafkaIO.Read<byte[], byte[]> kafkaRead =
KafkaIO.readBytes()
.withConsumerConfigUpdates(consumerConfigs)
.withConsumerFactoryFn(new ConsumerFactoryWithGcsTrustStores())
.withTopic(configuration.getTopic())
.withBootstrapServers(configuration.getBootstrapServers());
if (isTest) {
kafkaRead = kafkaRead.withMaxReadTime(Duration.standardSeconds(testTimeoutSecs));
PCollection<Row> errorOutput = outputTuple.get(ERROR_TAG).setRowSchema(errorSchema);
if (handleErrors) {
outputRows = outputRows.and(configuration.getErrorHandling().getOutput(), errorOutput);
}

PCollection<byte[]> kafkaValues =
input.getPipeline().apply(kafkaRead.withoutMetadata()).apply(Values.create());

PCollectionTuple outputTuple =
kafkaValues.apply(
ParDo.of(new ErrorFn("Kafka-read-error-counter", valueMapper))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

return PCollectionRowTuple.of(
"output",
outputTuple.get(OUTPUT_TAG).setRowSchema(beamSchema),
"errors",
outputTuple.get(ERROR_TAG).setRowSchema(ERROR_SCHEMA));
return outputRows;
}
};
} else {
Expand Down Expand Up @@ -259,25 +241,38 @@ public List<String> outputCollectionNames() {
}

public static class ErrorFn extends DoFn<byte[], Row> {
private SerializableFunction<byte[], Row> valueMapper;
private Counter errorCounter;
private final SerializableFunction<byte[], Row> valueMapper;
private final Counter errorCounter;
private Long errorsInBundle = 0L;

public ErrorFn(String name, SerializableFunction<byte[], Row> valueMapper) {
private final boolean handleErrors;
private final Schema errorSchema;

public ErrorFn(
String name,
SerializableFunction<byte[], Row> valueMapper,
Schema errorSchema,
boolean handleErrors) {
this.errorCounter = Metrics.counter(KafkaReadSchemaTransformProvider.class, name);
this.valueMapper = valueMapper;
this.handleErrors = handleErrors;
this.errorSchema = errorSchema;
}

@ProcessElement
public void process(@DoFn.Element byte[] msg, MultiOutputReceiver receiver) {
Row mappedRow = null;
try {
receiver.get(OUTPUT_TAG).output(valueMapper.apply(msg));
mappedRow = valueMapper.apply(msg);
} catch (Exception e) {
if (!handleErrors) {
throw new RuntimeException(e);
}
errorsInBundle += 1;
LOG.warn("Error while parsing the element", e);
receiver
.get(ERROR_TAG)
.output(Row.withSchema(ERROR_SCHEMA).addValues(e.toString(), msg).build());
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, msg, e));
}
if (mappedRow != null) {
receiver.get(OUTPUT_TAG).output(mappedRow);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,14 @@
import org.apache.beam.sdk.schemas.transforms.SchemaTransform;
import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.TypedSchemaTransformProvider;
import org.apache.beam.sdk.schemas.transforms.providers.ErrorHandling;
import org.apache.beam.sdk.schemas.utils.JsonUtils;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFn.ProcessElement;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionRowTuple;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.Row;
Expand All @@ -67,8 +68,6 @@ public class KafkaWriteSchemaTransformProvider
public static final TupleTag<Row> ERROR_TAG = new TupleTag<Row>() {};
public static final TupleTag<KV<byte[], byte[]>> OUTPUT_TAG =
new TupleTag<KV<byte[], byte[]>>() {};
public static final Schema ERROR_SCHEMA =
Schema.builder().addStringField("error").addNullableByteArrayField("row").build();
private static final Logger LOG =
LoggerFactory.getLogger(KafkaWriteSchemaTransformProvider.class);

Expand Down Expand Up @@ -100,25 +99,38 @@ static final class KafkaWriteSchemaTransform extends SchemaTransform implements
}

public static class ErrorCounterFn extends DoFn<Row, KV<byte[], byte[]>> {
private SerializableFunction<Row, byte[]> toBytesFn;
private Counter errorCounter;
private final SerializableFunction<Row, byte[]> toBytesFn;
private final Counter errorCounter;
private Long errorsInBundle = 0L;
private final boolean handleErrors;
private final Schema errorSchema;

public ErrorCounterFn(String name, SerializableFunction<Row, byte[]> toBytesFn) {
public ErrorCounterFn(
String name,
SerializableFunction<Row, byte[]> toBytesFn,
Schema errorSchema,
boolean handleErrors) {
this.toBytesFn = toBytesFn;
errorCounter = Metrics.counter(KafkaWriteSchemaTransformProvider.class, name);
this.errorCounter = Metrics.counter(KafkaWriteSchemaTransformProvider.class, name);
this.handleErrors = handleErrors;
this.errorSchema = errorSchema;
}

@ProcessElement
public void process(@DoFn.Element Row row, MultiOutputReceiver receiver) {
KV<byte[], byte[]> output = null;
try {
receiver.get(OUTPUT_TAG).output(KV.of(new byte[1], toBytesFn.apply(row)));
output = KV.of(new byte[1], toBytesFn.apply(row));
} catch (Exception e) {
if (!handleErrors) {
throw new RuntimeException(e);
}
errorsInBundle += 1;
LOG.warn("Error while processing the element", e);
receiver
.get(ERROR_TAG)
.output(Row.withSchema(ERROR_SCHEMA).addValues(e.toString(), row.toString()).build());
receiver.get(ERROR_TAG).output(ErrorHandling.errorRecord(errorSchema, row, e));
}
if (output != null) {
receiver.get(OUTPUT_TAG).output(output);
}
}

Expand All @@ -129,6 +141,9 @@ public void finish() {
}
}

@SuppressWarnings({
"nullness" // TODO(https://github.com/apache/beam/issues/20497)
})
@Override
public PCollectionRowTuple expand(PCollectionRowTuple input) {
Schema inputSchema = input.get("input").getSchema();
Expand All @@ -145,13 +160,17 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
toBytesFn = AvroUtils.getRowToAvroBytesFunction(inputSchema);
}

boolean handleErrors = ErrorHandling.hasOutput(configuration.getErrorHandling());
final Map<String, String> configOverrides = configuration.getProducerConfigUpdates();
Schema errorSchema = ErrorHandling.errorSchema(inputSchema);
PCollectionTuple outputTuple =
input
.get("input")
.apply(
"Map rows to Kafka messages",
ParDo.of(new ErrorCounterFn("Kafka-write-error-counter", toBytesFn))
ParDo.of(
new ErrorCounterFn(
"Kafka-write-error-counter", toBytesFn, errorSchema, handleErrors))
.withOutputTags(OUTPUT_TAG, TupleTagList.of(ERROR_TAG)));

outputTuple
Expand All @@ -167,8 +186,11 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) {
.withKeySerializer(ByteArraySerializer.class)
.withValueSerializer(ByteArraySerializer.class));

// TODO: include output from KafkaIO Write once updated from PDone
PCollection<Row> errorOutput =
outputTuple.get(ERROR_TAG).setRowSchema(ErrorHandling.errorSchema(errorSchema));
return PCollectionRowTuple.of(
"errors", outputTuple.get(ERROR_TAG).setRowSchema(ERROR_SCHEMA));
handleErrors ? configuration.getErrorHandling().getOutput() : "errors", errorOutput);
}
}

Expand Down Expand Up @@ -227,6 +249,10 @@ public abstract static class KafkaWriteSchemaTransformConfiguration implements S
@Nullable
public abstract Map<String, String> getProducerConfigUpdates();

@SchemaFieldDescription("This option specifies whether and where to output unwritable rows.")
@Nullable
public abstract ErrorHandling getErrorHandling();

public static Builder builder() {
return new AutoValue_KafkaWriteSchemaTransformProvider_KafkaWriteSchemaTransformConfiguration
.Builder();
Expand All @@ -242,6 +268,8 @@ public abstract static class Builder {

public abstract Builder setProducerConfigUpdates(Map<String, String> producerConfigUpdates);

public abstract Builder setErrorHandling(ErrorHandling errorHandling);

public abstract KafkaWriteSchemaTransformConfiguration build();
}
}
Expand Down
Loading

0 comments on commit 0e8e54c

Please sign in to comment.