diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/whitelist/InternalSqlScriptUtils.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/whitelist/InternalSqlScriptUtils.java index 9aabb3f10ecdc..a4ddd5f74ae8b 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/whitelist/InternalSqlScriptUtils.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/function/scalar/whitelist/InternalSqlScriptUtils.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.sql.expression.function.scalar.string.ReplaceFunctionProcessor; import org.elasticsearch.xpack.sql.expression.function.scalar.string.StringProcessor.StringOperation; import org.elasticsearch.xpack.sql.expression.function.scalar.string.SubstringFunctionProcessor; +import org.elasticsearch.xpack.sql.expression.predicate.In; import org.elasticsearch.xpack.sql.expression.predicate.IsNotNullProcessor; import org.elasticsearch.xpack.sql.expression.predicate.logical.BinaryLogicProcessor.BinaryLogicOperation; import org.elasticsearch.xpack.sql.expression.predicate.logical.NotProcessor; @@ -31,6 +32,7 @@ import org.elasticsearch.xpack.sql.util.StringUtils; import java.time.ZonedDateTime; +import java.util.List; import java.util.Map; /** @@ -113,6 +115,10 @@ public static Boolean notNull(Object expression) { return IsNotNullProcessor.apply(expression); } + public static Boolean in(Object value, List values) { + return In.doFold(value, values); + } + // // Regex // @@ -375,4 +381,4 @@ public static String substring(String s, Number start, Number length) { public static String ucase(String s) { return (String) StringOperation.UCASE.apply(s); } -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Scripts.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Scripts.java index f9e2588a9c035..21ac12e51da89 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Scripts.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/gen/script/Scripts.java @@ -87,4 +87,4 @@ public static ScriptTemplate binaryMethod(String methodName, ScriptTemplate left .build(), dataType); } -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/In.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/In.java index 9b16b77511ca7..1f3d666ee6157 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/In.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/expression/predicate/In.java @@ -8,11 +8,10 @@ import org.elasticsearch.xpack.sql.expression.Attribute; import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Expressions; +import org.elasticsearch.xpack.sql.expression.Foldables; import org.elasticsearch.xpack.sql.expression.NamedExpression; import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunctionAttribute; import org.elasticsearch.xpack.sql.expression.gen.pipeline.Pipe; -import org.elasticsearch.xpack.sql.expression.gen.script.Params; -import org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate; import org.elasticsearch.xpack.sql.expression.gen.script.ScriptWeaver; import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.Comparisons; @@ -30,7 +29,6 @@ import java.util.StringJoiner; import java.util.stream.Collectors; -import static java.lang.String.format; import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder; public class In extends NamedExpression implements ScriptWeaver { @@ -84,17 +82,21 @@ public boolean foldable() { @Override public Boolean fold() { + // Optimization for early return and Query folding to LocalExec if (value.dataType() == DataType.NULL) { return null; } if (list.size() == 1 && list.get(0).dataType() == DataType.NULL) { - return false; + return null; } - Object foldedLeftValue = value.fold(); + return doFold(value.fold(), Foldables.valuesOf(list, value.dataType())); + } + + public static Boolean doFold(Object value, List values) { Boolean result = false; - for (Expression rightValue : list) { - Boolean compResult = Comparisons.eq(foldedLeftValue, rightValue.fold()); + for (Object v : values) { + Boolean compResult = Comparisons.eq(value, v); if (compResult == null) { result = null; } else if (compResult) { @@ -122,34 +124,18 @@ public Attribute toAttribute() { @Override public ScriptTemplate asScript() { - StringJoiner sj = new StringJoiner(" || "); ScriptTemplate leftScript = asScript(value); - List rightParams = new ArrayList<>(); - String scriptPrefix = leftScript + "=="; - LinkedHashSet values = list.stream().map(Expression::fold).collect(Collectors.toCollection(LinkedHashSet::new)); - for (Object valueFromList : values) { - // if checked against null => false - if (valueFromList != null) { - if (valueFromList instanceof Expression) { - ScriptTemplate rightScript = asScript((Expression) valueFromList); - sj.add(scriptPrefix + rightScript.template()); - rightParams.add(rightScript.params()); - } else { - if (valueFromList instanceof String) { - sj.add(scriptPrefix + '"' + valueFromList + '"'); - } else { - sj.add(scriptPrefix + valueFromList.toString()); - } - } - } - } - - ParamsBuilder paramsBuilder = paramsBuilder().script(leftScript.params()); - for (Params p : rightParams) { - paramsBuilder = paramsBuilder.script(p); - } - - return new ScriptTemplate(format(Locale.ROOT, "%s", sj.toString()), paramsBuilder.build(), dataType()); + // remove duplicates + // TODO: Don't exclude nulls, painless script should handle them + List values = new ArrayList<>( + list.stream().map(Expression::fold).filter(Objects::nonNull).collect(Collectors.toCollection(LinkedHashSet::new))); + + return new ScriptTemplate(String.format(Locale.ROOT, formatTemplate("{sql}.in(%s, {})"), leftScript.template(), "%s"), + paramsBuilder() + .script(leftScript.params()) + .variable(values) + .build(), + dataType()); } @Override diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java index 8443358a12cb2..2d18a687a301e 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/optimizer/Optimizer.java @@ -682,8 +682,7 @@ protected LogicalPlan rule(Filter filter) { if (TRUE.equals(filter.condition())) { return filter.child(); } - // TODO: add comparison with null as well - if (FALSE.equals(filter.condition())) { + if (FALSE.equals(filter.condition()) || ((Literal) filter.condition()).value() == null) { return new LocalRelation(filter.location(), new EmptyExecutable(filter.output())); } } diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/TermsQuery.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/TermsQuery.java index 91ea49a8a3ce3..66d206f829a32 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/TermsQuery.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/querydsl/query/TermsQuery.java @@ -9,7 +9,7 @@ import org.elasticsearch.xpack.sql.expression.Expression; import org.elasticsearch.xpack.sql.expression.Foldables; import org.elasticsearch.xpack.sql.tree.Location; -import org.elasticsearch.xpack.sql.type.DataType; +import org.elasticsearch.xpack.sql.type.DataTypes; import java.util.Collections; import java.util.LinkedHashSet; @@ -27,7 +27,7 @@ public class TermsQuery extends LeafQuery { public TermsQuery(Location location, String term, List values) { super(location); this.term = term; - values.removeIf(e -> e.dataType() == DataType.NULL); + values.removeIf(e -> DataTypes.isNull(e.dataType())); if (values.isEmpty()) { this.values = Collections.emptySet(); } else { diff --git a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/type/DataTypeConversion.java b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/type/DataTypeConversion.java index 3312f449ec622..53f7e6b1ab16d 100644 --- a/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/type/DataTypeConversion.java +++ b/x-pack/plugin/sql/src/main/java/org/elasticsearch/xpack/sql/type/DataTypeConversion.java @@ -105,6 +105,9 @@ public static Conversion conversionFor(DataType from, DataType to) { if (to == DataType.NULL) { return Conversion.NULL; } + if (from == DataType.NULL) { + return Conversion.NULL; + } Conversion conversion = conversion(from, to); if (conversion == null) { diff --git a/x-pack/plugin/sql/src/main/resources/org/elasticsearch/xpack/sql/plugin/sql_whitelist.txt b/x-pack/plugin/sql/src/main/resources/org/elasticsearch/xpack/sql/plugin/sql_whitelist.txt index 998dab84783f0..827947424b08e 100644 --- a/x-pack/plugin/sql/src/main/resources/org/elasticsearch/xpack/sql/plugin/sql_whitelist.txt +++ b/x-pack/plugin/sql/src/main/resources/org/elasticsearch/xpack/sql/plugin/sql_whitelist.txt @@ -24,6 +24,7 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS Boolean lte(Object, Object) Boolean gt(Object, Object) Boolean gte(Object, Object) + Boolean in(Object, java.util.List) # # Logical @@ -107,4 +108,4 @@ class org.elasticsearch.xpack.sql.expression.function.scalar.whitelist.InternalS String space(Number) String substring(String, Number, Number) String ucase(String) -} \ No newline at end of file +} diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java index 0246499f7f9c0..62cb60ac98fa6 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/optimizer/OptimizerTests.java @@ -340,7 +340,7 @@ public void testConstantFoldingIn_LeftValueNotFoldable() { public void testConstantFoldingIn_RightValueIsNull() { In in = new In(EMPTY, getFieldAttribute(), Arrays.asList(NULL, NULL)); Literal result= (Literal) new ConstantFolding().rule(in); - assertEquals(false, result.value()); + assertNull(result.value()); } public void testConstantFoldingIn_LeftValueIsNull() { diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java index c1e5a0d2dafad..58d16af81fbeb 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/planner/QueryTranslatorTests.java @@ -33,6 +33,7 @@ import java.util.Map; import java.util.TimeZone; +import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.core.StringStartsWith.startsWith; public class QueryTranslatorTests extends AbstractBuilderTestCase { @@ -161,7 +162,7 @@ public void testLikeConstructsNotSupported() { } public void testTranslateInExpression_WhereClause() throws IOException { - LogicalPlan p = plan("SELECT * FROM test WHERE keyword IN ('foo', 'bar', 'lala', 'foo', concat('la', 'la'))"); + LogicalPlan p = plan("SELECT * FROM test WHERE keyword IN ('foo', 'bar', 'l''ala', 'foo', concat('l''a', 'la'))"); assertTrue(p instanceof Project); assertTrue(p.children().get(0) instanceof Filter); Expression condition = ((Filter) p.children().get(0)).condition(); @@ -170,7 +171,7 @@ public void testTranslateInExpression_WhereClause() throws IOException { Query query = translation.query; assertTrue(query instanceof TermsQuery); TermsQuery tq = (TermsQuery) query; - assertEquals("keyword:(bar foo lala)", tq.asBuilder().toQuery(createShardContext()).toString()); + assertEquals("keyword:(bar foo l'ala)", tq.asBuilder().toQuery(createShardContext()).toString()); } public void testTranslateInExpression_WhereClauseAndNullHandling() throws IOException { @@ -206,12 +207,14 @@ public void testTranslateInExpression_HavingClause_Painless() { QueryTranslation translation = QueryTranslator.toQuery(condition, false); assertTrue(translation.query instanceof ScriptQuery); ScriptQuery sq = (ScriptQuery) translation.query; - assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)", sq.script().toString()); + assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", + sq.script().toString()); assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->")); + assertThat(sq.script().params().toString(), endsWith(", {v=[10, 20]}]")); } - public void testTranslateInExpression_HavingClauseAndNullHandling_Painless() { - LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, null, 20, null, 30 - 10)"); + public void testTranslateInExpression_HavingClause_PainlessOneArg() { + LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, 30 - 20)"); assertTrue(p instanceof Project); assertTrue(p.children().get(0) instanceof Filter); Expression condition = ((Filter) p.children().get(0)).condition(); @@ -219,7 +222,24 @@ public void testTranslateInExpression_HavingClauseAndNullHandling_Painless() { QueryTranslation translation = QueryTranslator.toQuery(condition, false); assertTrue(translation.query instanceof ScriptQuery); ScriptQuery sq = (ScriptQuery) translation.query; - assertEquals("InternalSqlScriptUtils.nullSafeFilter(params.a0==10 || params.a0==20)", sq.script().toString()); + assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", sq.script().toString()); assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->")); + assertThat(sq.script().params().toString(), endsWith(", {v=[10]}]")); + + } + + public void testTranslateInExpression_HavingClause_PainlessAndNullHandling() { + LogicalPlan p = plan("SELECT keyword, max(int) FROM test GROUP BY keyword HAVING max(int) in (10, null, 20, 30, null, 30 - 10)"); + assertTrue(p instanceof Project); + assertTrue(p.children().get(0) instanceof Filter); + Expression condition = ((Filter) p.children().get(0)).condition(); + assertFalse(condition.foldable()); + QueryTranslation translation = QueryTranslator.toQuery(condition, false); + assertTrue(translation.query instanceof ScriptQuery); + ScriptQuery sq = (ScriptQuery) translation.query; + assertEquals("InternalSqlScriptUtils.nullSafeFilter(InternalSqlScriptUtils.in(params.a0, params.v0))", + sq.script().toString()); + assertThat(sq.script().params().toString(), startsWith("[{a=MAX(int){a->")); + assertThat(sq.script().params().toString(), endsWith(", {v=[10, 20, 30]}]")); } } diff --git a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/type/DataTypeConversionTests.java b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/type/DataTypeConversionTests.java index b191646a9cdee..7a04139430e33 100644 --- a/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/type/DataTypeConversionTests.java +++ b/x-pack/plugin/sql/src/test/java/org/elasticsearch/xpack/sql/type/DataTypeConversionTests.java @@ -224,6 +224,12 @@ public void testConversionToNull() { assertNull(conversion.convert(10.0)); } + public void testConversionFromNull() { + Conversion conversion = DataTypeConversion.conversionFor(DataType.NULL, DataType.INTEGER); + assertNull(conversion.convert(null)); + assertNull(conversion.convert(10)); + } + public void testConversionToIdentity() { Conversion conversion = DataTypeConversion.conversionFor(DataType.INTEGER, DataType.INTEGER); assertNull(conversion.convert(null));