Skip to content

Commit

Permalink
SQL: "Polish" painless script generated from IN
Browse files Browse the repository at this point in the history
Replace standard `||` and `==` painless operators with
`or` and `eq` null safe alternatives from `InternalSqlScriptUtils`.

Follow up to elastic#34750
  • Loading branch information
matriv committed Oct 30, 2018
1 parent 7ef65de commit 782b6fe
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 48 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.string.ReplaceFunctionProcessor;
import org.elasticsearch.xpack.sql.expression.function.scalar.string.StringProcessor.StringOperation;
import org.elasticsearch.xpack.sql.expression.function.scalar.string.SubstringFunctionProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.In;
import org.elasticsearch.xpack.sql.expression.predicate.IsNotNullProcessor;
import org.elasticsearch.xpack.sql.expression.predicate.logical.BinaryLogicProcessor.BinaryLogicOperation;
import org.elasticsearch.xpack.sql.expression.predicate.logical.NotProcessor;
Expand All @@ -31,6 +32,7 @@
import org.elasticsearch.xpack.sql.util.StringUtils;

import java.time.ZonedDateTime;
import java.util.List;
import java.util.Map;

/**
Expand Down Expand Up @@ -113,6 +115,10 @@ public static Boolean notNull(Object expression) {
return IsNotNullProcessor.apply(expression);
}

public static Boolean in(Object value, List<Object> values) {
return In.doFold(value, values);
}

//
// Regex
//
Expand Down Expand Up @@ -375,4 +381,4 @@ public static String substring(String s, Number start, Number length) {
public static String ucase(String s) {
return (String) StringOperation.UCASE.apply(s);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,4 @@ public static ScriptTemplate binaryMethod(String methodName, ScriptTemplate left
.build(),
dataType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@
import org.elasticsearch.xpack.sql.expression.Attribute;
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Expressions;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.expression.NamedExpression;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe;
import org.elasticsearch.xpack.sql.expression.gen.script.Params;
import org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptWeaver;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Comparisons;
Expand All @@ -30,7 +29,6 @@
import java.util.StringJoiner;
import java.util.stream.Collectors;

import static java.lang.String.format;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;

public class In extends NamedExpression implements ScriptWeaver {
Expand Down Expand Up @@ -84,17 +82,21 @@ public boolean foldable() {

@Override
public Boolean fold() {
// Optimization for early return and Query folding to LocalExec
if (value.dataType() == DataType.NULL) {
return null;
}
if (list.size() == 1 && list.get(0).dataType() == DataType.NULL) {
return false;
return null;
}

Object foldedLeftValue = value.fold();
return doFold(value.fold(), Foldables.valuesOf(list, value.dataType()));
}

public static Boolean doFold(Object value, List<Object> values) {
Boolean result = false;
for (Expression rightValue : list) {
Boolean compResult = Comparisons.eq(foldedLeftValue, rightValue.fold());
for (Object v : values) {
Boolean compResult = Comparisons.eq(value, v);
if (compResult == null) {
result = null;
} else if (compResult) {
Expand Down Expand Up @@ -122,34 +124,18 @@ public Attribute toAttribute() {

@Override
public ScriptTemplate asScript() {
StringJoiner sj = new StringJoiner(" || ");
ScriptTemplate leftScript = asScript(value);
List<Params> rightParams = new ArrayList<>();
String scriptPrefix = leftScript + "==";
LinkedHashSet<Object> values = list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new));
for (Object valueFromList : values) {
// if checked against null => false
if (valueFromList != null) {
if (valueFromList instanceof Expression) {
ScriptTemplate rightScript = asScript((Expression) valueFromList);
sj.add(scriptPrefix + rightScript.template());
rightParams.add(rightScript.params());
} else {
if (valueFromList instanceof String) {
sj.add(scriptPrefix + '"' + valueFromList + '"');
} else {
sj.add(scriptPrefix + valueFromList.toString());
}
}
}
}

ParamsBuilder paramsBuilder = paramsBuilder().script(leftScript.params());
for (Params p : rightParams) {
paramsBuilder = paramsBuilder.script(p);
}

return new ScriptTemplate(format(Locale.ROOT, "%s", sj.toString()), paramsBuilder.build(), dataType());
// remove duplicates
// TODO: Don't exclude nulls, painless script should handle them
List<Object> values = new ArrayList<>(
list.stream().map(Expression::fold).filter(Objects::nonNull).collect(Collectors.toCollection(LinkedHashSet::new)));

return new ScriptTemplate(String.format(Locale.ROOT, formatTemplate("{sql}.in(%s, {})"), leftScript.template(), "%s"),
paramsBuilder()
.script(leftScript.params())
.variable(values)
.build(),
dataType());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -682,8 +682,7 @@ protected LogicalPlan rule(Filter filter) {
if (TRUE.equals(filter.condition())) {
return filter.child();
}
// TODO: add comparison with null as well
if (FALSE.equals(filter.condition())) {
if (FALSE.equals(filter.condition()) || ((Literal) filter.condition()).value() == null) {
return new LocalRelation(filter.location(), new EmptyExecutable(filter.output()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import org.elasticsearch.xpack.sql.expression.Expression;
import org.elasticsearch.xpack.sql.expression.Foldables;
import org.elasticsearch.xpack.sql.tree.Location;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypes;

import java.util.Collections;
import java.util.LinkedHashSet;
Expand All @@ -27,7 +27,7 @@ public class TermsQuery extends LeafQuery {
public TermsQuery(Location location, String term, List<Expression> values) {
super(location);
this.term = term;
values.removeIf(e -> e.dataType() == DataType.NULL);
values.removeIf(e -> DataTypes.isNull(e.dataType()));
if (values.isEmpty()) {
this.values = Collections.emptySet();
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ public static Conversion conversionFor(DataType from, DataType to) {
if (to == DataType.NULL) {
return Conversion.NULL;
}
if (from == DataType.NULL) {
return Conversion.NULL;
}

Conversion conversion = conversion(from, to);
if (conversion == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
Boolean lte(Object, Object)
Boolean gt(Object, Object)
Boolean gte(Object, Object)
Boolean in(Object, java.util.List)

#
# Logical
Expand Down Expand Up @@ -107,4 +108,4 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS
String space(Number)
String substring(String, Number, Number)
String ucase(String)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ public void testConstantFoldingIn_LeftValueNotFoldable() {
public void testConstantFoldingIn_RightValueIsNull() {
In in = new In(EMPTY, getFieldAttribute(), Arrays.asList(NULL, NULL));
Literal result= (Literal) new ConstantFolding().rule(in);
assertEquals(false, result.value());
assertNull(result.value());
}

public void testConstantFoldingIn_LeftValueIsNull() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import java.util.Map;
import java.util.TimeZone;

import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.core.StringStartsWith.startsWith;

public class QueryTranslatorTests extends AbstractBuilderTestCase {
Expand Down Expand Up @@ -161,7 +162,7 @@ public void testLikeConstructsNotSupported() {
}

public void testTranslateInExpression_WhereClause() throws IOException {
LogicalPlan p = plan("SELECT * FROM test WHERE keyword IN ('foo', 'bar', 'lala', 'foo', concat('la', 'la'))");
LogicalPlan p = plan("SELECT * FROM test WHERE keyword IN ('foo', 'bar', 'l''ala', 'foo', concat('l''a', 'la'))");
assertTrue(p instanceof Project);
assertTrue(p.children().get(0) instanceof Filter);
Expression condition = ((Filter) p.children().get(0)).condition();
Expand All @@ -170,7 +171,7 @@ public void testTranslateInExpression_WhereClause() throws IOException {
Query query = translation.query;
assertTrue(query instanceof TermsQuery);
TermsQuery tq = (TermsQuery) query;
assertEquals("keyword:(bar foo lala)", tq.asBuilder().toQuery(createShardContext()).toString());
assertEquals("keyword:(bar foo l'ala)", tq.asBuilder().toQuery(createShardContext()).toString());
}

public void testTranslateInExpression_WhereClauseAndNullHandling() throws IOException {
Expand Down Expand Up @@ -206,20 +207,39 @@ public void testTranslateInExpression_HavingClause_Painless() {
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertTrue(translation.query instanceof ScriptQuery);
ScriptQuery sq = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)", sq.script().toString());
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
sq.script().toString());
assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(sq.script().params().toString(), endsWith(", {v=[10, 20]}]"));
}

public void testTranslateInExpression_HavingClauseAndNullHandling_Painless() {
LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, null, 20, null, 30 - 10)");
public void testTranslateInExpression_HavingClause_PainlessOneArg() {
LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, 30 - 20)");
assertTrue(p instanceof Project);
assertTrue(p.children().get(0) instanceof Filter);
Expression condition = ((Filter) p.children().get(0)).condition();
assertFalse(condition.foldable());
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertTrue(translation.query instanceof ScriptQuery);
ScriptQuery sq = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)", sq.script().toString());
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", sq.script().toString());
assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(sq.script().params().toString(), endsWith(", {v=[10]}]"));

}

public void testTranslateInExpression_HavingClause_PainlessAndNullHandling() {
LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, null, 20, 30, null, 30 - 10)");
assertTrue(p instanceof Project);
assertTrue(p.children().get(0) instanceof Filter);
Expression condition = ((Filter) p.children().get(0)).condition();
assertFalse(condition.foldable());
QueryTranslation translation = QueryTranslator.toQuery(condition, false);
assertTrue(translation.query instanceof ScriptQuery);
ScriptQuery sq = (ScriptQuery) translation.query;
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))",
sq.script().toString());
assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->"));
assertThat(sq.script().params().toString(), endsWith(", {v=[10, 20, 30]}]"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,12 @@ public void testConversionToNull() {
assertNull(conversion.convert(10.0));
}

public void testConversionFromNull() {
Conversion conversion = DataTypeConversion.conversionFor(DataType.NULL, DataType.INTEGER);
assertNull(conversion.convert(null));
assertNull(conversion.convert(10));
}

public void testConversionToIdentity() {
Conversion conversion = DataTypeConversion.conversionFor(DataType.INTEGER, DataType.INTEGER);
assertNull(conversion.convert(null));
Expand Down

0 comments on commit 782b6fe

Please sign in to comment.