Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support check constraints for MERGE #18137

Merged
merged 3 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3414,15 +3414,12 @@ protected Scope visitMerge(Merge merge, Optional<Scope> scope)
if (!accessControl.getRowFilters(session.toSecurityContext(), tableName).isEmpty()) {
throw semanticException(NOT_SUPPORTED, merge, "Cannot merge into a table with row filters");
}
if (!tableSchema.getTableSchema().getCheckConstraints().isEmpty()) {
// TODO https://github.com/trinodb/trino/issues/15411 Add support for CHECK constraint to MERGE statement
throw semanticException(NOT_SUPPORTED, merge, "Cannot merge into a table with check constraints");
}

Scope mergeScope = createScope(scope);
Scope targetTableScope = analyzer.analyzeForUpdate(relation, Optional.of(mergeScope), UpdateKind.MERGE);
Scope sourceTableScope = process(merge.getSource(), mergeScope);
Scope joinScope = createAndAssignScope(merge, Optional.of(mergeScope), targetTableScope.getRelationType().joinWith(sourceTableScope.getRelationType()));
analyzeCheckConstraints(table, tableName, targetTableScope, tableSchema.getTableSchema().getCheckConstraints());
analysis.registerTable(table, redirection.tableHandle(), tableName, session.getIdentity().getUser(), targetTableScope);

for (ColumnSchema column : dataColumnSchemas) {
if (accessControl.getColumnMask(session.toSecurityContext(), tableName, column.getName(), column.getType()).isPresent()) {
Expand Down
pajaks marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -683,23 +683,7 @@ public PlanNode plan(Update node)
idAllocator.getNextId(),
subPlanBuilder.getRoot(),
assignments.build()));

PlanBuilder constraintBuilder = subPlanBuilder.appendProjections(constraints, symbolAllocator, idAllocator);

List<Expression> predicates = new ArrayList<>();
for (Expression constraint : constraints) {
Expression symbol = constraintBuilder.translate(constraint).toSymbolReference();

Expression predicate = new IfExpression(
// When predicate evaluates to UNKNOWN (e.g. NULL > 100), it should not violate the check constraint.
new CoalesceExpression(coerceIfNecessary(analysis, symbol, symbol), TRUE_LITERAL),
TRUE_LITERAL,
new Cast(failFunction(plannerContext.getMetadata(), session, CONSTRAINT_VIOLATION, "Check constraint violation: " + constraint), toSqlType(BOOLEAN)));

predicates.add(predicate);
}

subPlanBuilder = subPlanBuilder.withNewRoot(new FilterNode(idAllocator.getNextId(), constraintBuilder.getRoot(), and(predicates)));
subPlanBuilder = addCheckConstraints(constraints, subPlanBuilder);
}

// Build the page, containing:
Expand Down Expand Up @@ -734,6 +718,26 @@ public PlanNode plan(Update node)
return createMergePipeline(table, relationPlan, projectNode, rowIdSymbol, mergeRowSymbol);
}

private PlanBuilder addCheckConstraints(List<Expression> constraints, PlanBuilder subPlanBuilder)
{
PlanBuilder constraintBuilder = subPlanBuilder.appendProjections(constraints, symbolAllocator, idAllocator);

List<Expression> predicates = new ArrayList<>();
for (Expression constraint : constraints) {
Expression symbol = constraintBuilder.translate(constraint).toSymbolReference();

Expression predicate = new IfExpression(
// When predicate evaluates to UNKNOWN (e.g. NULL > 100), it should not violate the check constraint.
new CoalesceExpression(coerceIfNecessary(analysis, symbol, symbol), TRUE_LITERAL),
TRUE_LITERAL,
new Cast(failFunction(plannerContext.getMetadata(), session, CONSTRAINT_VIOLATION, "Check constraint violation: " + constraint), toSqlType(BOOLEAN)));

predicates.add(predicate);
}

return subPlanBuilder.withNewRoot(new FilterNode(idAllocator.getNextId(), constraintBuilder.getRoot(), and(predicates)));
}

