diff --git a/x-pack/plugin/sql/src/main/antlr/SqlBase.g4 b/x-pack/plugin/sql/src/main/antlr/SqlBase.g4 index 84e98da4852ad..3bed074b03a2f 100644 --- a/x-pack/plugin/sql/src/main/antlr/SqlBase.g4 +++ b/x-pack/plugin/sql/src/main/antlr/SqlBase.g4 @@ -186,7 +186,7 @@ predicated // instead the property kind is used to differentiate predicate : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression - | NOT? kind=IN '(' expression (',' expression)* ')' + | NOT? kind=IN '(' valueExpression (',' valueExpression)* ')' | NOT? kind=IN '(' query ')' | NOT? kind=LIKE pattern | NOT? kind=RLIKE regex=string diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/ExpressionBuilder.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/ExpressionBuilder.java index f7d659a2933da..a75ad78521f7f 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/ExpressionBuilder.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/ExpressionBuilder.java @@ -226,7 +226,7 @@ public Expression visitPredicated(PredicatedContext ctx) { if (pCtx.query() != null) { throw new ParsingException(loc, "IN query not supported yet"); } - e = new In(loc, exp, expressions(pCtx.expression())); + e = new In(loc, exp, expressions(pCtx.valueExpression())); break; case SqlBaseParser.LIKE: e = new Like(loc, exp, visitPattern(pCtx.pattern())); diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlBaseParser.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlBaseParser.java index 56996e4c4c2e4..323aeea30eaf2 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlBaseParser.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlBaseParser.java @@ -3363,12 +3363,6 @@ public ValueExpressionContext valueExpression(int i) { return getRuleContext(ValueExpressionContext.class,i); } public TerminalNode NOT() { return getToken(SqlBaseParser.NOT, 0); } - public List expression() { - return getRuleContexts(ExpressionContext.class); - } - public ExpressionContext expression(int i) { - return getRuleContext(ExpressionContext.class,i); - } public TerminalNode IN() { return getToken(SqlBaseParser.IN, 0); } public QueryContext query() { return getRuleContext(QueryContext.class,0); @@ -3449,7 +3443,7 @@ public final PredicateContext predicate() throws RecognitionException { setState(502); match(T__0); setState(503); - expression(); + valueExpression(0); setState(508); _errHandler.sync(this); _la = _input.LA(1); @@ -3459,7 +3453,7 @@ public final PredicateContext predicate() throws RecognitionException { setState(504); match(T__2); setState(505); - expression(); + valueExpression(0); } } setState(510); @@ -6616,7 +6610,7 @@ private boolean valueExpression_sempred(ValueExpressionContext _localctx, int pr "\u01f0\7\16\2\2\u01f0\u01f1\5<\37\2\u01f1\u01f2\7\n\2\2\u01f2\u01f3\5"+ "<\37\2\u01f3\u021b\3\2\2\2\u01f4\u01f6\7=\2\2\u01f5\u01f4\3\2\2\2\u01f5"+ "\u01f6\3\2\2\2\u01f6\u01f7\3\2\2\2\u01f7\u01f8\7-\2\2\u01f8\u01f9\7\3"+ - "\2\2\u01f9\u01fe\5,\27\2\u01fa\u01fb\7\5\2\2\u01fb\u01fd\5,\27\2\u01fc"+ + "\2\2\u01f9\u01fe\5<\37\2\u01fa\u01fb\7\5\2\2\u01fb\u01fd\5<\37\2\u01fc"+ "\u01fa\3\2\2\2\u01fd\u0200\3\2\2\2\u01fe\u01fc\3\2\2\2\u01fe\u01ff\3\2"+ "\2\2\u01ff\u0201\3\2\2\2\u0200\u01fe\3\2\2\2\u0201\u0202\7\4\2\2\u0202"+ "\u021b\3\2\2\2\u0203\u0205\7=\2\2\u0204\u0203\3\2\2\2\u0204\u0205\3\2"+ diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlParser.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlParser.java index fcc0d50bd0515..6835dc43cdc77 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlParser.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/parser/SqlParser.java @@ -26,6 +26,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.xpack.sql.expression.Expression; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanDefaultContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanExpressionContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryPrimaryDefaultContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryTermContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementDefaultContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionDefaultContext; import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue; @@ -214,10 +222,26 @@ public void exitNonReserved(SqlBaseParser.NonReservedContext context) { /** * Used to catch large expressions that can lead to stack overflows */ - private class CircuitBreakerListener extends SqlBaseBaseListener { + static class CircuitBreakerListener extends SqlBaseBaseListener { private static final short MAX_RULE_DEPTH = 200; + /** + * Due to the structure of the grammar and our custom handling in {@link ExpressionBuilder} + * some expressions can exit with a different class than they entered: + * e.g.: ValueExpressionContext can exit as ValueExpressionDefaultContext + */ + private static final Map ENTER_EXIT_RULE_MAPPING = new HashMap<>(); + + static { + ENTER_EXIT_RULE_MAPPING.put(StatementDefaultContext.class.getSimpleName(), StatementContext.class.getSimpleName()); + ENTER_EXIT_RULE_MAPPING.put(QueryPrimaryDefaultContext.class.getSimpleName(), QueryTermContext.class.getSimpleName()); + ENTER_EXIT_RULE_MAPPING.put(BooleanDefaultContext.class.getSimpleName(), BooleanExpressionContext.class.getSimpleName()); + ENTER_EXIT_RULE_MAPPING.put(ValueExpressionDefaultContext.class.getSimpleName(), ValueExpressionContext.class.getSimpleName()); + } + + private boolean insideIn = false; + // Keep current depth for every rule visited. // The totalDepth alone cannot be used as expressions like: e1 OR e2 OR e3 OR ... // are processed as e1 OR (e2 OR (e3 OR (... and this results in the totalDepth not growing @@ -226,9 +250,18 @@ private class CircuitBreakerListener extends SqlBaseBaseListener { @Override public void enterEveryRule(ParserRuleContext ctx) { + if (inDetected(ctx)) { + insideIn = true; + } + + // Skip PrimaryExpressionContext for IN as it's not visited on exit due to + // the grammar's peculiarity rule with "predicated" and "predicate". + // Also skip the Identifiers as they are "cheap". if (ctx.getClass() != SqlBaseParser.UnquoteIdentifierContext.class && ctx.getClass() != SqlBaseParser.QuoteIdentifierContext.class && - ctx.getClass() != SqlBaseParser.BackQuotedIdentifierContext.class) { + ctx.getClass() != SqlBaseParser.BackQuotedIdentifierContext.class && + (insideIn == false || ctx.getClass() != SqlBaseParser.PrimaryExpressionContext.class)) { + int currentDepth = depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 1, (short) 1); if (currentDepth > MAX_RULE_DEPTH) { throw new ParsingException(source(ctx), "SQL statement too large; " + @@ -240,12 +273,35 @@ public void enterEveryRule(ParserRuleContext ctx) { @Override public void exitEveryRule(ParserRuleContext ctx) { - // Avoid having negative numbers - if (depthCounts.containsKey(ctx.getClass().getSimpleName())) { - depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 0, (short) -1); + if (inDetected(ctx)) { + insideIn = false; } + + decrementCounter(ctx); super.exitEveryRule(ctx); } + + ObjectShortHashMap depthCounts() { + return depthCounts; + } + + private void decrementCounter(ParserRuleContext ctx) { + String className = ctx.getClass().getSimpleName(); + String classNameToDecrement = ENTER_EXIT_RULE_MAPPING.getOrDefault(className, className); + + // Avoid having negative numbers + if (depthCounts.containsKey(classNameToDecrement)) { + depthCounts.putOrAdd(classNameToDecrement, (short) 0, (short) -1); + } + } + + private boolean inDetected(ParserRuleContext ctx) { + if (ctx.getParent() != null && ctx.getParent().getClass() == SqlBaseParser.PredicateContext.class) { + SqlBaseParser.PredicateContext pc = (SqlBaseParser.PredicateContext) ctx.getParent(); + return pc.kind != null && pc.kind.getType() == SqlBaseParser.IN; + } + return false; + } } private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() { diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java index 2794481dc8075..04966ff2fe879 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/parser/SqlParserTests.java @@ -15,6 +15,13 @@ import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate; import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MultiMatchQueryPredicate; import org.elasticsearch.xpack.sql.expression.predicate.fulltext.StringQueryPredicate; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanExpressionContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryPrimaryDefaultContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryTermContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementDefaultContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionContext; +import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionDefaultContext; import org.elasticsearch.xpack.sql.plan.logical.Filter; import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.sql.plan.logical.OrderBy; @@ -254,6 +261,40 @@ public void testLimitToPreventStackOverflowFromLargeComplexSubselectTree() { e.getMessage()); } + public void testLimitStackOverflowForInAndLiteralsIsNotApplied() { + new SqlParser().createStatement("SELECT * FROM t WHERE a IN(" + + Joiner.on(",").join(nCopies(100_000, "a + b")) + ")"); + } + + public void testDecrementOfDepthCounter() { + SqlParser.CircuitBreakerListener cbl = new SqlParser.CircuitBreakerListener(); + StatementContext sc = new StatementContext(); + QueryTermContext qtc = new QueryTermContext(); + ValueExpressionContext vec = new ValueExpressionContext(); + BooleanExpressionContext bec = new BooleanExpressionContext(); + + cbl.enterEveryRule(sc); + cbl.enterEveryRule(sc); + cbl.enterEveryRule(qtc); + cbl.enterEveryRule(qtc); + cbl.enterEveryRule(qtc); + cbl.enterEveryRule(vec); + cbl.enterEveryRule(bec); + cbl.enterEveryRule(bec); + + cbl.exitEveryRule(new StatementDefaultContext(sc)); + cbl.exitEveryRule(new StatementDefaultContext(sc)); + cbl.exitEveryRule(new QueryPrimaryDefaultContext(qtc)); + cbl.exitEveryRule(new QueryPrimaryDefaultContext(qtc)); + cbl.exitEveryRule(new ValueExpressionDefaultContext(vec)); + cbl.exitEveryRule(new SqlBaseParser.BooleanDefaultContext(bec)); + + assertEquals(0, cbl.depthCounts().get(SqlBaseParser.StatementContext.class.getSimpleName())); + assertEquals(1, cbl.depthCounts().get(SqlBaseParser.QueryTermContext.class.getSimpleName())); + assertEquals(0, cbl.depthCounts().get(SqlBaseParser.ValueExpressionContext.class.getSimpleName())); + assertEquals(1, cbl.depthCounts().get(SqlBaseParser.BooleanExpressionContext.class.getSimpleName())); + } + private LogicalPlan parseStatement(String sql) { return new SqlParser().createStatement(sql); }