Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL: Refactor args verification of In & conditionals #40916

Merged
merged 4 commits into from
Apr 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunction;
import org.elasticsearch.xpack.sql.expression.function.grouping.GroupingFunctionAttribute;
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalFunction;
import org.elasticsearch.xpack.sql.expression.predicate.operator.comparison.In;
import org.elasticsearch.xpack.sql.plan.logical.Aggregate;
import org.elasticsearch.xpack.sql.plan.logical.Distinct;
import org.elasticsearch.xpack.sql.plan.logical.Filter;
Expand Down Expand Up @@ -228,9 +226,6 @@ Collection<Failure> verify(LogicalPlan plan) {

Set<Failure> localFailures = new LinkedHashSet<>();

validateInExpression(p, localFailures);
validateConditional(p, localFailures);

checkGroupingFunctionInGroupBy(p, localFailures);
checkFilterOnAggs(p, localFailures);
checkFilterOnGrouping(p, localFailures);
Expand Down Expand Up @@ -724,52 +719,4 @@ private static void checkNestedUsedInGroupByOrHaving(LogicalPlan p, Set<Failure>
fail(nested.get(0), "HAVING isn't (yet) compatible with nested fields " + new AttributeSet(nested).names()));
}
}

private static void validateInExpression(LogicalPlan p, Set<Failure> localFailures) {
p.forEachExpressions(e ->
e.forEachUp((In in) -> {
DataType dt = in.value().dataType();
for (Expression value : in.list()) {
if (areTypesCompatible(dt, value.dataType()) == false) {
localFailures.add(fail(value, "expected data type [{}], value provided is of type [{}]",
dt.typeName, value.dataType().typeName));
return;
}
}
},
In.class));
}

private static void validateConditional(LogicalPlan p, Set<Failure> localFailures) {
p.forEachExpressions(e ->
e.forEachUp((ConditionalFunction cf) -> {
DataType dt = DataType.NULL;

for (Expression child : cf.children()) {
if (dt == DataType.NULL) {
if (Expressions.isNull(child) == false) {
dt = child.dataType();
}
} else {
if (areTypesCompatible(dt, child.dataType()) == false) {
localFailures.add(fail(child, "expected data type [{}], value provided is of type [{}]",
dt.typeName, child.dataType().typeName));
return;
}
}
}
},
ConditionalFunction.class));
}

