Skip to content

Commit

Permalink
SQL: Refactor args verification of In & conditionals (#40916)
Browse files Browse the repository at this point in the history
Move verification of arguments for Conditional functions and IN
from `Verifier` to the `resolveType()` method of the functions.
  • Loading branch information
matriv authored Apr 8, 2019
1 parent 71d407f commit 241644a
Show file tree
Hide file tree
Showing 11 changed files with 212 additions and 253 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalFunction;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
import org.elasticsearch.xpack.sql.plan.logical.Distinct;
import org.elasticsearch.xpack.sql.plan.logical.Filter;
Expand Down Expand Up @@ -228,9 +226,6 @@ Collection<Failure> verify(LogicalPlan plan) {

Set<Failure> localFailures = new LinkedHashSet<>();

validateInExpression(p, localFailures);
validateConditional(p, localFailures);

checkGroupingFunctionInGroupBy(p, localFailures);
checkFilterOnAggs(p, localFailures);
checkFilterOnGrouping(p, localFailures);
Expand Down Expand Up @@ -724,52 +719,4 @@ private static void checkNestedUsedInGroupByOrHaving(LogicalPlan p, Set<Failure>
fail(nested.get(0), "HAVING isn't (yet) compatible with nested fields " + new AttributeSet(nested).names()));
}
}

private static void validateInExpression(LogicalPlan p, Set<Failure> localFailures) {
p.forEachExpressions(e ->
e.forEachUp((In in) -> {
DataType dt = in.value().dataType();
for (Expression value : in.list()) {
if (areTypesCompatible(dt, value.dataType()) == false) {
localFailures.add(fail(value, "expected data type [{}], value provided is of type [{}]",
dt.typeName, value.dataType().typeName));
return;
}
}
},
In.class));
}

private static void validateConditional(LogicalPlan p, Set<Failure> localFailures) {
p.forEachExpressions(e ->
e.forEachUp((ConditionalFunction cf) -> {
DataType dt = DataType.NULL;

for (Expression child : cf.children()) {
if (dt == DataType.NULL) {
if (Expressions.isNull(child) == false) {
dt = child.dataType();
}
} else {
if (areTypesCompatible(dt, child.dataType()) == false) {
localFailures.add(fail(child, "expected data type [{}], value provided is of type [{}]",
dt.typeName, child.dataType().typeName));
return;
}
}
}
},
ConditionalFunction.class));
}

