Skip to content

Commit

Permalink
[CsvIO]: Implement CsvIOParse::withCustomRecordParsing method (#32142)
Browse files Browse the repository at this point in the history
* completed implementation without tests

Co-authored-by: Lahari Guduru <lahariguduru@google.com>

* intermediate stage

Co-authored-by: Lahari Guduru <lahariguduru@google.com>

* Implement CsvIOParse.withCustomRecordParsing

Co-authored-by: Lahari Guduru <lahariguduru@google.com>

---------

Co-authored-by: Lahari Guduru <lahariguduru@google.com>
  • Loading branch information
francisohara24 and lahariguduru committed Aug 9, 2024
1 parent b21a84a commit fc5a71d
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.apache.beam.sdk.values.PCollectionList;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
import org.checkerframework.checker.nullness.qual.NonNull;

/**
* {@link PTransform} for Parsing CSV Record Strings into {@link Schema}-mapped target types. {@link
Expand All @@ -43,9 +44,30 @@ static <T> CsvIOParse.Builder<T> builder() {
return new AutoValue_CsvIOParse.Builder<>();
}

// TODO(https://github.com/apache/beam/issues/31875): Implement in future PR.
public CsvIOParse<T> withCustomRecordParsing(
Map<String, SerializableFunction<String, Object>> customProcessingMap) {
/**
* Configures custom cell parsing.
*
* <h2>Example</h2>
*
* <pre>{@code
* CsvIO.parse().withCustomRecordParsing("listOfInts", cell-> {
*
* List<Integer> result = new ArrayList<>();
* for (String stringValue: Splitter.on(";").split(cell)) {
* result.add(Integer.parseInt(stringValue));
* }
*
* });
* }</pre>
*/
public <OutputT extends @NonNull Object> CsvIOParse<T> withCustomRecordParsing(
String fieldName, SerializableFunction<String, OutputT> customRecordParsingFn) {

Map<String, SerializableFunction<String, Object>> customProcessingMap =
getConfigBuilder().getOrCreateCustomProcessingMap();

customProcessingMap.put(fieldName, customRecordParsingFn::apply);
getConfigBuilder().setCustomProcessingMap(customProcessingMap);
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,26 @@ abstract static class Builder<T> implements Serializable {
abstract Builder<T> setCustomProcessingMap(
Map<String, SerializableFunction<String, Object>> customProcessingMap);

abstract Optional<Map<String, SerializableFunction<String, Object>>> getCustomProcessingMap();

final Map<String, SerializableFunction<String, Object>> getOrCreateCustomProcessingMap() {
if (!getCustomProcessingMap().isPresent()) {
setCustomProcessingMap(new HashMap<>());
}
return getCustomProcessingMap().get();
}

abstract Builder<T> setCoder(Coder<T> coder);

abstract Builder<T> setFromRowFn(SerializableFunction<Row, T> fromRowFn);

abstract Optional<Map<String, SerializableFunction<String, Object>>> getCustomProcessingMap();

abstract CsvIOParseConfiguration<T> autoBuild();

final CsvIOParseConfiguration<T> build() {
if (!getCustomProcessingMap().isPresent()) {
setCustomProcessingMap(new HashMap<>());
}

return autoBuild();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@

import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.NULLABLE_ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.TIME_CONTAINING_SCHEMA;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.TIME_CONTAINING_TYPE_DESCRIPTOR;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.TimeContaining;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.nullableAllPrimitiveDataTypes;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.nullableAllPrimitiveDataTypesFromRowFn;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.nullableAllPrimitiveDataTypesToRowFn;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.timeContaining;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.timeContainingFromRowFn;
import static org.apache.beam.sdk.io.common.SchemaAwareJavaBeans.timeContainingToRowFn;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
Expand All @@ -38,17 +45,22 @@
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.Row;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Splitter;
import org.apache.commons.csv.CSVFormat;
import org.joda.time.Instant;
import org.joda.time.format.DateTimeFormat;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link CsvIOParse}. */
@RunWith(JUnit4.class)
public class CsvIOParseTest {

Expand All @@ -61,6 +73,12 @@ public class CsvIOParseTest {
NULLABLE_ALL_PRIMITIVE_DATA_TYPES_TYPE_DESCRIPTOR,
nullableAllPrimitiveDataTypesToRowFn(),
nullableAllPrimitiveDataTypesFromRowFn());
private static final Coder<TimeContaining> TIME_CONTAINING_CODER =
SchemaCoder.of(
TIME_CONTAINING_SCHEMA,
TIME_CONTAINING_TYPE_DESCRIPTOR,
timeContainingToRowFn(),
timeContainingFromRowFn());
private static final SerializableFunction<Row, Row> ROW_ROW_SERIALIZABLE_FUNCTION = row -> row;
@Rule public final TestPipeline pipeline = TestPipeline.create();

Expand Down Expand Up @@ -120,7 +138,7 @@ public void parseRows() {
underTest(
NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA,
csvFormat(),
emptyCustomProcessingMap(),
new HashMap<>(),
ROW_ROW_SERIALIZABLE_FUNCTION,
RowCoder.of(NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA)));
PAssert.that(result.getOutput()).containsInAnyOrder(want);
Expand Down Expand Up @@ -152,7 +170,7 @@ public void parsePOJOs() {
underTest(
NULLABLE_ALL_PRIMITIVE_DATA_TYPES_SCHEMA,
csvFormat(),
emptyCustomProcessingMap(),
new HashMap<>(),
nullableAllPrimitiveDataTypesFromRowFn(),
NULLABLE_ALL_PRIMITIVE_DATA_TYPES_CODER));
PAssert.that(result.getOutput()).containsInAnyOrder(want);
Expand All @@ -161,6 +179,98 @@ public void parsePOJOs() {
pipeline.run();
}

@Test
public void givenSingleCustomParsingLambda_parsesPOJOs() {
PCollection<String> records =
csvRecords(
pipeline,
"instant,instantList",
"2024-01-23T10:00:05.000Z,10-00-05-2024-01-23;12-59-59-2024-01-24");
TimeContaining want =
timeContaining(
Instant.parse("2024-01-23T10:00:05.000Z"),
Arrays.asList(
Instant.parse("2024-01-23T10:00:05.000Z"),
Instant.parse("2024-01-24T12:59:59.000Z")));

CsvIOParse<TimeContaining> underTest =
underTest(
TIME_CONTAINING_SCHEMA,
CSVFormat.DEFAULT
.withHeader("instant", "instantList")
.withAllowDuplicateHeaderNames(false),
new HashMap<>(),
timeContainingFromRowFn(),
TIME_CONTAINING_CODER)
.withCustomRecordParsing("instantList", instantListParsingLambda());

CsvIOParseResult<TimeContaining> result = records.apply(underTest);
PAssert.that(result.getOutput()).containsInAnyOrder(want);
PAssert.that(result.getErrors()).empty();

pipeline.run();
}

@Test
public void givenMultipleCustomParsingLambdas_parsesPOJOs() {
PCollection<String> records =
csvRecords(
pipeline,
"instant,instantList",
"2024-01-23@10:00:05,10-00-05-2024-01-23;12-59-59-2024-01-24");
TimeContaining want =
timeContaining(
Instant.parse("2024-01-23T10:00:05.000Z"),
Arrays.asList(
Instant.parse("2024-01-23T10:00:05.000Z"),
Instant.parse("2024-01-24T12:59:59.000Z")));

CsvIOParse<TimeContaining> underTest =
underTest(
TIME_CONTAINING_SCHEMA,
CSVFormat.DEFAULT
.withHeader("instant", "instantList")
.withAllowDuplicateHeaderNames(false),
new HashMap<>(),
timeContainingFromRowFn(),
TIME_CONTAINING_CODER)
.withCustomRecordParsing(
"instant",
input ->
DateTimeFormat.forPattern("yyyy-MM-dd@HH:mm:ss")
.parseDateTime(input)
.toInstant())
.withCustomRecordParsing("instantList", instantListParsingLambda());

CsvIOParseResult<TimeContaining> result = records.apply(underTest);
PAssert.that(result.getOutput()).containsInAnyOrder(want);
PAssert.that(result.getErrors()).empty();

pipeline.run();
}

@Test
public void givenCustomParsingError_emits() {
PCollection<String> records =
csvRecords(pipeline, "instant,instantList", "2024-01-23T10:00:05.000Z,BAD CELL");
CsvIOParse<TimeContaining> underTest =
underTest(
TIME_CONTAINING_SCHEMA,
CSVFormat.DEFAULT
.withHeader("instant", "instantList")
.withAllowDuplicateHeaderNames(false),
new HashMap<>(),
timeContainingFromRowFn(),
TIME_CONTAINING_CODER)
.withCustomRecordParsing("instantList", instantListParsingLambda());

CsvIOParseResult<TimeContaining> result = records.apply(underTest);
PAssert.that(result.getOutput()).empty();
PAssert.thatSingleton(result.getErrors().apply(Count.globally())).isEqualTo(1L);

pipeline.run();
}

private static CSVFormat csvFormat() {
return CSVFormat.DEFAULT
.withAllowDuplicateHeaderNames(false)
Expand Down Expand Up @@ -191,7 +301,16 @@ private static <T> CsvIOParse<T> underTest(
return CsvIOParse.<T>builder().setConfigBuilder(configBuilder).build();
}

private static Map<String, SerializableFunction<String, Object>> emptyCustomProcessingMap() {
return new HashMap<>();
private static SerializableFunction<String, List<Instant>> instantListParsingLambda() {
return input -> {
Iterable<String> cells = Splitter.on(';').split(input);
;
List<Instant> output = new ArrayList<>();
for (String cell : cells) {
output.add(
DateTimeFormat.forPattern("HH-mm-ss-yyyy-MM-dd").parseDateTime(cell).toInstant());
}
return output;
};
}
}

0 comments on commit fc5a71d

Please sign in to comment.