Skip to content

Commit

Permalink
Fix dereference operations for union type in Hive Connector
Browse files Browse the repository at this point in the history
  • Loading branch information
leetcode-1533 committed Dec 9, 2022
1 parent 864d567 commit d97ad7d
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 17 deletions.
72 changes: 57 additions & 15 deletions plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,14 @@
import java.util.Optional;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.lenientFormat;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.plugin.hive.HiveStorageFormat.AVRO;
import static io.trino.plugin.hive.HiveStorageFormat.ORC;
import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION;
import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_FIELD_PREFIX;
import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_NAME;
import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_TYPE;
import static io.trino.plugin.hive.util.HiveTypeTranslator.fromPrimitiveType;
import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeInfo;
import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeSignature;
Expand Down Expand Up @@ -219,13 +223,32 @@ public Optional<HiveType> getHiveTypeForDereferences(List<Integer> dereferences)
{
TypeInfo typeInfo = getTypeInfo();
for (int fieldIndex : dereferences) {
checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo);
StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo;
try {
typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex);
if (typeInfo instanceof StructTypeInfo structTypeInfo) {
try {
typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex);
}
catch (RuntimeException e) {
// return empty when failed to dereference, this could happen when partition and table schema mismatch
return Optional.empty();
}
}
catch (RuntimeException e) {
return Optional.empty();
else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) {
try {
if (fieldIndex == 0) {
// union's tag field, defined in {@link io.trino.plugin.hive.util.HiveTypeTranslator#toTypeSignature}
return Optional.of(HiveType.toHiveType(UNION_FIELD_TAG_TYPE));
}
else {
typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1);
}
}
catch (RuntimeException e) {
// return empty when failed to dereference, this could happen when partition and table schema mismatch
return Optional.empty();
}
}
else {
throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo));
}
}
return Optional.of(toHiveType(typeInfo));
Expand All @@ -235,16 +258,35 @@ public List<String> getHiveDereferenceNames(List<Integer> dereferences)
{
ImmutableList.Builder<String> dereferenceNames = ImmutableList.builder();
TypeInfo typeInfo = getTypeInfo();
for (int fieldIndex : dereferences) {
checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo);
StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo;

for (int i = 0; i < dereferences.size(); i++) {
int fieldIndex = dereferences.get(i);
checkArgument(fieldIndex >= 0, "fieldIndex cannot be negative");
checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(),
"fieldIndex should be less than the number of fields in the struct");
String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex);
dereferenceNames.add(fieldName);
typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex);

if (typeInfo instanceof StructTypeInfo structTypeInfo) {
checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(),
"fieldIndex should be less than the number of fields in the struct");

String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex);
dereferenceNames.add(fieldName);
typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex);
}
else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) {
checkArgument((fieldIndex - 1) < unionTypeInfo.getAllUnionObjectTypeInfos().size(),
"fieldIndex should be less than the number of fields in the union plus tag field");

if (fieldIndex == 0) {
checkArgument(i == (dereferences.size() - 1), "Union's tag field should not have more subfields");
dereferenceNames.add(UNION_FIELD_TAG_NAME);
break;
}
else {
typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1);
dereferenceNames.add(UNION_FIELD_FIELD_PREFIX + (fieldIndex - 1));
}
}
else {
throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo));
}
}