private static boolean areTypesCompatible(DataType left, DataType right) {
if (left == right) {
return true;
} else {
return
(left == DataType.NULL || right == DataType.NULL) ||
(left.isString() && right.isString()) ||
(left.isNumeric() && right.isNumeric());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ void addToMap(FunctionDefinition...functions) {
for (String alias : f.aliases()) {
Object old = batchMap.put(alias, f);
if (old != null || defs.containsKey(alias)) {
throw new IllegalArgumentException("alias [" + alias + "] is used by "
throw new SqlIllegalArgumentException("alias [" + alias + "] is used by "
+ "[" + (old != null ? old : defs.get(alias).name()) + "] and [" + f.name() + "]");
}
aliases.put(alias, f.name());
Expand Down Expand Up @@ -321,10 +321,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
java.util.function.Function<Source, T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (false == children.isEmpty()) {
throw new IllegalArgumentException("expects no arguments");
throw new SqlIllegalArgumentException("expects no arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.apply(source);
};
Expand All @@ -341,10 +341,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
ConfigurationAwareFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (false == children.isEmpty()) {
throw new IllegalArgumentException("expects no arguments");
throw new SqlIllegalArgumentException("expects no arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, cfg);
};
Expand All @@ -365,10 +365,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
UnaryConfigurationAwareFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() > 1) {
throw new IllegalArgumentException("expects exactly one argument");
throw new SqlIllegalArgumentException("expects exactly one argument");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
Expression ex = children.size() == 1 ? children.get(0) : null;
return ctorRef.build(source, ex, cfg);
Expand All @@ -390,10 +390,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
BiFunction<Source, Expression, T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 1) {
throw new IllegalArgumentException("expects exactly one argument");
throw new SqlIllegalArgumentException("expects exactly one argument");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.apply(source, children.get(0));
};
Expand All @@ -409,7 +409,7 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
MultiFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children);
};
Expand All @@ -429,7 +429,7 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
DistinctAwareUnaryFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 1) {
throw new IllegalArgumentException("expects exactly one argument");
throw new SqlIllegalArgumentException("expects exactly one argument");
}
return ctorRef.build(source, children.get(0), distinct);
};
Expand All @@ -449,10 +449,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
DatetimeUnaryFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 1) {
throw new IllegalArgumentException("expects exactly one argument");
throw new SqlIllegalArgumentException("expects exactly one argument");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), cfg.zoneId());
};
Expand All @@ -471,10 +471,10 @@ interface DatetimeUnaryFunctionBuilder<T> {
static <T extends Function> FunctionDefinition def(Class<T> function, DatetimeBinaryFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 2) {
throw new IllegalArgumentException("expects exactly two arguments");
throw new SqlIllegalArgumentException("expects exactly two arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), children.get(1), cfg.zoneId());
};
Expand All @@ -496,13 +496,13 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
boolean isBinaryOptionalParamFunction = function.isAssignableFrom(Round.class) || function.isAssignableFrom(Truncate.class)
|| TopHits.class.isAssignableFrom(function);
if (isBinaryOptionalParamFunction && (children.size() > 2 || children.size() < 1)) {
throw new IllegalArgumentException("expects one or two arguments");
throw new SqlIllegalArgumentException("expects one or two arguments");
} else if (!isBinaryOptionalParamFunction && children.size() != 2) {
throw new IllegalArgumentException("expects exactly two arguments");
throw new SqlIllegalArgumentException("expects exactly two arguments");
}

if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), children.size() == 2 ? children.get(1) : null);
};
Expand All @@ -527,7 +527,7 @@ private static FunctionDefinition def(Class<? extends Function> function, Functi
FunctionDefinition.Builder realBuilder = (uf, distinct, cfg) -> {
try {
return builder.build(uf.source(), uf.children(), distinct, cfg);
} catch (IllegalArgumentException e) {
} catch (SqlIllegalArgumentException e) {
throw new ParsingException(uf.source(), "error building [" + primaryName + "]: " + e.getMessage(), e);
}
};
Expand All @@ -544,12 +544,12 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
FunctionBuilder builder = (source, children, distinct, cfg) -> {
boolean isLocateFunction = function.isAssignableFrom(Locate.class);
if (isLocateFunction && (children.size() > 3 || children.size() < 2)) {
throw new IllegalArgumentException("expects two or three arguments");
throw new SqlIllegalArgumentException("expects two or three arguments");
} else if (!isLocateFunction && children.size() != 3) {
throw new IllegalArgumentException("expects exactly three arguments");
throw new SqlIllegalArgumentException("expects exactly three arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), children.get(1), children.size() == 3 ? children.get(2) : null);
};
Expand All @@ -565,10 +565,10 @@ static <T extends Function> FunctionDefinition def(Class<T> function,
FourParametersFunctionBuilder<T> ctorRef, String... names) {
FunctionBuilder builder = (source, children, distinct, cfg) -> {
if (children.size() != 4) {
throw new IllegalArgumentException("expects exactly four arguments");
throw new SqlIllegalArgumentException("expects exactly four arguments");
}
if (distinct) {
throw new IllegalArgumentException("does not support DISTINCT yet it was specified");
throw new SqlIllegalArgumentException("does not support DISTINCT yet it was specified");
}
return ctorRef.build(source, children.get(0), children.get(1), children.get(2), children.get(3));
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.elasticsearch.xpack.sql.expression.gen.script.ScriptTemplate;
import org.elasticsearch.xpack.sql.expression.predicate.conditional.ConditionalProcessor.ConditionalOperation;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataTypeConversion;

import java.util.ArrayList;
import java.util.List;
Expand All @@ -33,14 +32,6 @@ public abstract class ArbitraryConditionalFunction extends ConditionalFunction {
this.operation = operation;
}

@Override
protected TypeResolution resolveType() {
for (Expression e : children()) {
dataType = DataTypeConversion.commonType(dataType, e.dataType());
}
return TypeResolution.TYPE_RESOLVED;
}

@Override
protected Pipe makePipe() {
return new ConditionalPipe(source(), this, Expressions.pipe(children()), operation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,14 @@
import org.elasticsearch.xpack.sql.expression.function.scalar.ScalarFunction;
import org.elasticsearch.xpack.sql.tree.Source;
import org.elasticsearch.xpack.sql.type.DataType;
import org.elasticsearch.xpack.sql.type.DataTypeConversion;

import java.util.List;

import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.sql.type.DataTypes.areTypesCompatible;
import static org.elasticsearch.xpack.sql.util.StringUtils.ordinal;

/**
* Base class for conditional predicates.
*/
Expand All @@ -36,6 +41,31 @@ public boolean foldable() {
return Expressions.foldable(children());
}

@Override
protected TypeResolution resolveType() {
DataType dt = DataType.NULL;

for (int i = 0; i < children().size(); i++) {
Expression child = children().get(i);
if (dt == DataType.NULL) {
if (Expressions.isNull(child) == false) {
dt = child.dataType();
}
} else {
if (areTypesCompatible(dt, child.dataType()) == false) {
return new TypeResolution(format(null, "{} argument of [{}] must be [{}], found value [{}] type [{}]",
ordinal(i + 1),
sourceText(),
dt.typeName,
Expressions.name(child),
child.dataType().typeName));
}
}
dataType = DataTypeConversion.commonType(dataType, child.dataType());
}
return TypeResolution.TYPE_RESOLVED;
}

@Override
public Nullability nullable() {
return Nullability.UNKNOWN;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ public Expression replaceChildren(List<Expression> newChildren) {
return new NullIf(source(), newChildren.get(0), newChildren.get(1));
}

@Override
protected TypeResolution resolveType() {
dataType = children().get(0).dataType();
return TypeResolution.TYPE_RESOLVED;
}

@Override
public Object fold() {
return NullIfProcessor.apply(children().get(0).fold(), children().get(1).fold());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

import static org.elasticsearch.common.logging.LoggerMessageFormat.format;
import static org.elasticsearch.xpack.sql.expression.gen.script.ParamsBuilder.paramsBuilder;
import static org.elasticsearch.xpack.sql.type.DataTypes.areTypesCompatible;
import static org.elasticsearch.xpack.sql.util.StringUtils.ordinal;

public class In extends ScalarFunction {

Expand Down Expand Up @@ -109,7 +111,7 @@ protected Pipe makePipe() {
@Override
protected TypeResolution resolveType() {
TypeResolution resolution = TypeResolutions.isExact(value, functionName(), Expressions.ParamOrdinal.DEFAULT);
if (resolution != TypeResolution.TYPE_RESOLVED) {
if (resolution.unresolved()) {
return resolution;
}

Expand All @@ -120,6 +122,20 @@ protected TypeResolution resolveType() {
name()));
}
}

DataType dt = value.dataType();
for (int i = 0; i < list.size(); i++) {
Expression listValue = list.get(i);
if (areTypesCompatible(dt, listValue.dataType()) == false) {
return new TypeResolution(format(null, "{} argument of [{}] must be [{}], found value [{}] type [{}]",
ordinal(i + 1),
sourceText(),
dt.typeName,
Expressions.name(listValue),
listValue.dataType().typeName));
}
}

return super.resolveType();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,4 +230,15 @@ public static Integer precision(DataType t) {
}
return t.displaySize;
}

public static boolean areTypesCompatible(DataType left, DataType right) {
if (left == right) {
return true;
} else {
return
(left == DataType.NULL || right == DataType.NULL) ||
(left.isString() && right.isString()) ||
(left.isNumeric() && right.isNumeric());
}
}
}
Loading