-
Notifications
You must be signed in to change notification settings - Fork 24.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SQL: Improve painless script generated from IN
#35055
Changes from 7 commits
782b6fe
a0cd723
6bd1dc5
fb61b36
8d8e0cb
c09027d
9f0b212
fb97b51
dbd9b62
8beefbc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,4 +87,4 @@ public static ScriptTemplate binaryMethod(String methodName, ScriptTemplate left | |
.build(), | ||
dataType); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,20 +3,17 @@ | |
* or more contributor license agreements. Licensed under the Elastic License; | ||
* you may not use this file except in compliance with the Elastic License. | ||
*/ | ||
package org.elasticsearch.xpack.sql.expression.predicate; | ||
package org.elasticsearch.xpack.sql.expression.predicate.operator.comparison; | ||
|
||
import org.elasticsearch.xpack.sql.expression.Attribute; | ||
import org.elasticsearch.xpack.sql.expression.Expression; | ||
import org.elasticsearch.xpack.sql.expression.Expressions; | ||
import org.elasticsearch.xpack.sql.expression.Foldables; | ||
import org.elasticsearch.xpack.sql.expression.NamedExpression; | ||
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute; | ||
import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; | ||
import org.elasticsearch.xpack.sql.expression.gen.script.Params; | ||
import org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder; | ||
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; | ||
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptWeaver; | ||
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Comparisons; | ||
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.InPipe; | ||
import org.elasticsearch.xpack.sql.tree.Location; | ||
import org.elasticsearch.xpack.sql.tree.NodeInfo; | ||
import org.elasticsearch.xpack.sql.type.DataType; | ||
|
@@ -30,7 +27,6 @@ | |
import java.util.StringJoiner; | ||
import java.util.stream.Collectors; | ||
|
||
import static java.lang.String.format; | ||
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder; | ||
|
||
public class In extends NamedExpression implements ScriptWeaver { | ||
|
@@ -84,17 +80,21 @@ public boolean foldable() { | |
|
||
@Override | ||
public Boolean fold() { | ||
// Optimization for early return and Query folding to LocalExec | ||
if (value.dataType() == DataType.NULL) { | ||
return null; | ||
} | ||
if (list.size() == 1 && list.get(0).dataType() == DataType.NULL) { | ||
return false; | ||
return null; | ||
} | ||
|
||
Object foldedLeftValue = value.fold(); | ||
return doFold(value.fold(), Foldables.valuesOf(list, value.dataType())); | ||
} | ||
|
||
public static Boolean doFold(Object value, List<Object> values) { | ||
Boolean result = false; | ||
for (Expression rightValue : list) { | ||
Boolean compResult = Comparisons.eq(foldedLeftValue, rightValue.fold()); | ||
for (Object v : values) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this method to the processor please. |
||
Boolean compResult = Comparisons.eq(value, v); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's be explicit about boxing. |
||
if (compResult == null) { | ||
result = null; | ||
} else if (compResult) { | ||
|
@@ -122,34 +122,18 @@ public Attribute toAttribute() { | |
|
||
@Override | ||
public ScriptTemplate asScript() { | ||
StringJoiner sj = new StringJoiner(" || "); | ||
ScriptTemplate leftScript = asScript(value); | ||
List<Params> rightParams = new ArrayList<>(); | ||
String scriptPrefix = leftScript + "=="; | ||
LinkedHashSet<Object> values = list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new)); | ||
for (Object valueFromList : values) { | ||
// if checked against null => false | ||
if (valueFromList != null) { | ||
if (valueFromList instanceof Expression) { | ||
ScriptTemplate rightScript = asScript((Expression) valueFromList); | ||
sj.add(scriptPrefix + rightScript.template()); | ||
rightParams.add(rightScript.params()); | ||
} else { | ||
if (valueFromList instanceof String) { | ||
sj.add(scriptPrefix + '"' + valueFromList + '"'); | ||
} else { | ||
sj.add(scriptPrefix + valueFromList.toString()); | ||
} | ||
} | ||
} | ||
} | ||
|
||
ParamsBuilder paramsBuilder = paramsBuilder().script(leftScript.params()); | ||
for (Params p : rightParams) { | ||
paramsBuilder = paramsBuilder.script(p); | ||
} | ||
|
||
return new ScriptTemplate(format(Locale.ROOT, "%s", sj.toString()), paramsBuilder.build(), dataType()); | ||
// remove duplicates | ||
List<Object> values = new ArrayList<>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new))); | ||
|
||
return new ScriptTemplate( | ||
formatTemplate(String.format(Locale.ROOT, "{sql}.in(%s, {})", leftScript.template())), | ||
paramsBuilder() | ||
.script(leftScript.params()) | ||
.variable(values) | ||
.build(), | ||
dataType()); | ||
} | ||
|
||
@Override | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
import java.util.Map; | ||
import java.util.TimeZone; | ||
|
||
import static org.hamcrest.Matchers.endsWith; | ||
import static org.hamcrest.core.StringStartsWith.startsWith; | ||
|
||
public class QueryTranslatorTests extends ESTestCase { | ||
|
@@ -208,12 +209,11 @@ public void testTranslateInExpression_WhereClause_Painless() { | |
QueryTranslation translation = QueryTranslator.toQuery(condition, false); | ||
assertNull(translation.aggFilter); | ||
assertTrue(translation.query instanceof ScriptQuery); | ||
ScriptQuery sq = (ScriptQuery) translation.query; | ||
assertEquals("InternalSqlScriptUtils.nullSafeFilter(" + | ||
"InternalSqlScriptUtils.power(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1)==10 || " + | ||
"InternalSqlScriptUtils.power(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1)==20)", | ||
sq.script().toString()); | ||
assertEquals("[{v=int}, {v=2}]", sq.script().params().toString()); | ||
ScriptQuery sc = (ScriptQuery) translation.query; | ||
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(" + | ||
"InternalSqlScriptUtils.power(InternalSqlScriptUtils.docValue(doc,params.v0),params.v1), params.v2))", | ||
sc.script().toString()); | ||
assertEquals("[{v=int}, {v=2}, {v=[10, null, 20]}]", sc.script().params().toString()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
} | ||
|
||
public void testTranslateInExpression_HavingClause_Painless() { | ||
|
@@ -225,9 +225,10 @@ public void testTranslateInExpression_HavingClause_Painless() { | |
QueryTranslation translation = QueryTranslator.toQuery(condition, true); | ||
assertNull(translation.query); | ||
AggFilter aggFilter = translation.aggFilter; | ||
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)", | ||
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", | ||
aggFilter.scriptTemplate().toString()); | ||
assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->")); | ||
assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10, 20]}]")); | ||
} | ||
|
||
public void testTranslateInExpression_HavingClause_PainlessOneArg() { | ||
|
@@ -239,9 +240,10 @@ public void testTranslateInExpression_HavingClause_PainlessOneArg() { | |
QueryTranslation translation = QueryTranslator.toQuery(condition, true); | ||
assertNull(translation.query); | ||
AggFilter aggFilter = translation.aggFilter; | ||
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10)", | ||
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", | ||
aggFilter.scriptTemplate().toString()); | ||
assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->")); | ||
assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10]}]")); | ||
|
||
} | ||
|
||
|
@@ -254,8 +256,9 @@ public void testTranslateInExpression_HavingClause_PainlessAndNullHandling() { | |
QueryTranslation translation = QueryTranslator.toQuery(condition, true); | ||
assertNull(translation.query); | ||
AggFilter aggFilter = translation.aggFilter; | ||
assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20 || params.a0==30)", | ||
assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", | ||
aggFilter.scriptTemplate().toString()); | ||
assertThat(aggFilter.scriptTemplate().params().toString(), startsWith("[{a=MAX(int){a->")); | ||
assertThat(aggFilter.scriptTemplate().params().toString(), endsWith(", {v=[10, null, 20, 30]}]")); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The naming convention is to call the method on the processor (typically
InProcessor.apply
)