public MergeWriterNode plan(Merge merge)
{
MergeAnalysis mergeAnalysis = analysis.getMergeAnalysis().orElseThrow(() -> new IllegalArgumentException("analysis.getMergeAnalysis() isn't present"));
Expand Down Expand Up @@ -773,6 +777,9 @@ public MergeWriterNode plan(Merge merge)

PlanBuilder subPlan = newPlanBuilder(joinPlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext);

FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable());
Symbol rowIdSymbol = planWithPresentColumn.getFieldMappings().get(rowIdReference.getFieldIndex());

// Build the SearchedCaseExpression that creates the project merge_row
Metadata metadata = plannerContext.getMetadata();
List<ColumnSchema> dataColumnSchemas = mergeAnalysis.getDataColumnSchemas();
Expand All @@ -790,25 +797,28 @@ public MergeWriterNode plan(Merge merge)
}

ImmutableList.Builder<Expression> rowBuilder = ImmutableList.builder();
Assignments.Builder assignments = Assignments.builder();
List<ColumnHandle> mergeCaseSetColumns = mergeCaseColumnsHandles.get(caseNumber);
for (ColumnHandle dataColumnHandle : mergeAnalysis.getDataColumnHandles()) {
int index = mergeCaseSetColumns.indexOf(dataColumnHandle);
int fieldNumber = mergeAnalysis.getColumnHandleFieldNumbers().get(dataColumnHandle);
Symbol field = planWithPresentColumn.getFieldMappings().get(fieldNumber);
if (index >= 0) {
Expression setExpression = mergeCase.getSetExpressions().get(index);
subPlan = subqueryPlanner.handleSubqueries(subPlan, setExpression, analysis.getSubqueries(merge));
Expression rewritten = subPlan.rewrite(setExpression);
rewritten = coerceIfNecessary(analysis, setExpression, rewritten);
if (nonNullableColumnHandles.contains(dataColumnHandle)) {
int fieldIndex = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(dataColumnHandle), "Could not find fieldIndex for non nullable column");
ColumnSchema columnSchema = dataColumnSchemas.get(fieldIndex);
ColumnSchema columnSchema = dataColumnSchemas.get(fieldNumber);
String columnName = columnSchema.getName();
rewritten = new CoalesceExpression(rewritten, new Cast(failFunction(metadata, session, INVALID_ARGUMENTS, "Assigning NULL to non-null MERGE target table column " + columnName), toSqlType(columnSchema.getType())));
}
rowBuilder.add(rewritten);
assignments.put(field, rewritten);
}
else {
Integer fieldNumber = requireNonNull(mergeAnalysis.getColumnHandleFieldNumbers().get(dataColumnHandle), "Field number for ColumnHandle is null");
rowBuilder.add(planWithPresentColumn.getFieldMappings().get(fieldNumber).toSymbolReference());
rowBuilder.add(field.toSymbolReference());
assignments.putIdentity(field);
}
}

Expand All @@ -834,6 +844,19 @@ public MergeWriterNode plan(Merge merge)
}

whenClauses.add(new WhenClause(condition, new Row(rowBuilder.build())));

List<Expression> constraints = analysis.getCheckConstraints(mergeAnalysis.getTargetTable());
if (!constraints.isEmpty()) {
assignments.putIdentity(uniqueIdSymbol);
assignments.putIdentity(presentColumn);
assignments.putIdentity(rowIdSymbol);
assignments.putIdentities(source.getFieldMappings());
subPlan = subPlan.withNewRoot(new ProjectNode(
idAllocator.getNextId(),
subPlan.getRoot(),
assignments.build()));
subPlan = addCheckConstraints(constraints, subPlan.withScope(targetTablePlan.getScope(), targetTablePlan.getFieldMappings()));
}
}

// Build the "else" clause for the SearchedCaseExpression
Expand All @@ -848,8 +871,6 @@ public MergeWriterNode plan(Merge merge)

SearchedCaseExpression caseExpression = new SearchedCaseExpression(whenClauses.build(), Optional.of(new Row(rowBuilder.build())));

FieldReference rowIdReference = analysis.getRowIdField(mergeAnalysis.getTargetTable());
Symbol rowIdSymbol = planWithPresentColumn.getFieldMappings().get(rowIdReference.getFieldIndex());
Symbol mergeRowSymbol = symbolAllocator.newSymbol("merge_row", mergeAnalysis.getMergeRowType());
Symbol caseNumberSymbol = symbolAllocator.newSymbol("case_number", INTEGER);

Expand Down
Loading