From 5dcd46266d2182ad3a5c5868c189ffe7cea83b33 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Mon, 27 Jul 2020 01:14:43 +0200 Subject: [PATCH] Support parameter in OFFSET clause --- .../sql/analyzer/StatementAnalyzer.java | 16 ++-- .../execution/TestQueryPreparer.java | 12 +++ .../prestosql/sql/analyzer/TestAnalyzer.java | 7 -- .../antlr4/io/prestosql/sql/parser/SqlBase.g4 | 2 +- .../java/io/prestosql/sql/SqlFormatter.java | 5 +- .../io/prestosql/sql/parser/AstBuilder.java | 18 +++- .../sql/tree/DefaultTraversalVisitor.java | 14 +++ .../java/io/prestosql/sql/tree/Offset.java | 22 ++--- .../prestosql/sql/parser/TestSqlParser.java | 41 ++++++++- .../tests/AbstractTestEngineOnlyQueries.java | 92 +++++++++++++++++++ 10 files changed, 192 insertions(+), 37 deletions(-) diff --git a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java index fa28fccea3b6..df0a5124ca04 100644 --- a/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java +++ b/presto-main/src/main/java/io/prestosql/sql/analyzer/StatementAnalyzer.java @@ -933,7 +933,7 @@ protected Scope visitQuery(Query node, Optional scope) analysis.setOrderByExpressions(node, orderByExpressions); if (node.getOffset().isPresent()) { - analyzeOffset(node.getOffset().get()); + analyzeOffset(node.getOffset().get(), queryBodyScope); } if (node.getLimit().isPresent()) { @@ -1311,7 +1311,7 @@ protected Scope visitQuerySpecification(QuerySpecification node, Optional analysis.setOrderByExpressions(node, orderByExpressions); if (node.getOffset().isPresent()) { - analyzeOffset(node.getOffset().get()); + analyzeOffset(node.getOffset().get(), outputScope); } if (node.getLimit().isPresent()) { @@ -2658,14 +2658,16 @@ private List analyzeOrderBy(Node node, List sortItems, Sco return orderByFieldsBuilder.build(); } - private void analyzeOffset(Offset node) + private void analyzeOffset(Offset node, Scope scope) { long rowCount; - try { - rowCount = Long.parseLong(node.getRowCount()); + if (node.getRowCount() instanceof LongLiteral) { + rowCount = ((LongLiteral) node.getRowCount()).getValue(); } - catch (NumberFormatException e) { - throw semanticException(TYPE_MISMATCH, node, "Invalid OFFSET row count: %s", node.getRowCount()); + else { + checkState(node.getRowCount() instanceof Parameter, "unexpected OFFSET rowCount: " + node.getRowCount().getClass().getSimpleName()); + OptionalLong providedValue = analyzeParameterAsRowCount((Parameter) node.getRowCount(), scope, "OFFSET"); + rowCount = providedValue.orElse(0); } if (rowCount < 0) { throw semanticException(NUMERIC_VALUE_OUT_OF_RANGE, node, "OFFSET row count must be greater or equal to 0 (actual value: %s)", rowCount); diff --git a/presto-main/src/test/java/io/prestosql/execution/TestQueryPreparer.java b/presto-main/src/test/java/io/prestosql/execution/TestQueryPreparer.java index 6a78e917da1a..f17195148502 100644 --- a/presto-main/src/test/java/io/prestosql/execution/TestQueryPreparer.java +++ b/presto-main/src/test/java/io/prestosql/execution/TestQueryPreparer.java @@ -81,6 +81,18 @@ public void testTooFewParameters() .hasErrorCode(INVALID_PARAMETER_USAGE); } + @Test + public void testParameterMismatchWithOffset() + { + Session session = testSessionBuilder() + .addPreparedStatement("my_query", "SELECT ? FROM foo OFFSET ? ROWS") + .build(); + assertPrestoExceptionThrownBy(() -> QUERY_PREPARER.prepareQuery(session, "EXECUTE my_query USING 1")) + .hasErrorCode(INVALID_PARAMETER_USAGE); + assertPrestoExceptionThrownBy(() -> QUERY_PREPARER.prepareQuery(session, "EXECUTE my_query USING 1, 2, 3, 4, 5, 6")) + .hasErrorCode(INVALID_PARAMETER_USAGE); + } + @Test public void testParameterMismatchWithLimit() { diff --git a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java index e6dcfa969fd2..81be675f0c9e 100644 --- a/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java +++ b/presto-main/src/test/java/io/prestosql/sql/analyzer/TestAnalyzer.java @@ -486,13 +486,6 @@ public void testOrderByNonComparable() .hasErrorCode(TYPE_MISMATCH); } - @Test - public void testOffsetInvalidRowCount() - { - assertFails("SELECT * FROM t1 OFFSET 987654321098765432109876543210 ROWS") - .hasErrorCode(TYPE_MISMATCH); - } - @Test public void testFetchFirstInvalidRowCount() { diff --git a/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 b/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 index 5cfde68b230c..69a5d22cd55a 100644 --- a/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 +++ b/presto-parser/src/main/antlr4/io/prestosql/sql/parser/SqlBase.g4 @@ -163,7 +163,7 @@ property queryNoWith: queryTerm (ORDER BY sortItem (',' sortItem)*)? - (OFFSET offset=INTEGER_VALUE (ROW | ROWS)?)? + (OFFSET offset=rowCount (ROW | ROWS)?)? ((LIMIT limitRowCount) | (FETCH (FIRST | NEXT) (fetchFirst=rowCount)? (ROW | ROWS) (ONLY | WITH TIES)))? ; diff --git a/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java b/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java index f9835b5eae9f..fb32c21bbf6f 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java +++ b/presto-parser/src/main/java/io/prestosql/sql/SqlFormatter.java @@ -334,8 +334,9 @@ protected Void visitOrderBy(OrderBy node, Integer indent) @Override protected Void visitOffset(Offset node, Integer indent) { - append(indent, "OFFSET " + node.getRowCount() + " ROWS") - .append('\n'); + append(indent, "OFFSET ") + .append(formatExpression(node.getRowCount())) + .append(" ROWS\n"); return null; } diff --git a/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java b/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java index 7982a6558cbe..79af2b6334e4 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java +++ b/presto-parser/src/main/java/io/prestosql/sql/parser/AstBuilder.java @@ -645,6 +645,19 @@ public Node visitQueryNoWith(SqlBaseParser.QueryNoWithContext context) orderBy = Optional.of(new OrderBy(getLocation(context.ORDER()), visit(context.sortItem(), SortItem.class))); } + Optional offset = Optional.empty(); + if (context.OFFSET() != null) { + Expression rowCount; + if (context.offset.INTEGER_VALUE() != null) { + rowCount = new LongLiteral(getLocation(context.offset.INTEGER_VALUE()), context.offset.getText()); + } + else { + rowCount = new Parameter(getLocation(context.offset.PARAMETER()), parameterPosition); + parameterPosition++; + } + offset = Optional.of(new Offset(Optional.of(getLocation(context.OFFSET())), rowCount)); + } + Optional limit = Optional.empty(); if (context.FETCH() != null) { Optional rowCount = Optional.empty(); @@ -678,11 +691,6 @@ else if (context.limitRowCount().rowCount().INTEGER_VALUE() != null) { limit = Optional.of(new Limit(Optional.of(getLocation(context.LIMIT())), rowCount)); } - Optional offset = Optional.empty(); - if (context.OFFSET() != null) { - offset = Optional.of(new Offset(Optional.of(getLocation(context.OFFSET())), getTextIfPresent(context.offset).orElseThrow(() -> new IllegalStateException("Missing OFFSET row count")))); - } - if (term instanceof QuerySpecification) { // When we have a simple query specification // followed by order by, offset, limit or fetch, diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/DefaultTraversalVisitor.java b/presto-parser/src/main/java/io/prestosql/sql/tree/DefaultTraversalVisitor.java index e7508aed2a12..1765ad81ff3d 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/DefaultTraversalVisitor.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/DefaultTraversalVisitor.java @@ -116,6 +116,9 @@ protected Void visitQuery(Query node, C context) if (node.getOrderBy().isPresent()) { process(node.getOrderBy().get(), context); } + if (node.getOffset().isPresent()) { + process(node.getOffset().get(), context); + } if (node.getLimit().isPresent()) { process(node.getLimit().get(), context); } @@ -262,6 +265,14 @@ public Void visitFrameBound(FrameBound node, C context) return null; } + @Override + protected Void visitOffset(Offset node, C context) + { + process(node.getRowCount()); + + return null; + } + @Override protected Void visitLimit(Limit node, C context) { @@ -442,6 +453,9 @@ protected Void visitQuerySpecification(QuerySpecification node, C context) if (node.getOrderBy().isPresent()) { process(node.getOrderBy().get(), context); } + if (node.getOffset().isPresent()) { + process(node.getOffset().get(), context); + } if (node.getLimit().isPresent()) { process(node.getLimit().get(), context); } diff --git a/presto-parser/src/main/java/io/prestosql/sql/tree/Offset.java b/presto-parser/src/main/java/io/prestosql/sql/tree/Offset.java index b0aa82c39c4a..e4e62609696e 100644 --- a/presto-parser/src/main/java/io/prestosql/sql/tree/Offset.java +++ b/presto-parser/src/main/java/io/prestosql/sql/tree/Offset.java @@ -20,29 +20,33 @@ import java.util.Optional; import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkArgument; public class Offset extends Node { - private final String rowCount; + private final Expression rowCount; - public Offset(String rowCount) + public Offset(Expression rowCount) { this(Optional.empty(), rowCount); } - public Offset(NodeLocation location, String rowCount) + public Offset(NodeLocation location, Expression rowCount) { this(Optional.of(location), rowCount); } - public Offset(Optional location, String rowCount) + public Offset(Optional location, Expression rowCount) { super(location); + checkArgument(rowCount instanceof LongLiteral || rowCount instanceof Parameter, + "unexpected rowCount class: %s", + rowCount.getClass().getSimpleName()); this.rowCount = rowCount; } - public String getRowCount() + public Expression getRowCount() { return rowCount; } @@ -56,7 +60,7 @@ public R accept(AstVisitor visitor, C context) @Override public List getChildren() { - return ImmutableList.of(); + return ImmutableList.of(rowCount); } @Override @@ -89,10 +93,6 @@ public String toString() @Override public boolean shallowEquals(Node other) { - if (!sameClass(this, other)) { - return false; - } - - return Objects.equals(rowCount, ((Offset) other).rowCount); + return sameClass(this, other); } } diff --git a/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java b/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java index 2f716047bdf9..a718854f3618 100644 --- a/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java +++ b/presto-parser/src/test/java/io/prestosql/sql/parser/TestSqlParser.java @@ -889,7 +889,7 @@ public void testSelectWithOffset() Optional.empty(), Optional.empty(), Optional.empty(), - Optional.of(new Offset("2")), + Optional.of(new Offset(new LongLiteral("2"))), Optional.empty())); assertStatement("SELECT * FROM table1 OFFSET 2", @@ -900,7 +900,7 @@ public void testSelectWithOffset() Optional.empty(), Optional.empty(), Optional.empty(), - Optional.of(new Offset("2")), + Optional.of(new Offset(new LongLiteral("2"))), Optional.empty())); Query valuesQuery = query(values( @@ -914,7 +914,7 @@ public void testSelectWithOffset() Optional.empty(), Optional.empty(), Optional.empty(), - Optional.of(new Offset("2")), + Optional.of(new Offset(new LongLiteral("2"))), Optional.empty())); assertStatement("SELECT * FROM (VALUES (1, '1'), (2, '2')) OFFSET 2", @@ -924,7 +924,7 @@ public void testSelectWithOffset() Optional.empty(), Optional.empty(), Optional.empty(), - Optional.of(new Offset("2")), + Optional.of(new Offset(new LongLiteral("2"))), Optional.empty())); } @@ -2150,6 +2150,39 @@ public void testPrepareWithParameters() Optional.empty(), Optional.empty(), Optional.of(new FetchFirst(new Parameter(2), true))))); + + assertStatement("PREPARE myquery FROM SELECT ?, ? FROM foo OFFSET ? ROWS", + new Prepare(identifier("myquery"), simpleQuery( + selectList(new Parameter(0), new Parameter(1)), + table(QualifiedName.of("foo")), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new Offset(new Parameter(2))), + Optional.empty()))); + + assertStatement("PREPARE myquery FROM SELECT ? FROM foo OFFSET ? ROWS LIMIT ?", + new Prepare(identifier("myquery"), simpleQuery( + selectList(new Parameter(0)), + table(QualifiedName.of("foo")), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new Offset(new Parameter(1))), + Optional.of(new Limit(new Parameter(2)))))); + + assertStatement("PREPARE myquery FROM SELECT ? FROM foo OFFSET ? ROWS FETCH FIRST ? ROWS WITH TIES", + new Prepare(identifier("myquery"), simpleQuery( + selectList(new Parameter(0)), + table(QualifiedName.of("foo")), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.of(new Offset(new Parameter(1))), + Optional.of(new FetchFirst(new Parameter(2), true))))); } @Test diff --git a/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java b/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java index 0bc5ae5768dd..e9c95b23c7d9 100644 --- a/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java +++ b/presto-tests/src/test/java/io/prestosql/tests/AbstractTestEngineOnlyQueries.java @@ -808,6 +808,86 @@ public void testExecuteWithParametersInFetchFirst() "\\Qline 1:1: Invalid numeric literal: 99999999999999999999\\E"); } + @Test + public void testExecuteWithParametersInOffset() + { + String query = "SELECT a FROM (VALUES 1, 2, 2, 3) t(a) where a = ? OFFSET ? ROWS"; + Session session = Session.builder(getSession()) + .addPreparedStatement("my_query", query) + .build(); + + assertQuery( + session, + "EXECUTE my_query USING 2, 1", + "SELECT 2"); + + assertQuery( + session, + "EXECUTE my_query USING 2, 4 - 3", + "SELECT 2"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 'one'", + "\\Qline 1:27: Cannot cast type varchar(3) to bigint\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 1.0", + "\\Qline 1:27: Cannot cast type decimal(2,1) to bigint\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 1 + t.a", + "\\Qline 1:29: Constant expression cannot contain column references\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, null", + "\\Qline 1:59: Parameter value provided for OFFSET is NULL: null\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 1 + null", + "\\Qline 1:59: Parameter value provided for OFFSET is NULL: (1 + null)\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, ?", + "\\Qline 1:27: No value provided for parameter\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, -2", + "\\Qline 1:52: OFFSET row count must be greater or equal to 0 (actual value: -2)\\E"); + + assertQueryFails( + session, + "EXECUTE my_query USING 2, 99999999999999999999", + "\\Qline 1:1: Invalid numeric literal: 99999999999999999999\\E"); + } + + @Test + public void testExecuteWithParametersInDifferentClauses() + { + String query1 = "SELECT a FROM (VALUES 1, 2, 2, 2, 2, 2) t(a) where a = ? OFFSET ? ROWS LIMIT ?"; + String query2 = "SELECT a FROM (VALUES 1, 2, 2, 2, 2, 2) t(a) where a = ? ORDER BY a OFFSET ? ROWS FETCH FIRST ? ROWS WITH TIES"; + Session session = Session.builder(getSession()) + .addPreparedStatement("my_query_1", query1) + .addPreparedStatement("my_query_2", query2) + .build(); + + assertQuery( + session, + "EXECUTE my_query_1 USING 2, 1, 3", + "VALUES 2, 2, 2"); + + assertQuery( + session, + "EXECUTE my_query_2 USING 2, 1, 3", + "VALUES 2, 2, 2, 2"); + } + @Test public void testExecuteUsingWithWithClause() { @@ -871,6 +951,18 @@ public void testDescribeInput() .build(); assertEqualsIgnoreOrder(actual, expected); + session = Session.builder(getSession()) + .addPreparedStatement("my_query", "SELECT ? FROM nation WHERE nationkey = ? and name < ? OFFSET ?") + .build(); + actual = computeActual(session, "DESCRIBE INPUT my_query"); + expected = resultBuilder(session, BIGINT, VARCHAR) + .row(0, "unknown") + .row(1, "bigint") + .row(2, "varchar(25)") + .row(3, "bigint") + .build(); + assertEqualsIgnoreOrder(actual, expected); + session = Session.builder(getSession()) .addPreparedStatement("my_query", "SELECT ? FROM nation WHERE nationkey = ? and name < ? LIMIT ?") .build();