Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL: Resolve attributes recursively for improved subquery support #69765

Merged
merged 9 commits into from
Mar 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
*/
package org.elasticsearch.xpack.ql.expression;

import org.elasticsearch.xpack.ql.QlIllegalArgumentException;

import java.util.AbstractSet;
import java.util.Collection;
import java.util.Iterator;
Expand Down Expand Up @@ -148,6 +150,8 @@ public static final <E> AttributeMap<E> emptyAttributeMap() {
return EMPTY;
}

private static final Object NOT_FOUND = new Object();

private final Map<AttributeWrapper, E> delegate;
private Set<Attribute> keySet = null;
private Collection<E> values = null;
Expand Down Expand Up @@ -238,14 +242,14 @@ public boolean isEmpty() {
@Override
public boolean containsKey(Object key) {
if (key instanceof NamedExpression) {
return delegate.keySet().contains(new AttributeWrapper(((NamedExpression) key).toAttribute()));
return delegate.containsKey(new AttributeWrapper(((NamedExpression) key).toAttribute()));
}
return false;
}

@Override
public boolean containsValue(Object value) {
return delegate.values().contains(value);
return delegate.containsValue(value);
}

@Override
Expand All @@ -258,10 +262,32 @@ public E get(Object key) {

@Override
public E getOrDefault(Object key, E defaultValue) {
E e;
return (((e = get(key)) != null) || containsKey(key))
? e
: defaultValue;
if (key instanceof NamedExpression) {
return delegate.getOrDefault(new AttributeWrapper(((NamedExpression) key).toAttribute()), defaultValue);
}
return defaultValue;
}

public E resolve(Object key) {
return resolve(key, null);
}

public E resolve(Object key, E defaultValue) {
E value = defaultValue;
E candidate = null;
int allowedLookups = 1000;
while ((candidate = get(key)) != null || containsKey(key)) {
// instead of circling around, return
if (candidate == key) {
return candidate;
}
if (--allowedLookups == 0) {
throw new QlIllegalArgumentException("Potential cycle detected");
}
key = candidate;
value = candidate;
}
return value;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.ql.expression;

import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.ql.QlIllegalArgumentException;
import org.elasticsearch.xpack.ql.tree.Source;
import org.elasticsearch.xpack.ql.type.DataTypes;

Expand All @@ -17,6 +18,8 @@
import java.util.Set;

import static java.util.stream.Collectors.toList;
import static org.elasticsearch.xpack.ql.TestUtils.fieldAttribute;
import static org.elasticsearch.xpack.ql.TestUtils.of;
import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.arrayWithSize;
import static org.hamcrest.Matchers.contains;
Expand Down Expand Up @@ -61,6 +64,64 @@ public void testAttributeMapWithSameAliasesCanResolveAttributes() {
assertTrue(newAttributeMap.get(param2.toAttribute()) == param2.child());
}

public void testResolve() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to see more tests for dealing with:
a. same key, value association - (e,e)
b. handling of default value across association: resolveOrDefault(k, default) for a map with (k, default), (default, value) and also resolveOrDefault(k, k)
c. handling of cycles: (a, b), (b, c), (c,a ) with a mixture of default value in place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✓ Done.

AttributeMap.Builder<Object> builder = AttributeMap.builder();
Attribute one = a("one");
Attribute two = fieldAttribute("two", DataTypes.INTEGER);
Attribute three = fieldAttribute("three", DataTypes.INTEGER);
Alias threeAlias = new Alias(Source.EMPTY, "three_alias", three);
Alias threeAliasAlias = new Alias(Source.EMPTY, "three_alias_alias", threeAlias);
builder.put(one, of("one"));
builder.put(two, "two");
builder.put(three, of("three"));
builder.put(threeAlias.toAttribute(), threeAlias.child());
builder.put(threeAliasAlias.toAttribute(), threeAliasAlias.child());
AttributeMap<Object> map = builder.build();

assertEquals(of("one"), map.resolve(one));
assertEquals("two", map.resolve(two));
assertEquals(of("three"), map.resolve(three));
assertEquals(of("three"), map.resolve(threeAlias));
assertEquals(of("three"), map.resolve(threeAliasAlias));
assertEquals(of("three"), map.resolve(threeAliasAlias, threeAlias));
Attribute four = a("four");
assertEquals("not found", map.resolve(four, "not found"));
assertNull(map.resolve(four));
assertEquals(four, map.resolve(four, four));
}

public void testResolveOneHopCycle() {
AttributeMap.Builder<Object> builder = AttributeMap.builder();
Attribute a = fieldAttribute("a", DataTypes.INTEGER);
Attribute b = fieldAttribute("b", DataTypes.INTEGER);
builder.put(a, a);
builder.put(b, a);
AttributeMap<Object> map = builder.build();

assertEquals(a, map.resolve(a, "default"));
assertEquals(a, map.resolve(b, "default"));
assertEquals("default", map.resolve("non-existing-key", "default"));
}

public void testResolveMultiHopCycle() {
AttributeMap.Builder<Object> builder = AttributeMap.builder();
Attribute a = fieldAttribute("a", DataTypes.INTEGER);
Attribute b = fieldAttribute("b", DataTypes.INTEGER);
Attribute c = fieldAttribute("c", DataTypes.INTEGER);
Attribute d = fieldAttribute("d", DataTypes.INTEGER);
builder.put(a, b);
builder.put(b, c);
builder.put(c, d);
builder.put(d, a);
AttributeMap<Object> map = builder.build();

// note: multi hop cycles should not happen, unless we have a
// bug in the code that populates the AttributeMaps
expectThrows(QlIllegalArgumentException.class, () -> {
assertEquals(a, map.resolve(a, c));
});
}

private Alias createIntParameterAlias(int index, int value) {
Source source = new Source(1, index * 5, "?");
Literal literal = new Literal(source, value, DataTypes.INTEGER);
Expand Down
28 changes: 26 additions & 2 deletions x-pack/plugin/sql/qa/server/src/main/resources/subselect.sql-spec
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,34 @@ basicGroupBy
SELECT gender FROM (SELECT first_name AS f, last_name, gender FROM test_emp) GROUP BY gender ORDER BY gender ASC;
basicGroupByAlias
SELECT g FROM (SELECT first_name AS f, last_name, gender AS g FROM test_emp) GROUP BY g ORDER BY g ASC;
// @AwaitsFix(bugUrl = "follow-up to https://github.com/elastic/elasticsearch/pull/67216")
basicGroupByWithFilterByAlias-Ignore
basicGroupByWithFilterByAlias
SELECT g FROM (SELECT first_name AS f, last_name, gender AS g FROM test_emp) WHERE g IS NOT NULL GROUP BY g ORDER BY g ASC;
basicGroupByRealiased
SELECT g AS h FROM (SELECT first_name AS f, last_name, gender AS g FROM test_emp) GROUP BY g ORDER BY g DESC NULLS last;
basicGroupByRealiasedTwice
SELECT g AS h FROM (SELECT first_name AS f, last_name, gender AS g FROM test_emp) GROUP BY h ORDER BY h DESC NULLS last;
basicOrderByRealiasedField
SELECT g AS h FROM (SELECT first_name AS f, last_name, gender AS g FROM test_emp) ORDER BY g DESC NULLS first;

groupAndOrderByRealiasedExpression
SELECT emp_group AS e, min_high_salary AS s
FROM (
SELECT emp_no % 2 AS emp_group, MIN(salary) AS min_high_salary
FROM test_emp
WHERE salary > 50000
GROUP BY emp_group
)
ORDER BY e DESC;

multiLevelSelectStar
SELECT * FROM (SELECT * FROM ( SELECT * FROM test_emp ));

multiLevelSelectStarWithAlias
SELECT * FROM (SELECT * FROM ( SELECT * FROM test_emp ) b) c;

// AwaitsFix: https://github.com/elastic/elasticsearch/issues/69758
filterAfterGroupBy-Ignore
SELECT s2 AS s3 FROM (SELECT s AS s2 FROM ( SELECT salary AS s FROM test_emp) GROUP BY s2) WHERE s2 < 5 ORDER BY s3 DESC NULLS last;

countAndComplexCondition
SELECT COUNT(*) as c FROM (SELECT * FROM test_emp WHERE gender IS NOT NULL) WHERE ABS(salary) > 0 GROUP BY gender ORDER BY gender;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,12 +638,12 @@ protected LogicalPlan doRule(LogicalPlan plan) {
AttributeMap.Builder<Expression> builder = AttributeMap.builder();
// collect aliases
child.forEachUp(p -> p.forEachExpressionUp(Alias.class, a -> builder.put(a.toAttribute(), a.child())));
final Map<Attribute, Expression> collectRefs = builder.build();
final AttributeMap<Expression> collectRefs = builder.build();

referencesStream = referencesStream.filter(r -> {
for (Attribute attr : child.outputSet()) {
if (attr instanceof ReferenceAttribute) {
Expression source = collectRefs.getOrDefault(attr, attr);
Expression source = collectRefs.resolve(attr, attr);
// found a match, no need to resolve it further
// so filter it out
if (source.equals(r.child())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,11 @@ private static boolean checkGroupByOrder(LogicalPlan p, Set<Failure> localFailur
Map<Expression, Node<?>> missing = new LinkedHashMap<>();

o.order().forEach(oe -> {
Expression e = oe.child();
final Expression e = oe.child();
final Expression resolvedE = attributeRefs.resolve(e, e);

// aggregates are allowed
if (Functions.isAggregate(attributeRefs.getOrDefault(e, e))) {
if (Functions.isAggregate(resolvedE)) {
return;
}

Expand All @@ -340,8 +341,12 @@ private static boolean checkGroupByOrder(LogicalPlan p, Set<Failure> localFailur
// e.g.: if "GROUP BY f2(f1(field))" you can "ORDER BY f4(f3(f2(f1(field))))"
//
// Also, make sure to compare attributes directly
if (e.anyMatch(expression -> Expressions.anyMatch(groupingAndMatchingAggregatesAliases,
g -> expression.semanticEquals(expression instanceof Attribute ? Expressions.attribute(g) : g)))) {
if (resolvedE.anyMatch(expression -> Expressions.anyMatch(groupingAndMatchingAggregatesAliases,
g -> {
Expression resolvedG = attributeRefs.resolve(g, g);
resolvedG = expression instanceof Attribute ? Expressions.attribute(resolvedG) : resolvedG;
return expression.semanticEquals(resolvedG);
}))) {
return;
}

Expand Down Expand Up @@ -406,7 +411,7 @@ private static boolean checkGroupByHavingHasOnlyAggs(Expression e, Set<Expressio

// resolve FunctionAttribute to backing functions
if (e instanceof ReferenceAttribute) {
e = attributeRefs.get(e);
e = attributeRefs.resolve(e);
}

// scalar functions can be a binary tree
Expand Down Expand Up @@ -484,7 +489,7 @@ private static boolean onlyRawFields(Iterable<? extends Expression> expressions,

expressions.forEach(e -> e.forEachDown(c -> {
if (c instanceof ReferenceAttribute) {
c = attributeRefs.getOrDefault(c, c);
c = attributeRefs.resolve(c, c);
}
if (c instanceof Function) {
localFailures.add(fail(c, "No functions allowed (yet); encountered [{}]", c.sourceText()));
Expand Down Expand Up @@ -579,7 +584,7 @@ private static boolean checkGroupMatch(Expression e, Node<?> source, List<Expres

// resolve FunctionAttribute to backing functions
if (e instanceof ReferenceAttribute) {
e = attributeRefs.get(e);
e = attributeRefs.resolve(e);
}

// scalar functions can be a binary tree
Expand Down Expand Up @@ -668,7 +673,7 @@ private static void checkFilterOnAggs(LogicalPlan p, Set<Failure> localFailures,
LogicalPlan filterChild = filter.child();
if (filterChild instanceof Aggregate == false) {
filter.condition().forEachDown(Expression.class, e -> {
if (Functions.isAggregate(attributeRefs.getOrDefault(e, e))) {
if (Functions.isAggregate(attributeRefs.resolve(e, e))) {
if (filterChild instanceof Project) {
filter.condition().forEachDown(FieldAttribute.class,
f -> localFailures.add(fail(e, "[{}] field must appear in the GROUP BY clause or in an aggregate function",
Expand All @@ -690,7 +695,7 @@ private static void checkFilterOnGrouping(LogicalPlan p, Set<Failure> localFailu
if (p instanceof Filter) {
Filter filter = (Filter) p;
filter.condition().forEachDown(Expression.class, e -> {
if (Functions.isGrouping(attributeRefs.getOrDefault(e, e))) {
if (Functions.isGrouping(attributeRefs.resolve(e, e))) {
localFailures
.add(fail(e, "Cannot filter on grouping function [{}], use its argument instead", Expressions.name(e)));
}
Expand All @@ -717,7 +722,7 @@ private static void checkNestedUsedInGroupByOrHavingOrWhereOrOrderBy(LogicalPlan
}
};
Consumer<Expression> checkForNested = e ->
attributeRefs.getOrDefault(e, e).forEachUp(FieldAttribute.class, matchNested);
attributeRefs.resolve(e, e).forEachUp(FieldAttribute.class, matchNested);
Consumer<ScalarFunction> checkForNestedInFunction = f -> f.arguments().forEach(
arg -> arg.forEachUp(FieldAttribute.class, matchNested));

Expand All @@ -739,7 +744,7 @@ private static void checkNestedUsedInGroupByOrHavingOrWhereOrOrderBy(LogicalPlan

// check in where (scalars not allowed)
p.forEachDown(Filter.class, f -> f.condition().forEachUp(e ->
attributeRefs.getOrDefault(e, e).forEachUp(ScalarFunction.class, sf -> {
attributeRefs.resolve(e, e).forEachUp(ScalarFunction.class, sf -> {
if (sf instanceof BinaryComparison == false &&
sf instanceof IsNull == false &&
sf instanceof IsNotNull == false &&
Expand All @@ -758,7 +763,7 @@ private static void checkNestedUsedInGroupByOrHavingOrWhereOrOrderBy(LogicalPlan

// check in order by (scalars not allowed)
p.forEachDown(OrderBy.class, ob -> ob.order().forEach(o -> o.forEachUp(e ->
attributeRefs.getOrDefault(e, e).forEachUp(ScalarFunction.class, checkForNestedInFunction)
attributeRefs.resolve(e, e).forEachUp(ScalarFunction.class, checkForNestedInFunction)
)));
if (nested.isEmpty() == false) {
localFailures.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@ public LogicalPlan apply(LogicalPlan plan) {
AttributeMap.Builder<Expression> builder = AttributeMap.builder();
// collect aliases
plan.forEachExpressionUp(Alias.class, a -> builder.put(a.toAttribute(), a.child()));
final Map<Attribute, Expression> collectRefs = builder.build();
java.util.function.Function<ReferenceAttribute, Expression> replaceReference = r -> collectRefs.getOrDefault(r, r);
final AttributeMap<Expression> collectRefs = builder.build();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

java.util.function.Function<ReferenceAttribute, Expression> replaceReference = r -> collectRefs.resolve(r, r);

plan = plan.transformUp(p -> {
// non attribute defining plans get their references removed
Expand Down Expand Up @@ -300,7 +300,7 @@ static class PruneOrderByNestedFields extends OptimizerRule<Project> {
private void findNested(Expression exp, AttributeMap<Function> functions, Consumer<FieldAttribute> onFind) {
exp.forEachUp(e -> {
if (e instanceof ReferenceAttribute) {
Function f = functions.get(e);
Function f = functions.resolve(e);
if (f != null) {
findNested(f, functions, onFind);
}
Expand Down Expand Up @@ -578,7 +578,7 @@ private List<NamedExpression> combineProjections(List<? extends NamedExpression>
// replace any matching attribute with a lower alias (if there's a match)
// but clean-up non-top aliases at the end
for (NamedExpression ne : upper) {
NamedExpression replacedExp = (NamedExpression) ne.transformUp(Attribute.class, a -> aliases.getOrDefault(a, a));
NamedExpression replacedExp = (NamedExpression) ne.transformUp(Attribute.class, a -> aliases.resolve(a, a));
replaced.add((NamedExpression) CleanAliases.trimNonTopLevelAliases(replacedExp));
}
return replaced;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,10 +600,13 @@ else if (target.foldable()) {
else {
GroupByKey matchingGroup = null;
if (groupingContext != null) {
target = queryC.aliases().resolve(target, target);
id = Expressions.id(target);
matchingGroup = groupingContext.groupFor(target);
Check.notNull(matchingGroup, "Cannot find group [{}]", Expressions.name(ne));

queryC = queryC.addColumn(new GroupByRef(matchingGroup.id(), null, isDateBased(ne.dataType())), id);
queryC = queryC.addColumn(
new GroupByRef(matchingGroup.id(), null, isDateBased(ne.dataType())), id);
}
// fallback
else {
Expand Down Expand Up @@ -707,7 +710,7 @@ protected PhysicalPlan rule(OrderExec plan) {

// if it's a reference, get the target expression
if (orderExpression instanceof ReferenceAttribute) {
orderExpression = qContainer.aliases().get(orderExpression);
orderExpression = qContainer.aliases().resolve(orderExpression);
}
String lookup = Expressions.id(orderExpression);
GroupByKey group = qContainer.findGroupForAgg(lookup);
Expand Down
Loading