From 3efdb05450f27c7927880c9cc7b5d35fcf0f881d Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Tue, 10 Jan 2023 19:14:02 +0900 Subject: [PATCH] Support reading float as double in ORC and Parquet for Iceberg type change --- .../trino/orc/reader/FloatColumnReader.java | 39 ++++++++++++++--- .../parquet/reader/FloatColumnReader.java | 12 +++++- .../TestIcebergSparkCompatibility.java | 43 +++++++++++++++++++ 3 files changed, 86 insertions(+), 8 deletions(-) diff --git a/lib/trino-orc/src/main/java/io/trino/orc/reader/FloatColumnReader.java b/lib/trino-orc/src/main/java/io/trino/orc/reader/FloatColumnReader.java index 9d3895b051aa..a2173625feaa 100644 --- a/lib/trino-orc/src/main/java/io/trino/orc/reader/FloatColumnReader.java +++ b/lib/trino-orc/src/main/java/io/trino/orc/reader/FloatColumnReader.java @@ -24,8 +24,8 @@ import io.trino.orc.stream.InputStreamSources; import io.trino.spi.block.Block; import io.trino.spi.block.IntArrayBlock; +import io.trino.spi.block.LongArrayBlock; import io.trino.spi.block.RunLengthEncodedBlock; -import io.trino.spi.type.RealType; import io.trino.spi.type.Type; import org.openjdk.jol.info.ClassLayout; @@ -43,7 +43,10 @@ import static io.trino.orc.reader.ReaderUtils.minNonNullValueSize; import static io.trino.orc.reader.ReaderUtils.verifyStreamType; import static io.trino.orc.stream.MissingInputStreamSource.missingStreamSource; +import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.RealType.REAL; +import static java.lang.Double.doubleToRawLongBits; +import static java.lang.Float.intBitsToFloat; import static java.lang.Math.toIntExact; import static java.util.Objects.requireNonNull; @@ -52,6 +55,8 @@ public class FloatColumnReader { private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(FloatColumnReader.class).instanceSize()); + private final Type type; + private final OrcColumn column; private int readOffset; @@ -74,8 +79,8 @@ public class FloatColumnReader public FloatColumnReader(Type type, OrcColumn column, LocalMemoryContext memoryContext) throws OrcCorruptionException { - requireNonNull(type, "type is null"); - verifyStreamType(column, type, RealType.class::isInstance); + this.type = requireNonNull(type, "type is null"); + verifyStreamType(column, type, t -> t == REAL || t == DOUBLE); this.column = requireNonNull(column, "column is null"); this.memoryContext = requireNonNull(memoryContext, "memoryContext is null"); @@ -116,7 +121,7 @@ public Block readBlock() throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing"); } presentStream.skip(nextBatchSize); - block = RunLengthEncodedBlock.create(REAL, null, nextBatchSize); + block = RunLengthEncodedBlock.create(type, null, nextBatchSize); } else if (presentStream == null) { block = readNonNullBlock(); @@ -131,7 +136,7 @@ else if (nullCount != nextBatchSize) { block = readNullBlock(isNull, nextBatchSize - nullCount); } else { - block = RunLengthEncodedBlock.create(REAL, null, nextBatchSize); + block = RunLengthEncodedBlock.create(type, null, nextBatchSize); } } @@ -147,7 +152,13 @@ private Block readNonNullBlock() verifyNotNull(dataStream); int[] values = new int[nextBatchSize]; dataStream.next(values, nextBatchSize); - return new IntArrayBlock(nextBatchSize, Optional.empty(), values); + if (type == REAL) { + return new IntArrayBlock(nextBatchSize, Optional.empty(), values); + } + if (type == DOUBLE) { + return new LongArrayBlock(nextBatchSize, Optional.empty(), convertToLongArray(values)); + } + throw new VerifyError("Unsupported type " + type); } private Block readNullBlock(boolean[] isNull, int nonNullCount) @@ -163,8 +174,22 @@ private Block readNullBlock(boolean[] isNull, int nonNullCount) dataStream.next(nonNullValueTemp, nonNullCount); int[] result = ReaderUtils.unpackIntNulls(nonNullValueTemp, isNull); + if (type == REAL) { + return new IntArrayBlock(isNull.length, Optional.of(isNull), result); + } + if (type == DOUBLE) { + return new LongArrayBlock(isNull.length, Optional.of(isNull), convertToLongArray(result)); + } + throw new VerifyError("Unsupported type " + type); + } - return new IntArrayBlock(isNull.length, Optional.of(isNull), result); + private static long[] convertToLongArray(int[] intValues) + { + long[] values = new long[intValues.length]; + for (int i = 0; i < intValues.length; i++) { + values[i] = doubleToRawLongBits(intBitsToFloat(intValues[i])); + } + return values; } private void openRowGroup() diff --git a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/FloatColumnReader.java b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/FloatColumnReader.java index 5dbac86d2805..e493087da027 100644 --- a/lib/trino-parquet/src/main/java/io/trino/parquet/reader/FloatColumnReader.java +++ b/lib/trino-parquet/src/main/java/io/trino/parquet/reader/FloatColumnReader.java @@ -17,6 +17,8 @@ import io.trino.spi.block.BlockBuilder; import io.trino.spi.type.Type; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.RealType.REAL; import static java.lang.Float.floatToRawIntBits; public class FloatColumnReader @@ -30,6 +32,14 @@ public FloatColumnReader(PrimitiveField field) @Override protected void readValue(BlockBuilder blockBuilder, Type type) { - type.writeLong(blockBuilder, floatToRawIntBits(valuesReader.readFloat())); + if (type == REAL) { + type.writeLong(blockBuilder, floatToRawIntBits(valuesReader.readFloat())); + } + else if (type == DOUBLE) { + type.writeDouble(blockBuilder, valuesReader.readFloat()); + } + else { + throw new VerifyError("Unsupported type " + type); + } } } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java index 3be74d340ca9..c944c1d3e11e 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java @@ -2459,6 +2459,49 @@ public void testRegisterTableWithMetadataFile(StorageFormat storageFormat) onTrino().executeQuery(format("DROP TABLE %s", trinoTableName)); } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "testSparkAlterColumnType") + public void testSparkAlterColumnType(StorageFormat storageFormat, String sourceColumnType, String sourceValueLiteral, String newColumnType, Object newValue) + { + String baseTableName = "test_spark_alter_column_type_" + randomNameSuffix(); + String trinoTableName = trinoTableName(baseTableName); + String sparkTableName = sparkTableName(baseTableName); + + onSpark().executeQuery("CREATE TABLE " + sparkTableName + " TBLPROPERTIES ('write.format.default' = '" + storageFormat + "') " + + "AS SELECT CAST(" + sourceValueLiteral + " AS " + sourceColumnType + ") AS col"); + + onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ALTER COLUMN col TYPE " + newColumnType); + + assertEquals(getColumnType(baseTableName, "col"), newColumnType); + assertThat(onSpark().executeQuery("SELECT * FROM " + sparkTableName)).containsOnly(row(newValue)); + assertThat(onTrino().executeQuery("SELECT * FROM " + trinoTableName)).containsOnly(row(newValue)); + + onSpark().executeQuery("DROP TABLE " + sparkTableName); + } + + @DataProvider + public static Object[][] testSparkAlterColumnType() + { + Object[][] alterColumnTypeData = new Object[][] { + {"integer", "2147483647", "bigint", 2147483647L}, + {"float", "10.3", "double", 10.3}, + {"float", "'NaN'", "double", Double.NaN}, + {"decimal(5,3)", "'12.345'", "decimal(10,3)", BigDecimal.valueOf(12.345)} + }; + + return Stream.of(StorageFormat.values()) + .flatMap(storageFormat -> Arrays.stream(alterColumnTypeData).map(data -> new Object[] {storageFormat, data[0], data[1], data[2], data[3]})) + .toArray(Object[][]::new); + } + + private String getColumnType(String tableName, String columnName) + { + return (String) onTrino().executeQuery("SELECT data_type FROM " + TRINO_CATALOG + ".information_schema.columns " + + "WHERE table_schema = '" + TEST_SCHEMA_NAME + "' AND " + + "table_name = '" + tableName + "' AND " + + "column_name = '" + columnName + "'") + .getOnlyValue(); + } + private int calculateMetadataFilesForPartitionedTable(String tableName) { String dataFilePath = (String) onTrino().executeQuery(format("SELECT file_path FROM iceberg.default.\"%s$files\" limit 1", tableName)).getOnlyValue();