private static boolean areTypesCompatible(DataType left, DataType right) {
if (left == right) {
return true;
} else {
return
(left == DataType.NULL || right == DataType.NULL) ||
(left.isString() && right.isString()) ||
(left.isNumeric() && right.isNumeric());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ void addToMap(FunctionDefinition...functions) {
for (String alias : f.aliases()) {
Object old = batchMap.put(alias, f);
if (old != null || defs.containsKey(alias)) {
throw new IllegalArgumentException("alias [" + alias + "] is used by "
throw new SqlIllegalArgumentException("alias [" + alias + "] is used by "
+ "[" + (old != null ? old : defs.get(alias).name()) + "] and [" + f.name() + "]");
}
aliases.put(alias, f.name());
Expand Down Expand Up @@ -321,10 +321,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
java.util.function.Function<Source, T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (false == children.isEmpty()) {
throw new IllegalArgumentException("expects no arguments");
throw new SqlIllegalArgumentException("expects no arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.apply(source);
};
Expand All @@ -341,10 +341,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
ConfigurationAwareFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (false == children.isEmpty()) {
throw new IllegalArgumentException("expects no arguments");
throw new SqlIllegalArgumentException("expects no arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, cfg);
};
Expand All @@ -365,10 +365,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
UnaryConfigurationAwareFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() > 1) {
throw new IllegalArgumentException("expects exactly one argument");
throw new SqlIllegalArgumentException("expects exactly one argument");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
Expression ex = children.size() == 1 ? children.get(0) : null;
return ctorRef.build(source, ex, cfg);
Expand All @@ -390,10 +390,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
BiFunction<Source, Expression, T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 1) {
throw new IllegalArgumentException("expects exactly one argument");
throw new SqlIllegalArgumentException("expects exactly one argument");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.apply(source, children.get(0));
};
Expand All @@ -409,7 +409,7 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
MultiFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children);
};
Expand All @@ -429,7 +429,7 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
DistinctAwareUnaryFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 1) {
throw new IllegalArgumentException("expects exactly one argument");
throw new SqlIllegalArgumentException("expects exactly one argument");
}
return ctorRef.build(source, children.get(0), distinct);
};
Expand All @@ -449,10 +449,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
DatetimeUnaryFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 1) {
throw new IllegalArgumentException("expects exactly one argument");
throw new SqlIllegalArgumentException("expects exactly one argument");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), cfg.zoneId());
};
Expand All @@ -471,10 +471,10 @@ interface DatetimeUnaryFunctionBuilder<T> {
static <T extends Function> FunctionDefinition def(Class<T> function, DatetimeBinaryFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 2) {
throw new IllegalArgumentException("expects exactly two arguments");
throw new SqlIllegalArgumentException("expects exactly two arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), children.get(1), cfg.zoneId());
};
Expand All @@ -496,13 +496,13 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
boolean isBinaryOptionalParamFunction = function.isAssignableFrom(Round.class) || function.isAssignableFrom(Truncate.class)
|| TopHits.class.isAssignableFrom(function);
if (isBinaryOptionalParamFunction && (children.size() > 2 || children.size() < 1)) {
throw new IllegalArgumentException("expects one or two arguments");
throw new SqlIllegalArgumentException("expects one or two arguments");
} else if (!isBinaryOptionalParamFunction && children.size() != 2) {
throw new IllegalArgumentException("expects exactly two arguments");
throw new SqlIllegalArgumentException("expects exactly two arguments");
}

if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), children.size() == 2 ? children.get(1) : null);
};
Expand All @@ -527,7 +527,7 @@ private static FunctionDefinition def(Class<? extends Function> function, Functi
FunctionDefinition.Builder realBuilder = (uf, distinct, cfg) -> {
try {
return builder.build(uf.source(), uf.children(), distinct, cfg);
} catch (IllegalArgumentException e) {
} catch (SqlIllegalArgumentException e) {
throw new ParsingException(uf.source(), "error building [" + primaryName + "]: " + e.getMessage(), e);
}
};
Expand All @@ -544,12 +544,12 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
FunctionBuilder builder = (source, children, distinct, cfg) -> {
boolean isLocateFunction = function.isAssignableFrom(Locate.class);
if (isLocateFunction && (children.size() > 3 || children.size() < 2)) {
throw new IllegalArgumentException("expects two or three arguments");
throw new SqlIllegalArgumentException("expects two or three arguments");
} else if (!isLocateFunction && children.size() != 3) {
throw new IllegalArgumentException("expects exactly three arguments");
throw new SqlIllegalArgumentException("expects exactly three arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), children.get(1), children.size() == 3 ? children.get(2) : null);
};
Expand All @@ -565,10 +565,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
FourParametersFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 4) {
throw new IllegalArgumentException("expects exactly four arguments");
throw new SqlIllegalArgumentException("expects exactly four arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), children.get(1), children.get(2), children.get(3));
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataTypeConversion;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -33,14 +32,6 @@ public abstract class ArbitraryConditionalFunction extends ConditionalFunction {
this.operation = operation;
}

@Override
protected TypeResolution resolveType() {
for (Expression e : children()) {
dataType = DataTypeConversion.commonType(dataType, e.dataType());
}
return TypeResolution.TYPE_RESOLVED;
}

@Override
protected Pipe makePipe() {
return new ConditionalPipe(source(), this, Expressions.pipe(children()), operation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypeConversion;

import java.util.List;

import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.sql.type.DataTypes.areTypesCompatible;
import static org.elasticsearch.xpack.sql.util.StringUtils.ordinal;

/**
* Base class for conditional predicates.
*/
Expand All @@ -36,6 +41,31 @@ public boolean foldable() {
return Expressions.foldable(children());
}

@Override
protected TypeResolution resolveType() {
DataType dt = DataType.NULL;

for (int i = 0; i < children().size(); i++) {
Expression child = children().get(i);
if (dt == DataType.NULL) {
if (Expressions.isNull(child) == false) {
dt = child.dataType();
}
} else {
if (areTypesCompatible(dt, child.dataType()) == false) {
return new TypeResolution(format(null, "{} argument of [{}] must be [{}], found value [{}] type [{}]",
ordinal(i + 1),
sourceText(),
dt.typeName,
Expressions.name(child),
child.dataType().typeName));
}
}
dataType = DataTypeConversion.commonType(dataType, child.dataType());
}
return TypeResolution.TYPE_RESOLVED;
}

@Override
public Nullability nullable() {
return Nullability.UNKNOWN;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ public Expression replaceChildren(List<Expression> newChildren) {
return new NullIf(source(), newChildren.get(0), newChildren.get(1));
}

@Override
protected TypeResolution resolveType() {
dataType = children().get(0).dataType();
return TypeResolution.TYPE_RESOLVED;
}

@Override
public Object fold() {
return NullIfProcessor.apply(children().get(0).fold(), children().get(1).fold());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;
import static org.elasticsearch.xpack.sql.type.DataTypes.areTypesCompatible;
import static org.elasticsearch.xpack.sql.util.StringUtils.ordinal;

public class In extends ScalarFunction {

Expand Down Expand Up @@ -109,7 +111,7 @@ protected Pipe makePipe() {
@Override
protected TypeResolution resolveType() {
TypeResolution resolution = TypeResolutions.isExact(value, functionName(), Expressions.ParamOrdinal.DEFAULT);
if (resolution != TypeResolution.TYPE_RESOLVED) {
if (resolution.unresolved()) {
return resolution;
}

Expand All @@ -120,6 +122,20 @@ protected TypeResolution resolveType() {
name()));
}
}

DataType dt = value.dataType();
for (int i = 0; i < list.size(); i++) {
Expression listValue = list.get(i);
if (areTypesCompatible(dt, listValue.dataType()) == false) {
return new TypeResolution(format(null, "{} argument of [{}] must be [{}], found value [{}] type [{}]",
ordinal(i + 1),
sourceText(),
dt.typeName,
Expressions.name(listValue),
listValue.dataType().typeName));
}
}

return super.resolveType();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,15 @@ public static Integer precision(DataType t) {
}
return t.displaySize;
}

public static boolean areTypesCompatible(DataType left, DataType right) {
if (left == right) {
return true;
} else {
return
(left == DataType.NULL || right == DataType.NULL) ||
(left.isString() && right.isString()) ||
(left.isNumeric() && right.isNumeric());
}
}
}
Loading

0 comments on commit 241644a

Please sign in to comment.