Skip to content

Commit

Permalink
Fix case-sensitivity issue with views and column masks
Browse files Browse the repository at this point in the history
The table column reference was registered incorectly with the original
case taken from the view definition. It then failed to match the column
schema returned from `SystemAccessControl` and the mask was not applied.

Instead of sprinkling `toLowerCase()` here and there, we will associate
the original `Field` with the column mask and use `Field#canResove` to
do the matching. The problem with this is that there's no way to do
efficient lookups by name in a case-insensitive way, so we have to
iterate the list of `Field`-`Expression` pairs to find a match.

Fixes trinodb#24054.
  • Loading branch information
ksobolew committed Dec 18, 2024
1 parent 893fc42 commit 32dd720
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 22 deletions.
42 changes: 33 additions & 9 deletions core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
import java.util.Optional;
import java.util.OptionalLong;
import java.util.Set;
import java.util.stream.Stream;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
Expand Down Expand Up @@ -232,7 +233,7 @@ public class Analysis
private final Map<NodeRef<Table>, List<Expression>> checkConstraints = new LinkedHashMap<>();

private final Multiset<ColumnMaskScopeEntry> columnMaskScopes = HashMultiset.create();
private final Map<NodeRef<Table>, Map<String, Expression>> columnMasks = new LinkedHashMap<>();
private final Map<NodeRef<Table>, List<FieldExpression>> columnMasks = new LinkedHashMap<>();

private final Map<NodeRef<Unnest>, UnnestAnalysis> unnestAnalysis = new LinkedHashMap<>();
private Optional<Create> create = Optional.empty();
Expand Down Expand Up @@ -1161,17 +1162,15 @@ public void unregisterTableForColumnMasking(QualifiedObjectName table, String co
referenceChain.pop();
}

public void addColumnMask(Table table, String column, Expression mask)
public void addColumnMask(Table table, Field column, Expression mask)
{
Map<String, Expression> masks = columnMasks.computeIfAbsent(NodeRef.of(table), node -> new LinkedHashMap<>());
checkArgument(!masks.containsKey(column), "Mask already exists for column %s", column);

masks.put(column, mask);
columnMasks.computeIfAbsent(NodeRef.of(table), node -> new ArrayList<>())
.add(new FieldExpression(column, mask));
}

public Map<String, Expression> getColumnMasks(Table table)
public List<FieldExpression> getColumnMasks(Table table)
{
return unmodifiableMap(columnMasks.getOrDefault(NodeRef.of(table), ImmutableMap.of()));
return unmodifiableList(columnMasks.getOrDefault(NodeRef.of(table), ImmutableList.of()));
}

public List<TableInfo> getReferencedTables()
Expand All @@ -1189,7 +1188,7 @@ public List<TableInfo> getReferencedTables()
.distinct()
.map(fieldName -> new ColumnInfo(
fieldName,
Optional.ofNullable(columnMasks.getOrDefault(table, ImmutableMap.of()).get(fieldName))
resolveColumnMask(table.getNode().getName(), fieldName, columnMasks.getOrDefault(table, ImmutableList.of()))
.map(Expression::toString)))
.collect(toImmutableList());

Expand All @@ -1210,6 +1209,22 @@ public List<TableInfo> getReferencedTables()
.collect(toImmutableList());
}

public static Optional<Expression> resolveColumnMask(QualifiedName tableName, String fieldName, List<FieldExpression> expressions)
{
return expressions.stream()
.filter(fieldExpression -> fieldExpression.field().canResolve(concatIdentifier(tableName, fieldName)))
.findFirst()
.map(FieldExpression::expression);
}

private static QualifiedName concatIdentifier(QualifiedName tableName, String fieldName)
{
return QualifiedName.of(Stream.concat(
tableName.getOriginalParts().stream(),
Stream.of(new Identifier(fieldName)))
.collect(toImmutableList()));
}

public List<RoutineInfo> getRoutines()
{
return resolvedFunctions.values().stream()
Expand Down Expand Up @@ -2528,4 +2543,13 @@ public record JsonTableAnalysis(
requireNonNull(orderedOutputColumns, "orderedOutputColumns is null");
}
}

