From cf24a49aeab92737a893ae4ad4832bd123392b97 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Sun, 26 Jul 2020 23:10:04 +0200 Subject: [PATCH] Support parameter in FETCH FIRST clause --- .../sql/analyzer/StatementAnalyzer.java | 40 +++++----- .../execution/TestQueryPreparer.java | 12 +++ .../prestosql/sql/analyzer/TestAnalyzer.java | 2 - .../antlr4/io/prestosql/sql/parser/SqlBase.g4 | 2 +- .../java/io/prestosql/sql/SqlFormatter.java | 2 +- .../io/prestosql/sql/parser/AstBuilder.java | 12 ++- .../sql/tree/DefaultTraversalVisitor.java | 8 ++ .../io/prestosql/sql/tree/FetchFirst.java | 23 +++--- .../prestosql/sql/parser/TestSqlParser.java | 26 ++++++- .../tests/AbstractTestEngineOnlyQueries.java | 76 +++++++++++++++++++ 10 files changed, 168 insertions(+), 35 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java index f1b6c4ca0b9f..fa28fccea3b6 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java @@ -2683,31 +2683,33 @@ private boolean analyzeLimit(Node node, Scope scope) node instanceof FetchFirst || node instanceof Limit, "Invalid limit node type. Expected: FetchFirst or Limit. Actual: %s", node.getClass().getName()); if (node instanceof FetchFirst) { - return analyzeLimit((FetchFirst) node); + return analyzeLimit((FetchFirst) node, scope); } else { return analyzeLimit((Limit) node, scope); } } - private boolean analyzeLimit(FetchFirst node) + private boolean analyzeLimit(FetchFirst node, Scope scope) { - if (node.getRowCount().isEmpty()) { - analysis.setLimit(node, 1); - } - else { - long rowCount; - try { - rowCount = Long.parseLong(node.getRowCount().get()); - } - catch (NumberFormatException e) { - throw semanticException(TYPE_MISMATCH, node, "Invalid FETCH FIRST row count: %s", node.getRowCount().get()); + long rowCount = 1; + if (node.getRowCount().isPresent()) { + Expression count = node.getRowCount().get(); + if (count instanceof LongLiteral) { + rowCount = ((LongLiteral) count).getValue(); } - if (rowCount <= 0) { - throw semanticException(NUMERIC_VALUE_OUT_OF_RANGE, node, "FETCH FIRST row count must be positive (actual value: %s)", rowCount); + else { + checkState(count instanceof Parameter, "unexpected FETCH FIRST rowCount: " + count.getClass().getSimpleName()); + OptionalLong providedValue = analyzeParameterAsRowCount((Parameter) count, scope, "FETCH FIRST"); + if (providedValue.isPresent()) { + rowCount = providedValue.getAsLong(); + } } - analysis.setLimit(node, rowCount); } + if (rowCount <= 0) { + throw semanticException(NUMERIC_VALUE_OUT_OF_RANGE, node, "FETCH FIRST row count must be positive (actual value: %s)", rowCount); + } + analysis.setLimit(node, rowCount); return node.isWithTies(); } @@ -2723,7 +2725,7 @@ else if (node.getRowCount() instanceof LongLiteral) { } else { checkState(node.getRowCount() instanceof Parameter, "unexpected LIMIT rowCount: " + node.getRowCount().getClass().getSimpleName()); - rowCount = analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope); + rowCount = analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, "LIMIT"); } rowCount.ifPresent(count -> { if (count < 0) { @@ -2736,7 +2738,7 @@ else if (node.getRowCount() instanceof LongLiteral) { return false; } - private OptionalLong analyzeParameterAsRowCount(Parameter parameter, Scope scope) + private OptionalLong analyzeParameterAsRowCount(Parameter parameter, Scope scope, String context) { if (analysis.isDescribe()) { analyzeExpression(parameter, scope); @@ -2758,10 +2760,10 @@ private OptionalLong analyzeParameterAsRowCount(Parameter parameter, Scope scope analysis.getParameters()); } catch (VerifyException e) { - throw semanticException(INVALID_ARGUMENTS, parameter, "Non constant parameter value for LIMIT: %s", providedValue); + throw semanticException(INVALID_ARGUMENTS, parameter, "Non constant parameter value for %s: %s", context, providedValue); } if (value == null) { - throw semanticException(INVALID_ARGUMENTS, parameter, "Parameter value provided for LIMIT is NULL: %s", providedValue); + throw semanticException(INVALID_ARGUMENTS, parameter, "Parameter value provided for %s is NULL: %s", context, providedValue); } return OptionalLong.of((long) value); } diff --git a/presto-main/src/test/java/io/prestosql/execution/TestQueryPreparer.java b/presto-main/src/test/java/io/prestosql/execution/TestQueryPreparer.java index 34a05455b8ef..6a78e917da1a 100644 --- a/presto-main/src/test/java/io/prestosql/execution/TestQueryPreparer.java +++ b/presto-main/src/test/java/io/prestosql/execution/TestQueryPreparer.java @@ -92,4 +92,16 @@ public void testParameterMismatchWithLimit() assertPrestoExceptionThrownBy(() -> QUERY_PREPARER.prepareQuery(session, "EXECUTE my_query USING 1, 2, 3, 4, 5, 6")) .hasErrorCode(INVALID_PARAMETER_USAGE); } + + @Test + public void testParameterMismatchWithFetchFirst() + { + Session session = testSessionBuilder() + .addPreparedStatement("my_query", "SELECT ? FROM foo FETCH FIRST ? ROWS ONLY") + .build(); + assertPrestoExceptionThrownBy(() -> QUERY_PREPARER.prepareQuery(session, "EXECUTE my_query USING 1")) + .hasErrorCode(INVALID_PARAMETER_USAGE); + assertPrestoExceptionThrownBy(() -> QUERY_PREPARER.prepareQuery(session, "EXECUTE my_query USING 1, 2, 3, 4, 5, 6")) + .hasErrorCode(INVALID_PARAMETER_USAGE); + } } diff --git a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java index 3a66b77109bf..e6dcfa969fd2 100644 --- a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java @@ -496,8 +496,6 @@ public void testOffsetInvalidRowCount() @Test public void testFetchFirstInvalidRowCount() { - assertFails("SELECT * FROM t1 FETCH FIRST 987654321098765432109876543210 ROWS ONLY") - .hasErrorCode(TYPE_MISMATCH); assertFails("SELECT * FROM t1 FETCH FIRST 0 ROWS ONLY") .hasErrorCode(NUMERIC_VALUE_OUT_OF_RANGE); } diff --git a/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 b/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 index 7bb2a738250f..5cfde68b230c 100644 --- a/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 +++ b/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 @@ -164,7 +164,7 @@ queryNoWith: queryTerm (ORDER BY sortItem (',' sortItem)*)? (OFFSET offset=INTEGER_VALUE (ROW | ROWS)?)? - ((LIMIT limitRowCount) | (FETCH (FIRST | NEXT) (fetchFirst=INTEGER_VALUE)? (ROW | ROWS) (ONLY | WITH TIES)))? + ((LIMIT limitRowCount) | (FETCH (FIRST | NEXT) (fetchFirst=rowCount)? (ROW | ROWS) (ONLY | WITH TIES)))? ; limitRowCount diff --git a/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java b/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java index 4487c14f45bc..f9835b5eae9f 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java +++ b/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java @@ -342,7 +342,7 @@ protected Void visitOffset(Offset node, Integer indent) @Override protected Void visitFetchFirst(FetchFirst node, Integer indent) { - append(indent, "FETCH FIRST " + node.getRowCount().map(c -> c + " ROWS ").orElse("ROW ")) + append(indent, "FETCH FIRST " + node.getRowCount().map(count -> formatExpression(count) + " ROWS ").orElse("ROW ")) .append(node.isWithTies() ? "WITH TIES" : "ONLY") .append('\n'); return null; diff --git a/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java b/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java index 57828ce01d05..7982a6558cbe 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java @@ -647,7 +647,17 @@ public Node visitQueryNoWith(SqlBaseParser.QueryNoWithContext context) Optional limit = Optional.empty(); if (context.FETCH() != null) { - limit = Optional.of(new FetchFirst(Optional.of(getLocation(context.FETCH())), getTextIfPresent(context.fetchFirst), context.TIES() != null)); + Optional rowCount = Optional.empty(); + if (context.fetchFirst != null) { + if (context.fetchFirst.INTEGER_VALUE() != null) { + rowCount = Optional.of(new LongLiteral(getLocation(context.fetchFirst.INTEGER_VALUE()), context.fetchFirst.getText())); + } + else { + rowCount = Optional.of(new Parameter(getLocation(context.fetchFirst.PARAMETER()), parameterPosition)); + parameterPosition++; + } + } + limit = Optional.of(new FetchFirst(Optional.of(getLocation(context.FETCH())), rowCount, context.TIES() != null)); } else if (context.LIMIT() != null) { if (context.limitRowCount() == null) { diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/DefaultTraversalVisitor.java b/presto-parser/src/main/java/io/prestosql/sql/tree/DefaultTraversalVisitor.java index 7b3e7f9e1003..e7508aed2a12 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/DefaultTraversalVisitor.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/DefaultTraversalVisitor.java @@ -270,6 +270,14 @@ protected Void visitLimit(Limit node, C context) return null; } + @Override + protected Void visitFetchFirst(FetchFirst node, C context) + { + node.getRowCount().ifPresent(this::process); + + return null; + } + @Override protected Void visitSimpleCaseExpression(SimpleCaseExpression node, C context) { diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/FetchFirst.java b/presto-parser/src/main/java/io/prestosql/sql/tree/FetchFirst.java index 424869efee1c..541fef1144d0 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/FetchFirst.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/FetchFirst.java @@ -20,41 +20,46 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; public class FetchFirst extends Node { - private final Optional rowCount; + private final Optional rowCount; private final boolean withTies; - public FetchFirst(String rowCount) + public FetchFirst(Expression rowCount) { this(Optional.empty(), Optional.of(rowCount), false); } - public FetchFirst(String rowCount, boolean withTies) + public FetchFirst(Expression rowCount, boolean withTies) { this(Optional.empty(), Optional.of(rowCount), withTies); } - public FetchFirst(Optional rowCount) + public FetchFirst(Optional rowCount) { this(Optional.empty(), rowCount, false); } - public FetchFirst(Optional rowCount, boolean withTies) + public FetchFirst(Optional rowCount, boolean withTies) { this(Optional.empty(), rowCount, withTies); } - public FetchFirst(Optional location, Optional rowCount, boolean withTies) + public FetchFirst(Optional location, Optional rowCount, boolean withTies) { super(location); + rowCount.ifPresent(count -> checkArgument( + count instanceof LongLiteral || count instanceof Parameter, + "unexpected rowCount class: %s", + rowCount.getClass().getSimpleName())); this.rowCount = rowCount; this.withTies = withTies; } - public Optional getRowCount() + public Optional getRowCount() { return rowCount; } @@ -73,7 +78,7 @@ public R accept(AstVisitor visitor, C context) @Override public List getChildren() { - return ImmutableList.of(); + return rowCount.map(ImmutableList::of).orElse(ImmutableList.of()); } @Override @@ -115,6 +120,6 @@ public boolean shallowEquals(Node other) FetchFirst otherNode = (FetchFirst) other; - return withTies == otherNode.withTies && rowCount.equals(otherNode.rowCount); + return withTies == otherNode.withTies; } } diff --git a/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java b/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java index 9d1cfce06036..2f716047bdf9 100644 --- a/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java +++ b/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java @@ -940,7 +940,7 @@ public void testSelectWithFetch() Optional.empty(), Optional.empty(), Optional.empty(), - Optional.of(new FetchFirst("2")))); + Optional.of(new FetchFirst(new LongLiteral("2"))))); assertStatement("SELECT * FROM table1 FETCH NEXT ROW ONLY", simpleQuery( @@ -988,7 +988,7 @@ public void testSelectWithFetch() Optional.empty(), Optional.empty(), Optional.empty(), - Optional.of(new FetchFirst("2", true)))); + Optional.of(new FetchFirst(new LongLiteral("2"), true)))); assertStatement("SELECT * FROM table1 FETCH NEXT ROW WITH TIES", simpleQuery( @@ -2128,6 +2128,28 @@ public void testPrepareWithParameters() Optional.empty(), Optional.empty(), Optional.of(new Limit(new Parameter(2)))))); + + assertStatement("PREPARE myquery FROM SELECT ? FROM foo FETCH FIRST ? ROWS ONLY", + new Prepare(identifier("myquery"), simpleQuery( + selectList(new Parameter(0)), + table(QualifiedName.of("foo")), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new FetchFirst(new Parameter(1)))))); + + assertStatement("PREPARE myquery FROM SELECT ?, ? FROM foo FETCH NEXT ? ROWS WITH TIES", + new Prepare(identifier("myquery"), simpleQuery( + selectList(new Parameter(0), new Parameter(1)), + table(QualifiedName.of("foo")), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new FetchFirst(new Parameter(2), true))))); } @Test diff --git a/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java b/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java index ba232127c0bc..0bc5ae5768dd 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java +++ b/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java @@ -744,6 +744,70 @@ public void testExecuteWithParametersInLimit() "\\Qline 1:1: Invalid numeric literal: 99999999999999999999\\E"); } + @Test + public void testExecuteWithParametersInFetchFirst() + { + String query = "SELECT a FROM (VALUES 1, 2, 2, 3) t(a) where a = ? FETCH FIRST ? ROW ONLY"; + Session session = Session.builder(getSession()) + .addPreparedStatement("my_query", query) + .build(); + + assertQuery( + session, + "EXECUTE my_query USING 2, 1", + "SELECT 2"); + + assertQuery( + session, + "EXECUTE my_query USING 2, 4 - 3", + "SELECT 2"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 'one'", + "\\Qline 1:27: Cannot cast type varchar(3) to bigint\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 1.0", + "\\Qline 1:27: Cannot cast type decimal(2,1) to bigint\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 1 + t.a", + "\\Qline 1:29: Constant expression cannot contain column references\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, null", + "\\Qline 1:64: Parameter value provided for FETCH FIRST is NULL: null\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 1 + null", + "\\Qline 1:64: Parameter value provided for FETCH FIRST is NULL: (1 + null)\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, ?", + "\\Qline 1:27: No value provided for parameter\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, -2", + "\\Qline 1:52: FETCH FIRST row count must be positive (actual value: -2)\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 0", + "\\Qline 1:52: FETCH FIRST row count must be positive (actual value: 0)\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 99999999999999999999", + "\\Qline 1:1: Invalid numeric literal: 99999999999999999999\\E"); + } + @Test public void testExecuteUsingWithWithClause() { @@ -819,6 +883,18 @@ public void testDescribeInput() .build(); assertEqualsIgnoreOrder(actual, expected); + session = Session.builder(getSession()) + .addPreparedStatement("my_query", "SELECT ? FROM nation WHERE nationkey = ? and name < ? FETCH FIRST ? ROWS ONLY") + .build(); + actual = computeActual(session, "DESCRIBE INPUT my_query"); + expected = resultBuilder(session, BIGINT, VARCHAR) + .row(0, "unknown") + .row(1, "bigint") + .row(2, "varchar(25)") + .row(3, "bigint") + .build(); + assertEqualsIgnoreOrder(actual, expected); + session = Session.builder(getSession()) .setSystemProperty("omit_datetime_type_precision", "false") .addPreparedStatement(