Skip to content

Commit

Permalink
Require session for coercion resolution
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Oct 15, 2021
1 parent 86fd8f6 commit 4826d05
Show file tree
Hide file tree
Showing 30 changed files with 146 additions and 138 deletions.
2 changes: 1 addition & 1 deletion core/trino-main/src/main/java/io/trino/cost/StatsUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static OptionalDouble toStatsRepresentation(Metadata metadata, Session session,
{
if (convertibleToDoubleWithCast(type)) {
InterpretedFunctionInvoker functionInvoker = new InterpretedFunctionInvoker(metadata);
ResolvedFunction castFunction = metadata.getCoercion(type, DOUBLE);
ResolvedFunction castFunction = metadata.getCoercion(session, type, DOUBLE);
return OptionalDouble.of((double) functionInvoker.invoke(castFunction, session.toConnectorSession(), singletonList(value)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Ordering;
import io.trino.Session;
import io.trino.spi.TrinoException;
import io.trino.spi.type.Type;
import io.trino.sql.analyzer.TypeSignatureProvider;
Expand Down Expand Up @@ -50,13 +51,13 @@ public FunctionResolver(Metadata metadata)
this.metadata = metadata;
}

FunctionBinding resolveCoercion(Collection<FunctionMetadata> allCandidates, Signature signature)
FunctionBinding resolveCoercion(Session session, Collection<FunctionMetadata> allCandidates, Signature signature)
{
List<FunctionMetadata> exactCandidates = allCandidates.stream()
.filter(function -> possibleExactCastMatch(signature, function.getSignature()))
.collect(Collectors.toList());
for (FunctionMetadata candidate : exactCandidates) {
if (canBindSignature(candidate.getSignature(), signature)) {
if (canBindSignature(session, candidate.getSignature(), signature)) {
return toFunctionBinding(candidate, signature);
}
}
Expand All @@ -66,17 +67,17 @@ FunctionBinding resolveCoercion(Collection<FunctionMetadata> allCandidates, Sign
.filter(function -> !function.getSignature().getTypeVariableConstraints().isEmpty())
.collect(Collectors.toList());
for (FunctionMetadata candidate : genericCandidates) {
if (canBindSignature(candidate.getSignature(), signature)) {
if (canBindSignature(session, candidate.getSignature(), signature)) {
return toFunctionBinding(candidate, signature);
}
}

throw new TrinoException(FUNCTION_IMPLEMENTATION_MISSING, format("%s not found", signature));
}

private boolean canBindSignature(Signature declaredSignature, Signature actualSignature)
private boolean canBindSignature(Session session, Signature declaredSignature, Signature actualSignature)
{
return new SignatureBinder(metadata, declaredSignature, false)
return new SignatureBinder(session, metadata, declaredSignature, false)
.canBind(fromTypeSignatures(actualSignature.getArgumentTypes()), actualSignature.getReturnType());
}

Expand Down Expand Up @@ -108,7 +109,7 @@ private static boolean possibleExactCastMatch(Signature signature, Signature dec
return true;
}

FunctionBinding resolveFunction(Collection<FunctionMetadata> allCandidates, QualifiedName name, List<TypeSignatureProvider> parameterTypes)
FunctionBinding resolveFunction(Session session, Collection<FunctionMetadata> allCandidates, QualifiedName name, List<TypeSignatureProvider> parameterTypes)
{
if (allCandidates.isEmpty()) {
throw new TrinoException(FUNCTION_NOT_FOUND, format("Function '%s' not registered", name));
Expand All @@ -118,7 +119,7 @@ FunctionBinding resolveFunction(Collection<FunctionMetadata> allCandidates, Qual
.filter(function -> function.getSignature().getTypeVariableConstraints().isEmpty())
.collect(toImmutableList());

Optional<FunctionBinding> match = matchFunctionExact(exactCandidates, parameterTypes);
Optional<FunctionBinding> match = matchFunctionExact(session, exactCandidates, parameterTypes);
if (match.isPresent()) {
return match.get();
}
Expand All @@ -127,12 +128,12 @@ FunctionBinding resolveFunction(Collection<FunctionMetadata> allCandidates, Qual
.filter(function -> !function.getSignature().getTypeVariableConstraints().isEmpty())
.collect(toImmutableList());

match = matchFunctionExact(genericCandidates, parameterTypes);
match = matchFunctionExact(session, genericCandidates, parameterTypes);
if (match.isPresent()) {
return match.get();
}

match = matchFunctionWithCoercion(allCandidates, parameterTypes);
match = matchFunctionWithCoercion(session, allCandidates, parameterTypes);
if (match.isPresent()) {
return match.get();
}
Expand All @@ -151,25 +152,25 @@ FunctionBinding resolveFunction(Collection<FunctionMetadata> allCandidates, Qual
throw new TrinoException(FUNCTION_NOT_FOUND, message);
}

private Optional<FunctionBinding> matchFunctionExact(List<FunctionMetadata> candidates, List<TypeSignatureProvider> actualParameters)
private Optional<FunctionBinding> matchFunctionExact(Session session, List<FunctionMetadata> candidates, List<TypeSignatureProvider> actualParameters)
{
return matchFunction(candidates, actualParameters, false);
return matchFunction(session, candidates, actualParameters, false);
}

private Optional<FunctionBinding> matchFunctionWithCoercion(Collection<FunctionMetadata> candidates, List<TypeSignatureProvider> actualParameters)
private Optional<FunctionBinding> matchFunctionWithCoercion(Session session, Collection<FunctionMetadata> candidates, List<TypeSignatureProvider> actualParameters)
{
return matchFunction(candidates, actualParameters, true);
return matchFunction(session, candidates, actualParameters, true);
}

private Optional<FunctionBinding> matchFunction(Collection<FunctionMetadata> candidates, List<TypeSignatureProvider> parameters, boolean coercionAllowed)
private Optional<FunctionBinding> matchFunction(Session session, Collection<FunctionMetadata> candidates, List<TypeSignatureProvider> parameters, boolean coercionAllowed)
{
List<ApplicableFunction> applicableFunctions = identifyApplicableFunctions(candidates, parameters, coercionAllowed);
List<ApplicableFunction> applicableFunctions = identifyApplicableFunctions(session, candidates, parameters, coercionAllowed);
if (applicableFunctions.isEmpty()) {
return Optional.empty();
}

if (coercionAllowed) {
applicableFunctions = selectMostSpecificFunctions(applicableFunctions, parameters);
applicableFunctions = selectMostSpecificFunctions(session, applicableFunctions, parameters);
checkState(!applicableFunctions.isEmpty(), "at least single function must be left");
}

Expand All @@ -189,22 +190,22 @@ private Optional<FunctionBinding> matchFunction(Collection<FunctionMetadata> can
throw new TrinoException(AMBIGUOUS_FUNCTION_CALL, errorMessageBuilder.toString());
}

private List<ApplicableFunction> identifyApplicableFunctions(Collection<FunctionMetadata> candidates, List<TypeSignatureProvider> actualParameters, boolean allowCoercion)
private List<ApplicableFunction> identifyApplicableFunctions(Session session, Collection<FunctionMetadata> candidates, List<TypeSignatureProvider> actualParameters, boolean allowCoercion)
{
ImmutableList.Builder<ApplicableFunction> applicableFunctions = ImmutableList.builder();
for (FunctionMetadata function : candidates) {
new SignatureBinder(metadata, function.getSignature(), allowCoercion)
new SignatureBinder(session, metadata, function.getSignature(), allowCoercion)
.bind(actualParameters)
.ifPresent(signature -> applicableFunctions.add(new ApplicableFunction(function, signature)));
}
return applicableFunctions.build();
}

private List<ApplicableFunction> selectMostSpecificFunctions(List<ApplicableFunction> applicableFunctions, List<TypeSignatureProvider> parameters)
private List<ApplicableFunction> selectMostSpecificFunctions(Session session, List<ApplicableFunction> applicableFunctions, List<TypeSignatureProvider> parameters)
{
checkArgument(!applicableFunctions.isEmpty());

List<ApplicableFunction> mostSpecificFunctions = selectMostSpecificFunctions(applicableFunctions);
List<ApplicableFunction> mostSpecificFunctions = selectMostSpecificFunctions(session, applicableFunctions);
if (mostSpecificFunctions.size() <= 1) {
return mostSpecificFunctions;
}
Expand Down Expand Up @@ -244,18 +245,18 @@ private List<ApplicableFunction> selectMostSpecificFunctions(List<ApplicableFunc
return mostSpecificFunctions;
}

private List<ApplicableFunction> selectMostSpecificFunctions(List<ApplicableFunction> candidates)
private List<ApplicableFunction> selectMostSpecificFunctions(Session session, List<ApplicableFunction> candidates)
{
List<ApplicableFunction> representatives = new ArrayList<>();

for (ApplicableFunction current : candidates) {
boolean found = false;
for (int i = 0; i < representatives.size(); i++) {
ApplicableFunction representative = representatives.get(i);
if (isMoreSpecificThan(current, representative)) {
if (isMoreSpecificThan(session, current, representative)) {
representatives.set(i, current);
}
if (isMoreSpecificThan(current, representative) || isMoreSpecificThan(representative, current)) {
if (isMoreSpecificThan(session, current, representative) || isMoreSpecificThan(session, representative, current)) {
found = true;
break;
}
Expand Down Expand Up @@ -342,10 +343,10 @@ private Optional<List<Type>> toTypes(List<TypeSignatureProvider> typeSignaturePr
/**
* One method is more specific than another if invocation handled by the first method could be passed on to the other one
*/
private boolean isMoreSpecificThan(ApplicableFunction left, ApplicableFunction right)
private boolean isMoreSpecificThan(Session session, ApplicableFunction left, ApplicableFunction right)
{
List<TypeSignatureProvider> resolvedTypes = fromTypeSignatures(left.getBoundSignature().getArgumentTypes());
return new SignatureBinder(metadata, right.getDeclaredSignature(), true)
return new SignatureBinder(session, metadata, right.getDeclaredSignature(), true)
.canBind(resolvedTypes);
}

Expand Down
8 changes: 4 additions & 4 deletions core/trino-main/src/main/java/io/trino/metadata/Metadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -616,14 +616,14 @@ default Type getParameterizedType(String baseTypeName, List<TypeSignatureParamet
ResolvedFunction resolveOperator(Session session, OperatorType operatorType, List<? extends Type> argumentTypes)
throws OperatorNotFoundException;

default ResolvedFunction getCoercion(Type fromType, Type toType)
default ResolvedFunction getCoercion(Session session, Type fromType, Type toType)
{
return getCoercion(CAST, fromType, toType);
return getCoercion(session, CAST, fromType, toType);
}

ResolvedFunction getCoercion(OperatorType operatorType, Type fromType, Type toType);
ResolvedFunction getCoercion(Session session, OperatorType operatorType, Type fromType, Type toType);

ResolvedFunction getCoercion(QualifiedName name, Type fromType, Type toType);
ResolvedFunction getCoercion(Session session, QualifiedName name, Type fromType, Type toType);

/**
* Is the named function an aggregation function? This does not need type parameters
Expand Down
Loading

0 comments on commit 4826d05

Please sign in to comment.