Skip to content

Commit

Permalink
Add utility to resolve failure function
Browse files Browse the repository at this point in the history
  • Loading branch information
electrum committed Jul 21, 2022
1 parent d1feda8 commit 855ee03
Show file tree
Hide file tree
Showing 10 changed files with 56 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import io.trino.metadata.TableHandle;
import io.trino.metadata.TableLayout;
import io.trino.metadata.TableMetadata;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
Expand Down Expand Up @@ -86,15 +86,13 @@
import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.Insert;
import io.trino.sql.tree.LambdaArgumentDeclaration;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.Query;
import io.trino.sql.tree.RefreshMaterializedView;
import io.trino.sql.tree.Row;
import io.trino.sql.tree.Statement;
import io.trino.sql.tree.StringLiteral;
import io.trino.sql.tree.Table;
import io.trino.sql.tree.TableExecute;
import io.trino.sql.tree.Update;
Expand All @@ -121,6 +119,7 @@
import static com.google.common.collect.Streams.zip;
import static io.trino.SystemSessionProperties.isCollectPlanStatisticsForAllQueries;
import static io.trino.metadata.MetadataUtil.createQualifiedObjectName;
import static io.trino.spi.StandardErrorCode.INVALID_CAST_ARGUMENT;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.StandardErrorCode.PERMISSION_DENIED;
import static io.trino.spi.statistics.TableStatisticType.ROW_COUNT;
Expand Down Expand Up @@ -479,7 +478,7 @@ private RelationPlan getInsertPlan(
plan = planner.addRowFilters(
table,
plan,
failIfPredicateIsNotMeet(metadata, session, PERMISSION_DENIED, AccessDeniedException.PREFIX + "Cannot insert row that does not match to a row filter"),
failIfPredicateIsNotMet(metadata, session, PERMISSION_DENIED, AccessDeniedException.PREFIX + "Cannot insert row that does not match to a row filter"),
node -> {
Scope accessControlScope = analysis.getAccessControlScope(table);
// hidden fields are not accessible in insert
Expand Down Expand Up @@ -523,19 +522,19 @@ private RelationPlan getInsertPlan(
statisticsMetadata);
}

private static Function<Expression, Expression> failIfPredicateIsNotMeet(Metadata metadata, Session session, StandardErrorCode errorCode, String errorMessage)
private static Function<Expression, Expression> failIfPredicateIsNotMet(Metadata metadata, Session session, ErrorCodeSupplier errorCode, String errorMessage)
{
ResolvedFunction fail = metadata.resolveFunction(session, QualifiedName.of("fail"), fromTypes(INTEGER, VARCHAR));
return predicate -> new IfExpression(
predicate,
TRUE_LITERAL,
new Cast(
new FunctionCall(
fail.toQualifiedName(),
ImmutableList.of(
new Cast(new LongLiteral(Long.toString(errorCode.toErrorCode().getCode())), toSqlType(INTEGER)),
new Cast(new StringLiteral(errorMessage), toSqlType(VARCHAR)))),
toSqlType(BOOLEAN)));
FunctionCall fail = failFunction(metadata, session, errorCode, errorMessage);
return predicate -> new IfExpression(predicate, TRUE_LITERAL, new Cast(fail, toSqlType(BOOLEAN)));
}

public static FunctionCall failFunction(Metadata metadata, Session session, ErrorCodeSupplier errorCode, String errorMessage)
{
return FunctionCallBuilder.resolve(session, metadata)
.setName(QualifiedName.of("fail"))
.addArgument(INTEGER, new GenericLiteral("INTEGER", Integer.toString(errorCode.toErrorCode().getCode())))
.addArgument(VARCHAR, new GenericLiteral("VARCHAR", errorMessage))
.build();
}

private RelationPlan createInsertPlan(Analysis analysis, Insert insertStatement)
Expand Down Expand Up @@ -700,7 +699,6 @@ private Expression noTruncationCast(Expression expression, Type fromType, Type t

checkState(fromType instanceof VarcharType || fromType instanceof CharType, "inserting non-character value to column of character type");
ResolvedFunction spaceTrimmedLength = metadata.resolveFunction(session, QualifiedName.of("$space_trimmed_length"), fromTypes(VARCHAR));
ResolvedFunction fail = metadata.resolveFunction(session, QualifiedName.of("fail"), fromTypes(VARCHAR));

return new IfExpression(
// check if the trimmed value fits in the target type
Expand All @@ -714,14 +712,10 @@ private Expression noTruncationCast(Expression expression, Type fromType, Type t
new GenericLiteral("BIGINT", "0"))),
new Cast(expression, toSqlType(toType)),
new Cast(
new FunctionCall(
fail.toQualifiedName(),
ImmutableList.of(new Cast(
new StringLiteral(format(
"Cannot truncate non-space characters when casting from %s to %s on INSERT",
fromType.getDisplayName(),
toType.getDisplayName())),
toSqlType(VARCHAR)))),
failFunction(metadata, session, INVALID_CAST_ARGUMENT, format(
"Cannot truncate non-space characters when casting from %s to %s on INSERT",
fromType.getDisplayName(),
toType.getDisplayName())),
toSqlType(toType)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@
import io.trino.sql.tree.QuerySpecification;
import io.trino.sql.tree.Relation;
import io.trino.sql.tree.SortItem;
import io.trino.sql.tree.StringLiteral;
import io.trino.sql.tree.Table;
import io.trino.sql.tree.Union;
import io.trino.sql.tree.Update;
Expand Down Expand Up @@ -115,15 +114,16 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static io.trino.SystemSessionProperties.getMaxRecursionDepth;
import static io.trino.SystemSessionProperties.isSkipRedundantSort;
import static io.trino.spi.StandardErrorCode.INVALID_WINDOW_FRAME;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarbinaryType.VARBINARY;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.NodeUtils.getSortItemsFromOrderBy;
import static io.trino.sql.analyzer.ExpressionAnalyzer.isNumericType;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.planner.GroupingOperationRewriter.rewriteGroupingOperation;
import static io.trino.sql.planner.LogicalPlanner.failFunction;
import static io.trino.sql.planner.OrderingScheme.sortItemToSortOrder;
import static io.trino.sql.planner.PlanBuilder.newPlanBuilder;
import static io.trino.sql.planner.ScopeAware.scopeAwareKey;
Expand Down Expand Up @@ -296,17 +296,14 @@ public RelationPlan planExpand(Query query)
0);

// 2. append filter to fail on non-empty result
ResolvedFunction fail = plannerContext.getMetadata().resolveFunction(session, QualifiedName.of("fail"), fromTypes(VARCHAR));
String recursionLimitExceededMessage = format("Recursion depth limit exceeded (%s). Use 'max_recursion_depth' session property to modify the limit.", maxRecursionDepth);
Expression predicate = new IfExpression(
new ComparisonExpression(
GREATER_THAN_OR_EQUAL,
countSymbol.toSymbolReference(),
new GenericLiteral("BIGINT", "0")),
new Cast(
new FunctionCall(
fail.toQualifiedName(),
ImmutableList.of(new Cast(new StringLiteral(recursionLimitExceededMessage), toSqlType(VARCHAR)))),
failFunction(plannerContext.getMetadata(), session, NOT_SUPPORTED, recursionLimitExceededMessage),
toSqlType(BOOLEAN)),
TRUE_LITERAL);
FilterNode filterNode = new FilterNode(idAllocator.getNextId(), windowNode, predicate);
Expand Down Expand Up @@ -1052,17 +1049,14 @@ private FrameBoundPlanAndSymbols planFrameBound(PlanBuilder subPlan, PlanAndMapp
// First, append filter to validate offset values. They mustn't be negative or null.
Symbol offsetSymbol = coercions.get(frameOffset.get());
Expression zeroOffset = zeroOfType(symbolAllocator.getTypes().get(offsetSymbol));
ResolvedFunction fail = plannerContext.getMetadata().resolveFunction(session, QualifiedName.of("fail"), fromTypes(VARCHAR));
Expression predicate = new IfExpression(
new ComparisonExpression(
GREATER_THAN_OR_EQUAL,
offsetSymbol.toSymbolReference(),
zeroOffset),
TRUE_LITERAL,
new Cast(
new FunctionCall(
fail.toQualifiedName(),
ImmutableList.of(new Cast(new StringLiteral("Window frame offset value must not be negative or null"), toSqlType(VARCHAR)))),
failFunction(plannerContext.getMetadata(), session, INVALID_WINDOW_FRAME, "Window frame offset value must not be negative or null"),
toSqlType(BOOLEAN)));
subPlan = subPlan.withNewRoot(new FilterNode(
idAllocator.getNextId(),
Expand Down Expand Up @@ -1158,14 +1152,11 @@ private FrameOffsetPlanAndSymbol planFrameOffset(PlanBuilder subPlan, Optional<S

// Append filter to validate offset values. They mustn't be negative or null.
Expression zeroOffset = zeroOfType(offsetType);
ResolvedFunction fail = plannerContext.getMetadata().resolveFunction(session, QualifiedName.of("fail"), fromTypes(VARCHAR));
Expression predicate = new IfExpression(
new ComparisonExpression(GREATER_THAN_OR_EQUAL, offsetSymbol.toSymbolReference(), zeroOffset),
TRUE_LITERAL,
new Cast(
new FunctionCall(
fail.toQualifiedName(),
ImmutableList.of(new Cast(new StringLiteral("Window frame offset value must not be negative or null"), toSqlType(VARCHAR)))),
failFunction(plannerContext.getMetadata(), session, INVALID_WINDOW_FRAME, "Window frame offset value must not be negative or null"),
toSqlType(BOOLEAN)));
subPlan = subPlan.withNewRoot(new FilterNode(
idAllocator.getNextId(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.metadata.ResolvedFunction;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
Expand All @@ -46,24 +45,22 @@
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.IfExpression;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.NullLiteral;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.StringLiteral;

import java.util.List;
import java.util.Optional;

import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.matching.Pattern.nonEmpty;
import static io.trino.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.planner.LogicalPlanner.failFunction;
import static io.trino.sql.planner.iterative.rule.ImplementLimitWithTies.rewriteLimitWithTiesWithPartitioning;
import static io.trino.sql.planner.iterative.rule.Util.restrictOutputs;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isScalar;
Expand Down Expand Up @@ -410,16 +407,13 @@ public RewriteResult visitEnforceSingleRow(EnforceSingleRowNode node, Void conte
Optional.of(2),
Optional.empty());
}
ResolvedFunction fail = metadata.resolveFunction(session, QualifiedName.of("fail"), fromTypes(VARCHAR));
Expression predicate = new IfExpression(
new ComparisonExpression(
GREATER_THAN,
rowNumberSymbol.toSymbolReference(),
new GenericLiteral("BIGINT", "1")),
new Cast(
new FunctionCall(
fail.toQualifiedName(),
ImmutableList.of(new Cast(new StringLiteral("Scalar sub-query has returned multiple rows"), toSqlType(VARCHAR)))),
failFunction(metadata, session, SUBQUERY_MULTIPLE_ROWS, "Scalar sub-query has returned multiple rows"),
toSqlType(BOOLEAN)),
TRUE_LITERAL);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.FunctionCallBuilder;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.AssignUniqueId;
Expand All @@ -31,20 +30,16 @@
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.SimpleCaseExpression;
import io.trino.sql.tree.StringLiteral;
import io.trino.sql.tree.WhenClause;

import java.util.Optional;

import static io.trino.matching.Pattern.nonEmpty;
import static io.trino.spi.StandardErrorCode.SUBQUERY_MULTIPLE_ROWS;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType;
import static io.trino.sql.planner.LogicalPlanner.failFunction;
import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.extractCardinality;
import static io.trino.sql.planner.plan.CorrelatedJoinNode.Type.LEFT;
Expand Down Expand Up @@ -163,11 +158,7 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co
ImmutableList.of(
new WhenClause(TRUE_LITERAL, TRUE_LITERAL)),
Optional.of(new Cast(
FunctionCallBuilder.resolve(context.getSession(), metadata)
.setName(QualifiedName.of("fail"))
.addArgument(INTEGER, new LongLiteral(Integer.toString(SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode())))
.addArgument(VARCHAR, new StringLiteral("Scalar sub-query has returned multiple rows"))
.build(),
failFunction(metadata, context.getSession(), SUBQUERY_MULTIPLE_ROWS, "Scalar sub-query has returned multiple rows"),
toSqlType(BOOLEAN)))));

return Result.ofPlanNode(new ProjectNode(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ public void testCorrelatedScalarSubqueryInSelect()
assertDistributedPlan("SELECT name, (SELECT name FROM region WHERE regionkey = nation.regionkey) FROM nation",
noJoinReordering(),
anyTree(
filter(format("CASE \"is_distinct\" WHEN true THEN true ELSE CAST(fail(%s, 'Scalar sub-query has returned multiple rows') AS boolean) END", SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()),
filter(format("CASE \"is_distinct\" WHEN true THEN true ELSE CAST(fail(%d, VARCHAR 'Scalar sub-query has returned multiple rows') AS boolean) END", SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()),
project(
markDistinct("is_distinct", ImmutableList.of("unique"),
join(LEFT, ImmutableList.of(equiJoinClause("n_regionkey", "r_regionkey")),
Expand All @@ -807,7 +807,7 @@ public void testCorrelatedScalarSubqueryInSelect()
assertDistributedPlan("SELECT name, (SELECT name FROM region WHERE regionkey = nation.regionkey) FROM nation",
automaticJoinDistribution(),
anyTree(
filter(format("CASE \"is_distinct\" WHEN true THEN true ELSE CAST(fail(%s, 'Scalar sub-query has returned multiple rows') AS boolean) END", SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()),
filter(format("CASE \"is_distinct\" WHEN true THEN true ELSE CAST(fail(%d, VARCHAR 'Scalar sub-query has returned multiple rows') AS boolean) END", SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()),
project(
markDistinct("is_distinct", ImmutableList.of("unique"),
join(LEFT, ImmutableList.of(equiJoinClause("n_regionkey", "r_regionkey")),
Expand Down Expand Up @@ -1028,13 +1028,13 @@ public void testCorrelatedDistinctAggregationRewriteToLeftOuterJoin()
}

@Test
public void testCorrelatedDistinctGropuedAggregationRewriteToLeftOuterJoin()
public void testCorrelatedDistinctGroupedAggregationRewriteToLeftOuterJoin()
{
assertPlan(
"SELECT (SELECT count(DISTINCT o.orderkey) FROM orders o WHERE c.custkey = o.custkey GROUP BY o.orderstatus), c.custkey FROM customer c",
output(
project(filter(
"(CASE \"is_distinct\" WHEN true THEN true ELSE CAST(fail(28, 'Scalar sub-query has returned multiple rows') AS boolean) END)",
format("CASE \"is_distinct\" WHEN true THEN true ELSE CAST(fail(%d, VARCHAR 'Scalar sub-query has returned multiple rows') AS boolean) END", SUBQUERY_MULTIPLE_ROWS.toErrorCode().getCode()),
project(markDistinct(
"is_distinct",
ImmutableList.of("unique"),
Expand Down
Loading

0 comments on commit 855ee03

Please sign in to comment.