Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][Connector][FileBase]Parquet reader parsing array type exception. #4457

Merged
merged 14 commits into from
Nov 28, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.avro.data.TimeConversions;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FSDataInputStream;
import org.apache.hadoop.fs.FileSystem;
Expand Down Expand Up @@ -132,7 +133,16 @@ private Object resolveObject(Object field, SeaTunnelDataType<?> fieldType) {
switch (fieldType.getSqlType()) {
case ARRAY:
ArrayList<Object> origArray = new ArrayList<>();
((GenericData.Array<?>) field).iterator().forEachRemaining(origArray::add);
((GenericData.Array<?>) field)
.iterator()
.forEachRemaining(
ele -> {
if (ele instanceof Utf8) {
origArray.add(ele.toString());
} else {
origArray.add(ele);
}
});
SeaTunnelDataType<?> elementType = ((ArrayType<?, ?>) fieldType).getElementType();
switch (elementType.getSqlType()) {
case STRING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,30 @@
import org.apache.seatunnel.shade.com.typesafe.config.ConfigFactory;

import org.apache.seatunnel.api.source.Collector;
import org.apache.seatunnel.api.table.type.ArrayType;
import org.apache.seatunnel.api.table.type.SeaTunnelRow;
import org.apache.seatunnel.api.table.type.SeaTunnelRowType;
import org.apache.seatunnel.connectors.seatunnel.file.config.HadoopConf;
import org.apache.seatunnel.connectors.seatunnel.file.source.reader.ParquetReadStrategy;

import org.apache.avro.Schema;
import org.apache.avro.generic.GenericArray;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
import org.apache.avro.util.Utf8;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.parquet.avro.AvroParquetWriter;
import org.apache.parquet.hadoop.ParquetWriter;
import org.apache.parquet.hadoop.metadata.CompressionCodecName;

import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledOnOs;
import org.junit.jupiter.api.condition.OS;

import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.nio.file.Paths;
import java.time.LocalDateTime;
Expand Down Expand Up @@ -154,6 +169,29 @@ public void testParquetReadProjection2() throws Exception {
parquetReadStrategy.read(path, "", testCollector);
}

@DisabledOnOs(OS.WINDOWS)
@Test
public void testParquetReadArray() throws Exception {
AutoGenerateParquetData.generateTestData();
ParquetReadStrategy parquetReadStrategy = new ParquetReadStrategy();
LocalConf localConf = new LocalConf(FS_DEFAULT_NAME_DEFAULT);
parquetReadStrategy.init(localConf);
SeaTunnelRowType seaTunnelRowTypeInfo =
parquetReadStrategy.getSeaTunnelRowTypeInfo(
localConf, AutoGenerateParquetData.DATA_FILE_PATH);
Assertions.assertNotNull(seaTunnelRowTypeInfo);
Assertions.assertEquals(seaTunnelRowTypeInfo.getFieldType(3).getClass(), ArrayType.class);
TestCollector testCollector = new TestCollector();
parquetReadStrategy.read(AutoGenerateParquetData.DATA_FILE_PATH, "1", testCollector);
List<SeaTunnelRow> rows = testCollector.getRows();
SeaTunnelRow seaTunnelRow = rows.get(0);
Assertions.assertEquals(seaTunnelRow.getField(1).toString(), "Alice");
String[] arrayData = (String[]) seaTunnelRow.getField(3);
Assertions.assertEquals(arrayData.length, 2);
Assertions.assertEquals(arrayData[0], "Java");
AutoGenerateParquetData.deleteFile();
}

public static class TestCollector implements Collector<SeaTunnelRow> {

private final List<SeaTunnelRow> rows = new ArrayList<>();
Expand Down Expand Up @@ -192,4 +230,58 @@ public String getSchema() {
return SCHEMA;
}
}

public static class AutoGenerateParquetData {

public static final String DATA_FILE_PATH = "/tmp/data.parquet";

public static void generateTestData() throws IOException {
deleteFile();
String schemaString =
"{\"type\":\"record\",\"name\":\"User\",\"fields\":[{\"name\":\"id\",\"type\":\"int\"},{\"name\":\"name\",\"type\":\"string\"},{\"name\":\"salary\",\"type\":\"double\"},{\"name\":\"skills\",\"type\":{\"type\":\"array\",\"items\":\"string\"}}]}";
Schema schema = new Schema.Parser().parse(schemaString);

Configuration conf = new Configuration();

Path file = new Path(DATA_FILE_PATH);

ParquetWriter<GenericRecord> writer =
AvroParquetWriter.<GenericRecord>builder(file)
.withSchema(schema)
.withConf(conf)
.withCompressionCodec(CompressionCodecName.SNAPPY)
.build();

GenericRecord record1 = new GenericData.Record(schema);
record1.put("id", 1);
record1.put("name", "Alice");
record1.put("salary", 50000.0);
GenericArray<Utf8> skills1 =
new GenericData.Array<>(2, schema.getField("skills").schema());
skills1.add(new Utf8("Java"));
skills1.add(new Utf8("Python"));
record1.put("skills", skills1);
writer.write(record1);

GenericRecord record2 = new GenericData.Record(schema);
record2.put("id", 2);
record2.put("name", "Bob");
record2.put("salary", 60000.0);
GenericArray<Utf8> skills2 =
new GenericData.Array<>(2, schema.getField("skills").schema());
skills2.add(new Utf8("C++"));
skills2.add(new Utf8("Go"));
record2.put("skills", skills2);
writer.write(record2);

writer.close();
}

public static void deleteFile() {
File parquetFile = new File(DATA_FILE_PATH);
if (parquetFile.exists()) {
parquetFile.delete();
}
}
}
}