public record FieldExpression(Field field, Expression expression)
{
public FieldExpression
{
requireNonNull(field, "field is null");
requireNonNull(expression, "expression is null");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2377,23 +2377,33 @@ private void checkStorageTableNotRedirected(QualifiedObjectName source)

private void analyzeFiltersAndMasks(Table table, QualifiedObjectName name, RelationType relationType, Scope accessControlScope)
{
ImmutableList.Builder<Field> fieldBuilder = ImmutableList.builder();
ImmutableList.Builder<ColumnSchema> columnSchemaBuilder = ImmutableList.builder();
for (int index = 0; index < relationType.getAllFieldCount(); index++) {
Field field = relationType.getFieldByIndex(index);
field.getName().ifPresent(fieldName -> columnSchemaBuilder.add(ColumnSchema.builder()
.setName(fieldName)
.setType(field.getType())
.setHidden(field.isHidden())
.build()));
field.getName().ifPresent(fieldName -> {
fieldBuilder.add(field);
columnSchemaBuilder.add(ColumnSchema.builder()
.setName(fieldName)
.setType(field.getType())
.setHidden(field.isHidden())
.build());
});
}
List<Field> fields = fieldBuilder.build();
List<ColumnSchema> columnSchemas = columnSchemaBuilder.build();

Map<ColumnSchema, ViewExpression> masks = accessControl.getColumnMasks(session.toSecurityContext(), name, columnSchemas);

for (ColumnSchema columnSchema : columnSchemas) {
for (Field field : fields) {
ColumnSchema columnSchema = ColumnSchema.builder()
.setName(field.getName().orElseThrow())
.setType(field.getType())
.setHidden(field.isHidden())
.build();
Optional.ofNullable(masks.get(columnSchema)).ifPresent(mask -> {
if (checkCanSelectFromColumn(name, columnSchema.getName())) {
analyzeColumnMask(session.getIdentity().getUser(), table, name, columnSchema, accessControlScope, mask);
analyzeColumnMask(session.getIdentity().getUser(), table, name, field, accessControlScope, mask);
}
});
}
Expand Down Expand Up @@ -5222,9 +5232,9 @@ private void analyzeCheckConstraint(Table table, QualifiedObjectName name, Scope
analysis.addCheckConstraints(table, expression);
}

private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObjectName tableName, ColumnSchema columnSchema, Scope scope, ViewExpression mask)
private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObjectName tableName, Field columnSchema, Scope scope, ViewExpression mask)
{
String column = columnSchema.getName();
String column = columnSchema.getName().orElseThrow();
if (analysis.hasColumnMask(tableName, column, currentIdentity)) {
throw new TrinoException(INVALID_COLUMN_MASK, extractLocation(table), format("Column mask for '%s.%s' is recursive", tableName, column), null);
}
Expand Down Expand Up @@ -5284,7 +5294,7 @@ private void analyzeColumnMask(String currentIdentity, Table table, QualifiedObj
analysis.addCoercion(expression, expectedType);
}

analysis.addColumnMask(table, column, expression);
analysis.addColumnMask(table, columnSchema, expression);
}

private List<Expression> descriptorToFields(Scope scope)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ public RelationPlan addCheckConstraints(List<io.trino.sql.tree.Expression> const

private RelationPlan addColumnMasks(Table table, RelationPlan plan)
{
Map<String, io.trino.sql.tree.Expression> columnMasks = analysis.getColumnMasks(table);
List<Analysis.FieldExpression> columnMasks = analysis.getColumnMasks(table);

// A Table can represent a WITH query, which can have anonymous fields. On the other hand,
// it can't have masks. The loop below expects fields to have proper names, so bail out
Expand All @@ -441,10 +441,11 @@ private RelationPlan addColumnMasks(Table table, RelationPlan plan)
for (int i = 0; i < plan.getDescriptor().getAllFieldCount(); i++) {
Field field = plan.getDescriptor().getFieldByIndex(i);

io.trino.sql.tree.Expression mask = columnMasks.get(field.getName().orElseThrow());
Optional<io.trino.sql.tree.Expression> columnMask = Analysis.resolveColumnMask(table.getName(), field.getName().orElseThrow(), columnMasks);
Symbol symbol = plan.getFieldMappings().get(i);
Expression projection = symbol.toSymbolReference();
if (mask != null) {
if (columnMask.isPresent()) {
io.trino.sql.tree.Expression mask = columnMask.get();
planBuilder = subqueryPlanner.handleSubqueries(planBuilder, mask, analysis.getSubqueries(mask));
symbol = symbolAllocator.newSymbol(symbol);
projection = coerceIfNecessary(analysis, mask, planBuilder.rewrite(mask));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ public TestColumnMask()
Optional.of(VIEW_OWNER),
false,
ImmutableList.of());
ConnectorViewDefinition viewWithDifferentCase = new ConnectorViewDefinition(
"SELECT NATIONKEY, NAME FROM local.tiny.nation",
Optional.empty(),
Optional.empty(),
ImmutableList.of(
new ConnectorViewDefinition.ViewColumn("NATIONKEY", BigintType.BIGINT.getTypeId(), Optional.empty()),
new ConnectorViewDefinition.ViewColumn("NAME", VarcharType.createVarcharType(25).getTypeId(), Optional.empty())),
Optional.empty(),
Optional.of(VIEW_OWNER),
false,
ImmutableList.of());

ConnectorViewDefinition viewWithNested = new ConnectorViewDefinition(
"""
Expand Down Expand Up @@ -171,6 +182,7 @@ public TestColumnMask()
})
.withGetViews((s, prefix) -> ImmutableMap.of(
new SchemaTableName("default", "nation_view"), view,
new SchemaTableName("default", "nation_view_uppercase"), viewWithDifferentCase,
new SchemaTableName("default", "view_with_nested"), viewWithNested))
.withGetMaterializedViews((s, prefix) -> ImmutableMap.of(
new SchemaTableName("default", "nation_materialized_view"), materializedView,
Expand Down Expand Up @@ -456,6 +468,23 @@ public void testView()
assertThat(assertions.query("SELECT name FROM mock.default.nation_view WHERE nationkey = 1")).matches("VALUES CAST('ANITNEGRA' AS VARCHAR(25))");
}

@Test
public void testViewWithUppercaseColumnName()
{
accessControl.reset();
accessControl.columnMask(
new QualifiedObjectName(MOCK_CATALOG, "default", "nation_view_uppercase"),
"name",
USER,
ViewExpression.builder()
.identity(USER)
.catalog(LOCAL_CATALOG)
.schema("tiny")
.expression("reverse(name)")
.build());
assertThat(assertions.query("SELECT name FROM mock.default.nation_view_uppercase WHERE nationkey = 1")).matches("VALUES CAST('ANITNEGRA' AS VARCHAR(25))");
}

@Test
public void testTableReferenceInWithClause()
{
Expand Down

0 comments on commit 32dd720

Please sign in to comment.