return dereferenceNames.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ public final class HiveTypeTranslator
{
private HiveTypeTranslator() {}

public static final String UNION_FIELD_TAG_NAME = "tag";
public static final String UNION_FIELD_FIELD_PREFIX = "field";
public static final Type UNION_FIELD_TAG_TYPE = TINYINT;

public static TypeInfo toTypeInfo(Type type)
{
requireNonNull(type, "type is null");
Expand Down Expand Up @@ -213,10 +217,10 @@ public static TypeSignature toTypeSignature(TypeInfo typeInfo, HiveTimestampPrec
UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo;
List<TypeInfo> unionObjectTypes = unionTypeInfo.getAllUnionObjectTypeInfos();
ImmutableList.Builder<TypeSignatureParameter> typeSignatures = ImmutableList.builder();
typeSignatures.add(namedField("tag", TINYINT.getTypeSignature()));
typeSignatures.add(namedField(UNION_FIELD_TAG_NAME, UNION_FIELD_TAG_TYPE.getTypeSignature()));
for (int i = 0; i < unionObjectTypes.size(); i++) {
TypeInfo unionObjectType = unionObjectTypes.get(i);
typeSignatures.add(namedField("field" + i, toTypeSignature(unionObjectType, timestampPrecision)));
typeSignatures.add(namedField(UNION_FIELD_FIELD_PREFIX + i, toTypeSignature(unionObjectType, timestampPrecision)));
}
return rowType(typeSignatures.build());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import java.util.Arrays;
import java.util.List;

import static io.trino.testing.TestingNames.randomNameSuffix;
import static io.trino.tests.product.TestGroups.SMOKE;
import static io.trino.tests.product.utils.QueryExecutors.onHive;
import static io.trino.tests.product.utils.QueryExecutors.onTrino;
Expand Down Expand Up @@ -51,6 +52,87 @@ public static Object[][] storageFormats()
return new String[][] {{"ORC"}, {"AVRO"}};
}

@DataProvider(name = "union_dereference_test_cases")
public static Object[][] unionDereferenceTestCases()
{
String tableUnionDereference = "test_union_dereference" + randomNameSuffix();
// Hive insertion for union type in AVRO format has bugs, so we test on different table schemas for AVRO than ORC.
return new Object[][] {{
format(
"CREATE TABLE %s (unionLevel0 UNIONTYPE<" +
"INT, STRING>)" +
"STORED AS %s",
tableUnionDereference,
"AVRO"),
format(
"INSERT INTO TABLE %s " +
"SELECT create_union(0, 321, 'row1') " +
"UNION ALL " +
"SELECT create_union(1, 55, 'row2') ",
tableUnionDereference),
format("SELECT unionLevel0.field0 FROM %s WHERE unionLevel0.field0 IS NOT NULL", tableUnionDereference),
Arrays.asList(321),
format("SELECT unionLevel0.tag FROM %s", tableUnionDereference),
Arrays.asList((byte) 0, (byte) 1),
"DROP TABLE IF EXISTS " + tableUnionDereference},
// there is an internal issue in Hive 1.2:
// unionLevel1 is declared as unionType<String, Int>, but has to be inserted by create_union(tagId, Int, String)
{
format(
"CREATE TABLE %s (unionLevel0 UNIONTYPE<INT, STRING," +
"STRUCT<intLevel1:INT, stringLevel1:STRING, unionLevel1:UNIONTYPE<STRING, INT>>>, intLevel0 INT )" +
"STORED AS %s",
tableUnionDereference,
"AVRO"),
format(
"INSERT INTO TABLE %s " +
"SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 5, 'testString'))), 8 " +
"UNION ALL " +
"SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 5, 'testString'))), 8 ",
tableUnionDereference),
format("SELECT unionLevel0.field2.unionLevel1.field1 FROM %s WHERE unionLevel0.field2.unionLevel1.field1 IS NOT NULL", tableUnionDereference),
Arrays.asList(5),
format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference),
Arrays.asList((byte) 0, (byte) 1),
"DROP TABLE IF EXISTS " + tableUnionDereference},
{
format(
"CREATE TABLE %s (unionLevel0 UNIONTYPE<" +
"STRUCT<unionLevel1:UNIONTYPE<STRING, INT>>>)" +
"STORED AS %s",
tableUnionDereference,
"ORC"),
format(
"INSERT INTO TABLE %s " +
"SELECT create_union(0, named_struct('unionLevel1', create_union(0, 'testString1', 23))) " +
"UNION ALL " +
"SELECT create_union(0, named_struct('unionLevel1', create_union(1, 'testString2', 45))) ",
tableUnionDereference),
format("SELECT unionLevel0.field0.unionLevel1.field0 FROM %s WHERE unionLevel0.field0.unionLevel1.field0 IS NOT NULL", tableUnionDereference),
Arrays.asList("testString1"),
format("SELECT unionLevel0.field0.unionLevel1.tag FROM %s", tableUnionDereference),
Arrays.asList((byte) 0, (byte) 1),
"DROP TABLE IF EXISTS " + tableUnionDereference},
{
format(
"CREATE TABLE %s (unionLevel0 UNIONTYPE<INT, STRING," +
"STRUCT<intLevel1:INT, stringLevel1:STRING, unionLevel1:UNIONTYPE<STRING, INT>>>, intLevel0 INT )" +
"STORED AS %s",
tableUnionDereference,
"ORC"),
format(
"INSERT INTO TABLE %s " +
"SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 'testString', 5))), 8 " +
"UNION ALL " +
"SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 'testString', 5))), 8 ",
tableUnionDereference),
format("SELECT unionLevel0.field2.unionLevel1.field0 FROM %s WHERE unionLevel0.field2.unionLevel1.field0 IS NOT NULL", tableUnionDereference),
Arrays.asList("testString"),
format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference),
Arrays.asList((byte) 0, (byte) 1),
"DROP TABLE IF EXISTS " + tableUnionDereference}};
}

@Test(dataProvider = "storage_formats", groups = SMOKE)
public void testReadUniontype(String storageFormat)
{
Expand Down Expand Up @@ -137,6 +219,25 @@ public void testReadUniontype(String storageFormat)
}
}

@Test(dataProvider = "union_dereference_test_cases", groups = SMOKE)
public void testReadUniontypeWithDereference(String createTableSql, String insertSql, String selectSql, List<Object> expectedResult, String selectTagSql, List<Object> expectedTagResult, String dropTableSql)
{
// According to testing results, the Hive INSERT queries here only work in Hive 1.2
if (getHiveVersionMajor() != 1 || getHiveVersionMinor() != 2) {
throw new SkipException("This test can only be run with Hive 1.2 (default config)");
}

onHive().executeQuery(createTableSql);
onHive().executeQuery(insertSql);

QueryResult result = onTrino().executeQuery(selectSql);
assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedResult);
result = onTrino().executeQuery(selectTagSql);
assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedTagResult);

onTrino().executeQuery(dropTableSql);
}

@Test(dataProvider = "storage_formats", groups = SMOKE)
public void testUnionTypeSchemaEvolution(String storageFormat)
{
Expand Down

0 comments on commit d97ad7d

Please sign in to comment.