diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java index 9230d9e26013..45afb7bb299a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java @@ -32,10 +32,14 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.lenientFormat; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.hive.HiveStorageFormat.AVRO; import static io.trino.plugin.hive.HiveStorageFormat.ORC; import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_FIELD_PREFIX; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_NAME; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_TYPE; import static io.trino.plugin.hive.util.HiveTypeTranslator.fromPrimitiveType; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeInfo; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeSignature; @@ -219,13 +223,32 @@ public Optional getHiveTypeForDereferences(List dereferences) { TypeInfo typeInfo = getTypeInfo(); for (int fieldIndex : dereferences) { - checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo); - StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - try { - typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + if (typeInfo instanceof StructTypeInfo structTypeInfo) { + try { + typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + } + catch (RuntimeException e) { + // return empty when failed to dereference, this could happen when partition and table schema mismatch + return Optional.empty(); + } } - catch (RuntimeException e) { - return Optional.empty(); + else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { + try { + if (fieldIndex == 0) { + // union's tag field, defined in {@link io.trino.plugin.hive.util.HiveTypeTranslator#toTypeSignature} + return Optional.of(HiveType.toHiveType(UNION_FIELD_TAG_TYPE)); + } + else { + typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1); + } + } + catch (RuntimeException e) { + // return empty when failed to dereference, this could happen when partition and table schema mismatch + return Optional.empty(); + } + } + else { + throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo)); } } return Optional.of(toHiveType(typeInfo)); @@ -235,16 +258,35 @@ public List getHiveDereferenceNames(List dereferences) { ImmutableList.Builder dereferenceNames = ImmutableList.builder(); TypeInfo typeInfo = getTypeInfo(); - for (int fieldIndex : dereferences) { - checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo); - StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - + for (int i = 0; i < dereferences.size(); i++) { + int fieldIndex = dereferences.get(i); checkArgument(fieldIndex >= 0, "fieldIndex cannot be negative"); - checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(), - "fieldIndex should be less than the number of fields in the struct"); - String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); - dereferenceNames.add(fieldName); - typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + + if (typeInfo instanceof StructTypeInfo structTypeInfo) { + checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(), + "fieldIndex should be less than the number of fields in the struct"); + + String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); + dereferenceNames.add(fieldName); + typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + } + else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { + checkArgument((fieldIndex - 1) < unionTypeInfo.getAllUnionObjectTypeInfos().size(), + "fieldIndex should be less than the number of fields in the union plus tag field"); + + if (fieldIndex == 0) { + checkArgument(i == (dereferences.size() - 1), "Union's tag field should not have more subfields"); + dereferenceNames.add(UNION_FIELD_TAG_NAME); + break; + } + else { + typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1); + dereferenceNames.add(UNION_FIELD_FIELD_PREFIX + (fieldIndex - 1)); + } + } + else { + throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo)); + } } return dereferenceNames.build(); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java index 22d1e5f4ce2d..4d165511130a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java @@ -91,6 +91,10 @@ public final class HiveTypeTranslator { private HiveTypeTranslator() {} + public static final String UNION_FIELD_TAG_NAME = "tag"; + public static final String UNION_FIELD_FIELD_PREFIX = "field"; + public static final Type UNION_FIELD_TAG_TYPE = TINYINT; + public static TypeInfo toTypeInfo(Type type) { requireNonNull(type, "type is null"); @@ -213,10 +217,10 @@ public static TypeSignature toTypeSignature(TypeInfo typeInfo, HiveTimestampPrec UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; List unionObjectTypes = unionTypeInfo.getAllUnionObjectTypeInfos(); ImmutableList.Builder typeSignatures = ImmutableList.builder(); - typeSignatures.add(namedField("tag", TINYINT.getTypeSignature())); + typeSignatures.add(namedField(UNION_FIELD_TAG_NAME, UNION_FIELD_TAG_TYPE.getTypeSignature())); for (int i = 0; i < unionObjectTypes.size(); i++) { TypeInfo unionObjectType = unionObjectTypes.get(i); - typeSignatures.add(namedField("field" + i, toTypeSignature(unionObjectType, timestampPrecision))); + typeSignatures.add(namedField(UNION_FIELD_FIELD_PREFIX + i, toTypeSignature(unionObjectType, timestampPrecision))); } return rowType(typeSignatures.build()); } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java index a7ec94a51eb3..062ae91d9c3b 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.List; +import static io.trino.testing.TestingNames.randomNameSuffix; import static io.trino.tests.product.TestGroups.SMOKE; import static io.trino.tests.product.utils.QueryExecutors.onHive; import static io.trino.tests.product.utils.QueryExecutors.onTrino; @@ -51,6 +52,87 @@ public static Object[][] storageFormats() return new String[][] {{"ORC"}, {"AVRO"}}; } + @DataProvider(name = "union_dereference_test_cases") + public static Object[][] unionDereferenceTestCases() + { + String tableUnionDereference = "test_union_dereference" + randomNameSuffix(); + // Hive insertion for union type in AVRO format has bugs, so we test on different table schemas for AVRO than ORC. + return new Object[][] {{ + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE<" + + "INT, STRING>)" + + "STORED AS %s", + tableUnionDereference, + "AVRO"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(0, 321, 'row1') " + + "UNION ALL " + + "SELECT create_union(1, 55, 'row2') ", + tableUnionDereference), + format("SELECT unionLevel0.field0 FROM %s WHERE unionLevel0.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList(321), + format("SELECT unionLevel0.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + // there is an internal issue in Hive 1.2: + // unionLevel1 is declared as unionType, but has to be inserted by create_union(tagId, Int, String) + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE>>, intLevel0 INT )" + + "STORED AS %s", + tableUnionDereference, + "AVRO"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 5, 'testString'))), 8 " + + "UNION ALL " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 5, 'testString'))), 8 ", + tableUnionDereference), + format("SELECT unionLevel0.field2.unionLevel1.field1 FROM %s WHERE unionLevel0.field2.unionLevel1.field1 IS NOT NULL", tableUnionDereference), + Arrays.asList(5), + format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE<" + + "STRUCT>>)" + + "STORED AS %s", + tableUnionDereference, + "ORC"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(0, named_struct('unionLevel1', create_union(0, 'testString1', 23))) " + + "UNION ALL " + + "SELECT create_union(0, named_struct('unionLevel1', create_union(1, 'testString2', 45))) ", + tableUnionDereference), + format("SELECT unionLevel0.field0.unionLevel1.field0 FROM %s WHERE unionLevel0.field0.unionLevel1.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList("testString1"), + format("SELECT unionLevel0.field0.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE>>, intLevel0 INT )" + + "STORED AS %s", + tableUnionDereference, + "ORC"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 'testString', 5))), 8 " + + "UNION ALL " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 'testString', 5))), 8 ", + tableUnionDereference), + format("SELECT unionLevel0.field2.unionLevel1.field0 FROM %s WHERE unionLevel0.field2.unionLevel1.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList("testString"), + format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}}; + } + @Test(dataProvider = "storage_formats", groups = SMOKE) public void testReadUniontype(String storageFormat) { @@ -137,6 +219,25 @@ public void testReadUniontype(String storageFormat) } } + @Test(dataProvider = "union_dereference_test_cases", groups = SMOKE) + public void testReadUniontypeWithDereference(String createTableSql, String insertSql, String selectSql, List expectedResult, String selectTagSql, List expectedTagResult, String dropTableSql) + { + // According to testing results, the Hive INSERT queries here only work in Hive 1.2 + if (getHiveVersionMajor() != 1 || getHiveVersionMinor() != 2) { + throw new SkipException("This test can only be run with Hive 1.2 (default config)"); + } + + onHive().executeQuery(createTableSql); + onHive().executeQuery(insertSql); + + QueryResult result = onTrino().executeQuery(selectSql); + assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedResult); + result = onTrino().executeQuery(selectTagSql); + assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedTagResult); + + onTrino().executeQuery(dropTableSql); + } + @Test(dataProvider = "storage_formats", groups = SMOKE) public void testUnionTypeSchemaEvolution(String storageFormat) {