Skip to content

Commit

Permalink
SQL: Fix wrong appliance of StackOverflow limit for IN (#36724)
Browse files Browse the repository at this point in the history
Fix grammar so that each element inside the list of values for IN
is a valueExpression and not a more generic expression. Introduce a
mapping for context names as some rules in the grammar are exited with
a different rule from the one they entered.This helps so that the decrement
of depth counts in the Parser's CircuitBreakerListener works correctly.

For the list of values for IN, don't count the
PrimaryExpressionContext as this is not visited on exitRule() due to
the peculiarity in our gramamr with the predicate and predicated.

Fixes: #36592
  • Loading branch information
matriv committed Dec 18, 2018
1 parent d41eb6f commit 7d2e2cf
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 18 deletions.
2 changes: 1 addition & 1 deletion x-pack/plugin/sql/src/main/antlr/SqlBase.g4
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ predicated
// instead the property kind is used to differentiate
predicate
: NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression
| NOT? kind=IN '(' expression (',' expression)* ')'
| NOT? kind=IN '(' valueExpression (',' valueExpression)* ')'
| NOT? kind=IN '(' query ')'
| NOT? kind=LIKE pattern
| NOT? kind=RLIKE regex=string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ public Expression visitPredicated(PredicatedContext ctx) {
if (pCtx.query() != null) {
throw new ParsingException(loc, "IN query not supported yet");
}
e = new In(loc, exp, expressions(pCtx.expression()));
e = new In(loc, exp, expressions(pCtx.valueExpression()));
break;
case SqlBaseParser.LIKE:
e = new Like(loc, exp, visitPattern(pCtx.pattern()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3363,12 +3363,6 @@ public ValueExpressionContext valueExpression(int i) {
return getRuleContext(ValueExpressionContext.class,i);
}
public TerminalNode NOT() { return getToken(SqlBaseParser.NOT, 0); }
public List<ExpressionContext> expression() {
return getRuleContexts(ExpressionContext.class);
}
public ExpressionContext expression(int i) {
return getRuleContext(ExpressionContext.class,i);
}
public TerminalNode IN() { return getToken(SqlBaseParser.IN, 0); }
public QueryContext query() {
return getRuleContext(QueryContext.class,0);
Expand Down Expand Up @@ -3449,7 +3443,7 @@ public final PredicateContext predicate() throws RecognitionException {
setState(502);
match(T__0);
setState(503);
expression();
valueExpression(0);
setState(508);
_errHandler.sync(this);
_la = _input.LA(1);
Expand All @@ -3459,7 +3453,7 @@ public final PredicateContext predicate() throws RecognitionException {
setState(504);
match(T__2);
setState(505);
expression();
valueExpression(0);
}
}
setState(510);
Expand Down Expand Up @@ -6616,7 +6610,7 @@ private boolean valueExpression_sempred(ValueExpressionContext _localctx, int pr
"\u01f0\7\16\2\2\u01f0\u01f1\5<\37\2\u01f1\u01f2\7\n\2\2\u01f2\u01f3\5"+
"<\37\2\u01f3\u021b\3\2\2\2\u01f4\u01f6\7=\2\2\u01f5\u01f4\3\2\2\2\u01f5"+
"\u01f6\3\2\2\2\u01f6\u01f7\3\2\2\2\u01f7\u01f8\7-\2\2\u01f8\u01f9\7\3"+
"\2\2\u01f9\u01fe\5,\27\2\u01fa\u01fb\7\5\2\2\u01fb\u01fd\5,\27\2\u01fc"+
"\2\2\u01f9\u01fe\5<\37\2\u01fa\u01fb\7\5\2\2\u01fb\u01fd\5<\37\2\u01fc"+
"\u01fa\3\2\2\2\u01fd\u0200\3\2\2\2\u01fe\u01fc\3\2\2\2\u01fe\u01ff\3\2"+
"\2\2\u01ff\u0201\3\2\2\2\u0200\u01fe\3\2\2\2\u0201\u0202\7\4\2\2\u0202"+
"\u021b\3\2\2\2\u0203\u0205\7=\2\2\u0204\u0203\3\2\2\2\u0204\u0205\3\2"+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,18 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BackQuotedIdentifierContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.PrimaryExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryPrimaryDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryTermContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QuoteIdentifierContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.UnquoteIdentifierContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionDefaultContext;
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.sql.proto.SqlTypedParamValue;

Expand Down Expand Up @@ -214,10 +226,26 @@ public void exitNonReserved(SqlBaseParser.NonReservedContext context) {
/**
* Used to catch large expressions that can lead to stack overflows
*/
private class CircuitBreakerListener extends SqlBaseBaseListener {
static class CircuitBreakerListener extends SqlBaseBaseListener {

private static final short MAX_RULE_DEPTH = 200;

/**
* Due to the structure of the grammar and our custom handling in {@link ExpressionBuilder}
* some expressions can exit with a different class than they entered:
* e.g.: ValueExpressionContext can exit as ValueExpressionDefaultContext
*/
private static final Map<String, String> ENTER_EXIT_RULE_MAPPING = new HashMap<>();

static {
ENTER_EXIT_RULE_MAPPING.put(StatementDefaultContext.class.getSimpleName(), StatementContext.class.getSimpleName());
ENTER_EXIT_RULE_MAPPING.put(QueryPrimaryDefaultContext.class.getSimpleName(), QueryTermContext.class.getSimpleName());
ENTER_EXIT_RULE_MAPPING.put(BooleanDefaultContext.class.getSimpleName(), BooleanExpressionContext.class.getSimpleName());
ENTER_EXIT_RULE_MAPPING.put(ValueExpressionDefaultContext.class.getSimpleName(), ValueExpressionContext.class.getSimpleName());
}

private boolean insideIn = false;

// Keep current depth for every rule visited.
// The totalDepth alone cannot be used as expressions like: e1 OR e2 OR e3 OR ...
// are processed as e1 OR (e2 OR (e3 OR (... and this results in the totalDepth not growing
Expand All @@ -226,9 +254,18 @@ private class CircuitBreakerListener extends SqlBaseBaseListener {

@Override
public void enterEveryRule(ParserRuleContext ctx) {
if (ctx.getClass() != SqlBaseParser.UnquoteIdentifierContext.class &&
ctx.getClass() != SqlBaseParser.QuoteIdentifierContext.class &&
ctx.getClass() != SqlBaseParser.BackQuotedIdentifierContext.class) {
if (inDetected(ctx)) {
insideIn = true;
}

// Skip PrimaryExpressionContext for IN as it's not visited on exit due to
// the grammar's peculiarity rule with "predicated" and "predicate".
// Also skip the Identifiers as they are "cheap".
if (ctx.getClass() != UnquoteIdentifierContext.class &&
ctx.getClass() != QuoteIdentifierContext.class &&
ctx.getClass() != BackQuotedIdentifierContext.class &&
(insideIn == false || ctx.getClass() != PrimaryExpressionContext.class)) {

int currentDepth = depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 1, (short) 1);
if (currentDepth > MAX_RULE_DEPTH) {
throw new ParsingException(source(ctx), "SQL statement too large; " +
Expand All @@ -240,12 +277,35 @@ public void enterEveryRule(ParserRuleContext ctx) {

@Override
public void exitEveryRule(ParserRuleContext ctx) {
// Avoid having negative numbers
if (depthCounts.containsKey(ctx.getClass().getSimpleName())) {
depthCounts.putOrAdd(ctx.getClass().getSimpleName(), (short) 0, (short) -1);
if (inDetected(ctx)) {
insideIn = false;
}

decrementCounter(ctx);
super.exitEveryRule(ctx);
}

ObjectShortHashMap<String> depthCounts() {
return depthCounts;
}

private void decrementCounter(ParserRuleContext ctx) {
String className = ctx.getClass().getSimpleName();
String classNameToDecrement = ENTER_EXIT_RULE_MAPPING.getOrDefault(className, className);

// Avoid having negative numbers
if (depthCounts.containsKey(classNameToDecrement)) {
depthCounts.putOrAdd(classNameToDecrement, (short) 0, (short) -1);
}
}

private boolean inDetected(ParserRuleContext ctx) {
if (ctx.getParent() != null && ctx.getParent().getClass() == SqlBaseParser.PredicateContext.class) {
SqlBaseParser.PredicateContext pc = (SqlBaseParser.PredicateContext) ctx.getParent();
return pc.kind != null && pc.kind.getType() == SqlBaseParser.IN;
}
return false;
}
}

private static final BaseErrorListener ERROR_LISTENER = new BaseErrorListener() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,19 @@
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MatchQueryPredicate;
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.MultiMatchQueryPredicate;
import org.elasticsearch.xpack.sql.expression.predicate.fulltext.StringQueryPredicate;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.BooleanExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryPrimaryDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.QueryTermContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.StatementDefaultContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionContext;
import org.elasticsearch.xpack.sql.parser.SqlBaseParser.ValueExpressionDefaultContext;
import org.elasticsearch.xpack.sql.plan.logical.Filter;
import org.elasticsearch.xpack.sql.plan.logical.LogicalPlan;
import org.elasticsearch.xpack.sql.plan.logical.OrderBy;
import org.elasticsearch.xpack.sql.plan.logical.Project;
import org.elasticsearch.xpack.sql.plan.logical.With;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -28,6 +37,7 @@
import static org.hamcrest.Matchers.hasEntry;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;

public class SqlParserTests extends ESTestCase {

Expand Down Expand Up @@ -254,6 +264,51 @@ public void testLimitToPreventStackOverflowFromLargeComplexSubselectTree() {
e.getMessage());
}

public void testLimitStackOverflowForInAndLiteralsIsNotApplied() {
int noChildren = 100_000;
LogicalPlan plan = parseStatement("SELECT * FROM t WHERE a IN(" +
Joiner.on(",").join(nCopies(noChildren, "a + b")) + ")");

assertEquals(With.class, plan.getClass());
assertEquals(Project.class, ((With) plan).child().getClass());
assertEquals(Filter.class, ((Project) ((With) plan).child()).child().getClass());
Filter filter = (Filter) ((Project) ((With) plan).child()).child();
assertEquals(In.class, filter.condition().getClass());
In in = (In) filter.condition();
assertEquals("?a", in.value().toString());
assertEquals(noChildren, in.list().size());
assertThat(in.list().get(0).toString(), startsWith("(a) + (b)#"));
}

public void testDecrementOfDepthCounter() {
SqlParser.CircuitBreakerListener cbl = new SqlParser.CircuitBreakerListener();
StatementContext sc = new StatementContext();
QueryTermContext qtc = new QueryTermContext();
ValueExpressionContext vec = new ValueExpressionContext();
BooleanExpressionContext bec = new BooleanExpressionContext();

cbl.enterEveryRule(sc);
cbl.enterEveryRule(sc);
cbl.enterEveryRule(qtc);
cbl.enterEveryRule(qtc);
cbl.enterEveryRule(qtc);
cbl.enterEveryRule(vec);
cbl.enterEveryRule(bec);
cbl.enterEveryRule(bec);

cbl.exitEveryRule(new StatementDefaultContext(sc));
cbl.exitEveryRule(new StatementDefaultContext(sc));
cbl.exitEveryRule(new QueryPrimaryDefaultContext(qtc));
cbl.exitEveryRule(new QueryPrimaryDefaultContext(qtc));
cbl.exitEveryRule(new ValueExpressionDefaultContext(vec));
cbl.exitEveryRule(new SqlBaseParser.BooleanDefaultContext(bec));

assertEquals(0, cbl.depthCounts().get(SqlBaseParser.StatementContext.class.getSimpleName()));
assertEquals(1, cbl.depthCounts().get(SqlBaseParser.QueryTermContext.class.getSimpleName()));
assertEquals(0, cbl.depthCounts().get(SqlBaseParser.ValueExpressionContext.class.getSimpleName()));
assertEquals(1, cbl.depthCounts().get(SqlBaseParser.BooleanExpressionContext.class.getSimpleName()));
}

private LogicalPlan parseStatement(String sql) {
return new SqlParser().createStatement(sql);
}
Expand Down

0 comments on commit 7d2e2cf

Please sign in to comment.