Skip to content

Commit

Permalink
Support reading float as double in ORC and Parquet for Iceberg type c…
Browse files Browse the repository at this point in the history
…hange
  • Loading branch information
ebyhr committed Jan 16, 2023
1 parent 908483c commit 3efdb05
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
import io.trino.orc.stream.InputStreamSources;
import io.trino.spi.block.Block;
import io.trino.spi.block.IntArrayBlock;
import io.trino.spi.block.LongArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.type.RealType;
import io.trino.spi.type.Type;
import org.openjdk.jol.info.ClassLayout;

Expand All @@ -43,7 +43,10 @@
import static io.trino.orc.reader.ReaderUtils.minNonNullValueSize;
import static io.trino.orc.reader.ReaderUtils.verifyStreamType;
import static io.trino.orc.stream.MissingInputStreamSource.missingStreamSource;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static java.lang.Double.doubleToRawLongBits;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;

Expand All @@ -52,6 +55,8 @@ public class FloatColumnReader
{
private static final int INSTANCE_SIZE = toIntExact(ClassLayout.parseClass(FloatColumnReader.class).instanceSize());

private final Type type;

private final OrcColumn column;

private int readOffset;
Expand All @@ -74,8 +79,8 @@ public class FloatColumnReader
public FloatColumnReader(Type type, OrcColumn column, LocalMemoryContext memoryContext)
throws OrcCorruptionException
{
requireNonNull(type, "type is null");
verifyStreamType(column, type, RealType.class::isInstance);
this.type = requireNonNull(type, "type is null");
verifyStreamType(column, type, t -> t == REAL || t == DOUBLE);

this.column = requireNonNull(column, "column is null");
this.memoryContext = requireNonNull(memoryContext, "memoryContext is null");
Expand Down Expand Up @@ -116,7 +121,7 @@ public Block readBlock()
throw new OrcCorruptionException(column.getOrcDataSourceId(), "Value is null but present stream is missing");
}
presentStream.skip(nextBatchSize);
block = RunLengthEncodedBlock.create(REAL, null, nextBatchSize);
block = RunLengthEncodedBlock.create(type, null, nextBatchSize);
}
else if (presentStream == null) {
block = readNonNullBlock();
Expand All @@ -131,7 +136,7 @@ else if (nullCount != nextBatchSize) {
block = readNullBlock(isNull, nextBatchSize - nullCount);
}
else {
block = RunLengthEncodedBlock.create(REAL, null, nextBatchSize);
block = RunLengthEncodedBlock.create(type, null, nextBatchSize);
}
}

Expand All @@ -147,7 +152,13 @@ private Block readNonNullBlock()
verifyNotNull(dataStream);
int[] values = new int[nextBatchSize];
dataStream.next(values, nextBatchSize);
return new IntArrayBlock(nextBatchSize, Optional.empty(), values);
if (type == REAL) {
return new IntArrayBlock(nextBatchSize, Optional.empty(), values);
}
if (type == DOUBLE) {
return new LongArrayBlock(nextBatchSize, Optional.empty(), convertToLongArray(values));
}
throw new VerifyError("Unsupported type " + type);
}

private Block readNullBlock(boolean[] isNull, int nonNullCount)
Expand All @@ -163,8 +174,22 @@ private Block readNullBlock(boolean[] isNull, int nonNullCount)
dataStream.next(nonNullValueTemp, nonNullCount);

int[] result = ReaderUtils.unpackIntNulls(nonNullValueTemp, isNull);
if (type == REAL) {
return new IntArrayBlock(isNull.length, Optional.of(isNull), result);
}
if (type == DOUBLE) {
return new LongArrayBlock(isNull.length, Optional.of(isNull), convertToLongArray(result));
}
throw new VerifyError("Unsupported type " + type);
}

return new IntArrayBlock(isNull.length, Optional.of(isNull), result);
private static long[] convertToLongArray(int[] intValues)
{
long[] values = new long[intValues.length];
for (int i = 0; i < intValues.length; i++) {
values[i] = doubleToRawLongBits(intBitsToFloat(intValues[i]));
}
return values;
}

private void openRowGroup()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.Type;

import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.RealType.REAL;
import static java.lang.Float.floatToRawIntBits;

public class FloatColumnReader
Expand All @@ -30,6 +32,14 @@ public FloatColumnReader(PrimitiveField field)
@Override
protected void readValue(BlockBuilder blockBuilder, Type type)
{
type.writeLong(blockBuilder, floatToRawIntBits(valuesReader.readFloat()));
if (type == REAL) {
type.writeLong(blockBuilder, floatToRawIntBits(valuesReader.readFloat()));
}
else if (type == DOUBLE) {
type.writeDouble(blockBuilder, valuesReader.readFloat());
}
else {
throw new VerifyError("Unsupported type " + type);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2459,6 +2459,49 @@ public void testRegisterTableWithMetadataFile(StorageFormat storageFormat)
onTrino().executeQuery(format("DROP TABLE %s", trinoTableName));
}

@Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "testSparkAlterColumnType")
public void testSparkAlterColumnType(StorageFormat storageFormat, String sourceColumnType, String sourceValueLiteral, String newColumnType, Object newValue)
{
String baseTableName = "test_spark_alter_column_type_" + randomNameSuffix();
String trinoTableName = trinoTableName(baseTableName);
String sparkTableName = sparkTableName(baseTableName);

onSpark().executeQuery("CREATE TABLE " + sparkTableName + " TBLPROPERTIES ('write.format.default' = '" + storageFormat + "') " +
"AS SELECT CAST(" + sourceValueLiteral + " AS " + sourceColumnType + ") AS col");

onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ALTER COLUMN col TYPE " + newColumnType);

assertEquals(getColumnType(baseTableName, "col"), newColumnType);
assertThat(onSpark().executeQuery("SELECT * FROM " + sparkTableName)).containsOnly(row(newValue));
assertThat(onTrino().executeQuery("SELECT * FROM " + trinoTableName)).containsOnly(row(newValue));

onSpark().executeQuery("DROP TABLE " + sparkTableName);
}

@DataProvider
public static Object[][] testSparkAlterColumnType()
{
Object[][] alterColumnTypeData = new Object[][] {
{"integer", "2147483647", "bigint", 2147483647L},
{"float", "10.3", "double", 10.3},
{"float", "'NaN'", "double", Double.NaN},
{"decimal(5,3)", "'12.345'", "decimal(10,3)", BigDecimal.valueOf(12.345)}
};

return Stream.of(StorageFormat.values())
.flatMap(storageFormat -> Arrays.stream(alterColumnTypeData).map(data -> new Object[] {storageFormat, data[0], data[1], data[2], data[3]}))
.toArray(Object[][]::new);
}

private String getColumnType(String tableName, String columnName)
{
return (String) onTrino().executeQuery("SELECT data_type FROM " + TRINO_CATALOG + ".information_schema.columns " +
"WHERE table_schema = '" + TEST_SCHEMA_NAME + "' AND " +
"table_name = '" + tableName + "' AND " +
"column_name = '" + columnName + "'")
.getOnlyValue();
}

private int calculateMetadataFilesForPartitionedTable(String tableName)
{
String dataFilePath = (String) onTrino().executeQuery(format("SELECT file_path FROM iceberg.default.\"%s$files\" limit 1", tableName)).getOnlyValue();
Expand Down

0 comments on commit 3efdb05

Please sign in to comment.