diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index f2dd8027a4cc..e4311a8da33b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -110,6 +110,7 @@ import java.util.Optional; import java.util.OptionalLong; import java.util.Set; +import java.util.stream.Stream; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -232,7 +233,7 @@ public class Analysis private final Map, List> checkConstraints = new LinkedHashMap<>(); private final Multiset columnMaskScopes = HashMultiset.create(); - private final Map, Map> columnMasks = new LinkedHashMap<>(); + private final Map, List> columnMasks = new LinkedHashMap<>(); private final Map, UnnestAnalysis> unnestAnalysis = new LinkedHashMap<>(); private Optional create = Optional.empty(); @@ -1161,17 +1162,15 @@ public void unregisterTableForColumnMasking(QualifiedObjectName table, String co referenceChain.pop(); } - public void addColumnMask(Table table, String column, Expression mask) + public void addColumnMask(Table table, Field column, Expression mask) { - Map masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>()); - checkArgument(!masks.containsKey(column), "Mask already exists for column %s", column); - - masks.put(column, mask); + columnMasks.computeIfAbsent(NodeRef.of(table), node -> new ArrayList<>()) + .add(new FieldExpression(column, mask)); } - public Map getColumnMasks(Table table) + public List getColumnMasks(Table table) { - return unmodifiableMap(columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of())); + return unmodifiableList(columnMasks.getOrDefault(NodeRef.of(table), ImmutableList.of())); } public List getReferencedTables() @@ -1189,7 +1188,7 @@ public List getReferencedTables() .distinct() .map(fieldName -> new ColumnInfo( fieldName, - Optional.ofNullable(columnMasks.getOrDefault(table, ImmutableMap.of()).get(fieldName)) + resolveColumnMask(table.getNode().getName(), fieldName, columnMasks.getOrDefault(table, ImmutableList.of())) .map(Expression::toString))) .collect(toImmutableList()); @@ -1210,6 +1209,22 @@ public List getReferencedTables() .collect(toImmutableList()); } + public static Optional resolveColumnMask(QualifiedName tableName, String fieldName, List expressions) + { + return expressions.stream() + .filter(fieldExpression -> fieldExpression.field().canResolve(concatIdentifier(tableName, fieldName))) + .findFirst() + .map(FieldExpression::expression); + } + + private static QualifiedName concatIdentifier(QualifiedName tableName, String fieldName) + { + return QualifiedName.of(Stream.concat( + tableName.getOriginalParts().stream(), + Stream.of(new Identifier(fieldName))) + .collect(toImmutableList())); + } + public List getRoutines() { return resolvedFunctions.values().stream() @@ -2528,4 +2543,13 @@ public record JsonTableAnalysis( requireNonNull(orderedOutputColumns, "orderedOutputColumns is null"); } } + + public record FieldExpression(Field field, Expression expression) + { + public FieldExpression + { + requireNonNull(field, "field is null"); + requireNonNull(expression, "expression is null"); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index 6f7cc0422026..057b47036fdc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -2377,23 +2377,33 @@ private void checkStorageTableNotRedirected(QualifiedObjectName source) private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, RelationType relationType, Scope accessControlScope) { + ImmutableList.Builder fieldBuilder = ImmutableList.builder(); ImmutableList.Builder columnSchemaBuilder = ImmutableList.builder(); for (int index = 0; index < relationType.getAllFieldCount(); index++) { Field field = relationType.getFieldByIndex(index); - field.getName().ifPresent(fieldName -> columnSchemaBuilder.add(ColumnSchema.builder() - .setName(fieldName) - .setType(field.getType()) - .setHidden(field.isHidden()) - .build())); + field.getName().ifPresent(fieldName -> { + fieldBuilder.add(field); + columnSchemaBuilder.add(ColumnSchema.builder() + .setName(fieldName) + .setType(field.getType()) + .setHidden(field.isHidden()) + .build()); + }); } + List fields = fieldBuilder.build(); List columnSchemas = columnSchemaBuilder.build(); Map masks = accessControl.getColumnMasks(session.toSecurityContext(), name, columnSchemas); - for (ColumnSchema columnSchema : columnSchemas) { + for (Field field : fields) { + ColumnSchema columnSchema = ColumnSchema.builder() + .setName(field.getName().orElseThrow()) + .setType(field.getType()) + .setHidden(field.isHidden()) + .build(); Optional.ofNullable(masks.get(columnSchema)).ifPresent(mask -> { if (checkCanSelectFromColumn(name, columnSchema.getName())) { - analyzeColumnMask(session.getIdentity().getUser(), table, name, columnSchema, accessControlScope, mask); + analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, mask); } }); } @@ -5222,9 +5232,9 @@ private void analyzeCheckConstraint(Table table, QualifiedObjectName name, Scope analysis.addCheckConstraints(table, expression); } - private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObjectName tableName, ColumnSchema columnSchema, Scope scope, ViewExpression mask) + private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObjectName tableName, Field columnSchema, Scope scope, ViewExpression mask) { - String column = columnSchema.getName(); + String column = columnSchema.getName().orElseThrow(); if (analysis.hasColumnMask(tableName, column, currentIdentity)) { throw new TrinoException(INVALID_COLUMN_MASK, extractLocation(table), format("Column mask for '%s.%s' is recursive", tableName, column), null); } @@ -5284,7 +5294,7 @@ private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObj analysis.addCoercion(expression, expectedType); } - analysis.addColumnMask(table, column, expression); + analysis.addColumnMask(table, columnSchema, expression); } private List descriptorToFields(Scope scope) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 1db38dfa13b4..6e0b351e8bdd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -422,7 +422,7 @@ public RelationPlan addCheckConstraints(List const private RelationPlan addColumnMasks(Table table, RelationPlan plan) { - Map columnMasks = analysis.getColumnMasks(table); + List columnMasks = analysis.getColumnMasks(table); // A Table can represent a WITH query, which can have anonymous fields. On the other hand, // it can't have masks. The loop below expects fields to have proper names, so bail out @@ -441,10 +441,11 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan) for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) { Field field = plan.getDescriptor().getFieldByIndex(i); - io.trino.sql.tree.Expression mask = columnMasks.get(field.getName().orElseThrow()); + Optional columnMask = Analysis.resolveColumnMask(table.getName(), field.getName().orElseThrow(), columnMasks); Symbol symbol = plan.getFieldMappings().get(i); Expression projection = symbol.toSymbolReference(); - if (mask != null) { + if (columnMask.isPresent()) { + io.trino.sql.tree.Expression mask = columnMask.get(); planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask)); symbol = symbolAllocator.newSymbol(symbol); projection = coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask)); diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java index 94eb4684502c..2dbdb74bf9ef 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestColumnMask.java @@ -89,6 +89,17 @@ public TestColumnMask() Optional.of(VIEW_OWNER), false, ImmutableList.of()); + ConnectorViewDefinition viewWithDifferentCase = new ConnectorViewDefinition( + "SELECT NATIONKEY, NAME FROM local.tiny.nation", + Optional.empty(), + Optional.empty(), + ImmutableList.of( + new ConnectorViewDefinition.ViewColumn("NATIONKEY", BigintType.BIGINT.getTypeId(), Optional.empty()), + new ConnectorViewDefinition.ViewColumn("NAME", VarcharType.createVarcharType(25).getTypeId(), Optional.empty())), + Optional.empty(), + Optional.of(VIEW_OWNER), + false, + ImmutableList.of()); ConnectorViewDefinition viewWithNested = new ConnectorViewDefinition( """ @@ -171,6 +182,7 @@ public TestColumnMask() }) .withGetViews((s, prefix) -> ImmutableMap.of( new SchemaTableName("default", "nation_view"), view, + new SchemaTableName("default", "nation_view_uppercase"), viewWithDifferentCase, new SchemaTableName("default", "view_with_nested"), viewWithNested)) .withGetMaterializedViews((s, prefix) -> ImmutableMap.of( new SchemaTableName("default", "nation_materialized_view"), materializedView, @@ -456,6 +468,23 @@ public void testView() assertThat(assertions.query("SELECT name FROM mock.default.nation_view WHERE nationkey = 1")).matches("VALUES CAST('ANITNEGRA' AS VARCHAR(25))"); } + @Test + public void testViewWithUppercaseColumnName() + { + accessControl.reset(); + accessControl.columnMask( + new QualifiedObjectName(MOCK_CATALOG, "default", "nation_view_uppercase"), + "name", + USER, + ViewExpression.builder() + .identity(USER) + .catalog(LOCAL_CATALOG) + .schema("tiny") + .expression("reverse(name)") + .build()); + assertThat(assertions.query("SELECT name FROM mock.default.nation_view_uppercase WHERE nationkey = 1")).matches("VALUES CAST('ANITNEGRA' AS VARCHAR(25))"); + } + @Test public void testTableReferenceInWithClause() {