From 4336c25250231c62c26e4a913cfb401eaa9df924 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Tue, 21 Jul 2020 21:13:10 -0700 Subject: [PATCH] Add FunctionBinding --- .../io/prestosql/metadata/BoundVariables.java | 10 +++ .../prestosql/metadata/FunctionBinding.java | 81 +++++++++++++++++++ .../prestosql/metadata/FunctionRegistry.java | 78 +++++++++--------- .../prestosql/metadata/MetadataManager.java | 17 +++- .../prestosql/metadata/SignatureBinder.java | 13 +++ .../metadata/SpecializedFunctionKey.java | 70 ---------------- 6 files changed, 153 insertions(+), 116 deletions(-) create mode 100644 presto-main/src/main/java/io/prestosql/metadata/FunctionBinding.java delete mode 100644 presto-main/src/main/java/io/prestosql/metadata/SpecializedFunctionKey.java diff --git a/presto-main/src/main/java/io/prestosql/metadata/BoundVariables.java b/presto-main/src/main/java/io/prestosql/metadata/BoundVariables.java index c9868671c415..c2887a4ab929 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/BoundVariables.java +++ b/presto-main/src/main/java/io/prestosql/metadata/BoundVariables.java @@ -43,6 +43,11 @@ public BoundVariables(Map typeVariables, Map longVar .collect(toImmutableSortedMap(CASE_INSENSITIVE_ORDER, Map.Entry::getKey, Map.Entry::getValue)); } + public Map getTypeVariables() + { + return typeVariables; + } + public Type getTypeVariable(String variableName) { return getValue(typeVariables, variableName); @@ -53,6 +58,11 @@ public boolean containsTypeVariable(String variableName) return containsValue(typeVariables, variableName); } + public Map getLongVariables() + { + return longVariables; + } + public Long getLongVariable(String variableName) { return getValue(longVariables, variableName); diff --git a/presto-main/src/main/java/io/prestosql/metadata/FunctionBinding.java b/presto-main/src/main/java/io/prestosql/metadata/FunctionBinding.java new file mode 100644 index 000000000000..209907e3e901 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/metadata/FunctionBinding.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.metadata; + +import com.google.common.collect.ImmutableSortedMap; +import io.prestosql.spi.type.Type; + +import java.util.Map; +import java.util.Objects; + +import static java.lang.String.CASE_INSENSITIVE_ORDER; +import static java.util.Objects.requireNonNull; + +public class FunctionBinding +{ + private final FunctionId functionId; + private final Signature boundSignature; + private final Map typeVariables; + private final Map longVariables; + + public FunctionBinding(FunctionId functionId, Signature boundSignature, Map typeVariables, Map longVariables) + { + this.functionId = requireNonNull(functionId, "functionId is null"); + this.boundSignature = requireNonNull(boundSignature, "boundSignature is null"); + this.typeVariables = ImmutableSortedMap.copyOf(requireNonNull(typeVariables, "typeVariables is null"), CASE_INSENSITIVE_ORDER); + this.longVariables = ImmutableSortedMap.copyOf(requireNonNull(longVariables, "longVariables is null"), CASE_INSENSITIVE_ORDER); + } + + public FunctionId getFunctionId() + { + return functionId; + } + + public Signature getBoundSignature() + { + return boundSignature; + } + + public Map getTypeVariables() + { + return typeVariables; + } + + public Map getLongVariables() + { + return longVariables; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + FunctionBinding that = (FunctionBinding) o; + return Objects.equals(functionId, that.functionId) && + Objects.equals(boundSignature, that.boundSignature) && + Objects.equals(typeVariables, that.typeVariables) && + Objects.equals(longVariables, that.longVariables); + } + + @Override + public int hashCode() + { + return Objects.hash(functionId, boundSignature, typeVariables, longVariables); + } +} diff --git a/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java b/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java index 1cea2ca8d563..1bf6ebb268ee 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java @@ -332,7 +332,6 @@ import static io.prestosql.operator.scalar.ZipFunction.ZIP_FUNCTIONS; import static io.prestosql.operator.scalar.ZipWithFunction.ZIP_WITH_FUNCTION; import static io.prestosql.operator.window.AggregateWindowFunction.supplier; -import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; import static io.prestosql.type.DecimalCasts.BIGINT_TO_DECIMAL_CAST; import static io.prestosql.type.DecimalCasts.BOOLEAN_TO_DECIMAL_CAST; import static io.prestosql.type.DecimalCasts.DECIMAL_TO_BIGINT_CAST; @@ -379,9 +378,9 @@ @ThreadSafe public class FunctionRegistry { - private final Cache specializedScalarCache; - private final Cache specializedAggregationCache; - private final Cache specializedWindowCache; + private final Cache specializedScalarCache; + private final Cache specializedAggregationCache; + private final Cache specializedWindowCache; private volatile FunctionMap functions = new FunctionMap(); public FunctionRegistry(Metadata metadata, FeaturesConfig featuresConfig) @@ -798,28 +797,29 @@ public FunctionMetadata get(FunctionId functionId) return functions.get(functionId).getFunctionMetadata(); } - public AggregationFunctionMetadata getAggregationFunctionMetadata(Metadata metadata, ResolvedFunction resolvedFunction) + public AggregationFunctionMetadata getAggregationFunctionMetadata(Metadata metadata, FunctionBinding functionBinding) { - SqlFunction function = functions.get(resolvedFunction.getFunctionId()); - checkArgument(function instanceof SqlAggregationFunction, "%s is not an aggregation function", resolvedFunction); + SqlFunction function = functions.get(functionBinding.getFunctionId()); + checkArgument(function instanceof SqlAggregationFunction, "%s is not an aggregation function", functionBinding.getBoundSignature()); SqlAggregationFunction aggregationFunction = (SqlAggregationFunction) function; if (!aggregationFunction.isDecomposable()) { return new AggregationFunctionMetadata(aggregationFunction.isOrderSensitive(), Optional.empty()); } - InternalAggregationFunction implementation = getAggregateFunctionImplementation(metadata, resolvedFunction); + InternalAggregationFunction implementation = getAggregateFunctionImplementation(metadata, functionBinding); return new AggregationFunctionMetadata(aggregationFunction.isOrderSensitive(), Optional.of(implementation.getIntermediateType().getTypeSignature())); } - public WindowFunctionSupplier getWindowFunctionImplementation(Metadata metadata, ResolvedFunction resolvedFunction) + public WindowFunctionSupplier getWindowFunctionImplementation(Metadata metadata, FunctionBinding functionBinding) { - SpecializedFunctionKey key = getSpecializedFunctionKey(metadata, resolvedFunction); + SqlFunction function = functions.get(functionBinding.getFunctionId()); try { - if (key.getFunction() instanceof SqlAggregationFunction) { - return supplier(key.getFunction().getFunctionMetadata().getSignature(), specializedAggregationCache.get(key, () -> specializedAggregation(metadata, key))); + if (function instanceof SqlAggregationFunction) { + InternalAggregationFunction aggregationFunction = specializedAggregationCache.get(functionBinding, () -> specializedAggregation(metadata, functionBinding)); + return supplier(function.getFunctionMetadata().getSignature(), aggregationFunction); } - return specializedWindowCache.get(key, () -> specializeWindow(metadata, key)); + return specializedWindowCache.get(functionBinding, () -> specializeWindow(metadata, functionBinding)); } catch (ExecutionException | UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), PrestoException.class); @@ -827,17 +827,19 @@ public WindowFunctionSupplier getWindowFunctionImplementation(Metadata metadata, } } - private static WindowFunctionSupplier specializeWindow(Metadata metadata, SpecializedFunctionKey key) + private WindowFunctionSupplier specializeWindow(Metadata metadata, FunctionBinding functionBinding) { - return ((SqlWindowFunction) key.getFunction()) - .specialize(key.getBoundVariables(), key.getArity(), metadata); + SqlWindowFunction function = (SqlWindowFunction) functions.get(functionBinding.getFunctionId()); + return function.specialize( + new BoundVariables(functionBinding.getTypeVariables(), functionBinding.getLongVariables()), + functionBinding.getBoundSignature().getArgumentTypes().size(), + metadata); } - public InternalAggregationFunction getAggregateFunctionImplementation(Metadata metadata, ResolvedFunction resolvedFunction) + public InternalAggregationFunction getAggregateFunctionImplementation(Metadata metadata, FunctionBinding functionBinding) { - SpecializedFunctionKey key = getSpecializedFunctionKey(metadata, resolvedFunction); try { - return specializedAggregationCache.get(key, () -> specializedAggregation(metadata, key)); + return specializedAggregationCache.get(functionBinding, () -> specializedAggregation(metadata, functionBinding)); } catch (ExecutionException | UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), PrestoException.class); @@ -845,10 +847,13 @@ public InternalAggregationFunction getAggregateFunctionImplementation(Metadata m } } - private static InternalAggregationFunction specializedAggregation(Metadata metadata, SpecializedFunctionKey key) + private InternalAggregationFunction specializedAggregation(Metadata metadata, FunctionBinding functionBinding) { - SqlAggregationFunction function = (SqlAggregationFunction) key.getFunction(); - InternalAggregationFunction implementation = function.specialize(key.getBoundVariables(), key.getArity(), metadata); + SqlAggregationFunction function = (SqlAggregationFunction) functions.get(functionBinding.getFunctionId()); + InternalAggregationFunction implementation = function.specialize( + new BoundVariables(functionBinding.getTypeVariables(), functionBinding.getLongVariables()), + functionBinding.getBoundSignature().getArgumentTypes().size(), + metadata); checkArgument( function.isOrderSensitive() == implementation.isOrderSensitive(), "implementation order sensitivity doesn't match for: %s", @@ -860,25 +865,27 @@ private static InternalAggregationFunction specializedAggregation(Metadata metad return implementation; } - public FunctionInvoker getScalarFunctionInvoker(Metadata metadata, ResolvedFunction resolvedFunction, InvocationConvention invocationConvention) + public FunctionInvoker getScalarFunctionInvoker(Metadata metadata, FunctionBinding functionBinding, InvocationConvention invocationConvention) { - SpecializedFunctionKey key = getSpecializedFunctionKey(metadata, resolvedFunction); ScalarFunctionImplementation scalarFunctionImplementation; try { - scalarFunctionImplementation = specializedScalarCache.get(key, () -> specializeScalarFunction(metadata, key)); + scalarFunctionImplementation = specializedScalarCache.get(functionBinding, () -> specializeScalarFunction(metadata, functionBinding)); } catch (ExecutionException | UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), PrestoException.class); throw new RuntimeException(e.getCause()); } FunctionInvokerProvider functionInvokerProvider = new FunctionInvokerProvider(metadata); - return functionInvokerProvider.createFunctionInvoker(scalarFunctionImplementation, resolvedFunction.getSignature(), invocationConvention); + return functionInvokerProvider.createFunctionInvoker(scalarFunctionImplementation, functionBinding.getBoundSignature(), invocationConvention); } - private static ScalarFunctionImplementation specializeScalarFunction(Metadata metadata, SpecializedFunctionKey key) + private ScalarFunctionImplementation specializeScalarFunction(Metadata metadata, FunctionBinding functionBinding) { - SqlScalarFunction function = (SqlScalarFunction) key.getFunction(); - ScalarFunctionImplementation specialize = function.specialize(key.getBoundVariables(), key.getArity(), metadata); + SqlScalarFunction function = (SqlScalarFunction) functions.get(functionBinding.getFunctionId()); + ScalarFunctionImplementation specialize = function.specialize( + new BoundVariables(functionBinding.getTypeVariables(), functionBinding.getLongVariables()), + functionBinding.getBoundSignature().getArgumentTypes().size(), + metadata); FunctionMetadata functionMetadata = function.getFunctionMetadata(); for (ScalarImplementationChoice choice : specialize.getAllChoices()) { checkArgument(choice.isNullable() == functionMetadata.isNullable(), "choice nullability doesn't match for: " + functionMetadata.getSignature()); @@ -900,19 +907,6 @@ else if (argumentProperty.getNullConvention() != BLOCK_AND_POSITION) { return specialize; } - private SpecializedFunctionKey getSpecializedFunctionKey(Metadata metadata, ResolvedFunction resolvedFunction) - { - SqlFunction function = functions.get(resolvedFunction.getFunctionId()); - Signature signature = resolvedFunction.getSignature(); - BoundVariables boundVariables = new SignatureBinder(metadata, function.getFunctionMetadata().getSignature(), false) - .bindVariables(fromTypeSignatures(signature.getArgumentTypes()), signature.getReturnType()) - .orElseThrow(() -> new IllegalArgumentException("Could not extract bound variables")); - return new SpecializedFunctionKey( - function, - boundVariables, - signature.getArgumentTypes().size()); - } - private static class FunctionMap { private final Map functions; diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index bf805bf7de6c..e2370c093746 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -1561,26 +1561,26 @@ public FunctionMetadata getFunctionMetadata(ResolvedFunction resolvedFunction) @Override public AggregationFunctionMetadata getAggregationFunctionMetadata(ResolvedFunction resolvedFunction) { - return functions.getAggregationFunctionMetadata(this, resolvedFunction); + return functions.getAggregationFunctionMetadata(this, toFunctionBinding(resolvedFunction)); } @Override public WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction) { - return functions.getWindowFunctionImplementation(this, resolvedFunction); + return functions.getWindowFunctionImplementation(this, toFunctionBinding(resolvedFunction)); } @Override public InternalAggregationFunction getAggregateFunctionImplementation(ResolvedFunction resolvedFunction) { - return functions.getAggregateFunctionImplementation(this, resolvedFunction); + return functions.getAggregateFunctionImplementation(this, toFunctionBinding(resolvedFunction)); } @Override public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional invocationConvention) { InvocationConvention expectedConvention = invocationConvention.orElseGet(() -> getDefaultCallingConvention(resolvedFunction)); - return functions.getScalarFunctionInvoker(this, resolvedFunction, expectedConvention); + return functions.getScalarFunctionInvoker(this, toFunctionBinding(resolvedFunction), expectedConvention); } /** @@ -1602,6 +1602,15 @@ private InvocationConvention getDefaultCallingConvention(ResolvedFunction resolv false); } + private FunctionBinding toFunctionBinding(ResolvedFunction resolvedFunction) + { + return SignatureBinder.bindFunction( + this, + resolvedFunction.getFunctionId(), + functions.get(resolvedFunction.getFunctionId()).getSignature(), + resolvedFunction.getSignature()); + } + @Override public ProcedureRegistry getProcedureRegistry() { diff --git a/presto-main/src/main/java/io/prestosql/metadata/SignatureBinder.java b/presto-main/src/main/java/io/prestosql/metadata/SignatureBinder.java index 285eaf8eaba5..f36300f56d3d 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/SignatureBinder.java +++ b/presto-main/src/main/java/io/prestosql/metadata/SignatureBinder.java @@ -48,6 +48,7 @@ import static io.prestosql.metadata.SignatureBinder.RelationshipType.EXPLICIT_COERCION_FROM; import static io.prestosql.metadata.SignatureBinder.RelationshipType.EXPLICIT_COERCION_TO; import static io.prestosql.metadata.SignatureBinder.RelationshipType.IMPLICIT_COERCION; +import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes; import static io.prestosql.type.TypeCalculation.calculateLiteralValue; import static io.prestosql.type.TypeCoercion.isCovariantTypeBase; @@ -185,6 +186,18 @@ public static TypeSignature applyBoundVariables(TypeSignature typeSignature, Bou return new TypeSignature(baseType, parameters); } + public static FunctionBinding bindFunction(Metadata metadata, FunctionId functionId, Signature functionSignature, Signature boundSignature) + { + BoundVariables boundVariables = new SignatureBinder(metadata, functionSignature, false) + .bindVariables(fromTypeSignatures(boundSignature.getArgumentTypes()), boundSignature.getReturnType()) + .orElseThrow(() -> new IllegalArgumentException("Could not extract bound variables")); + return new FunctionBinding( + functionId, + boundSignature, + boundVariables.getTypeVariables(), + boundVariables.getLongVariables()); + } + /** * Example of not allowed literal variable usages across typeSignatures: *

    diff --git a/presto-main/src/main/java/io/prestosql/metadata/SpecializedFunctionKey.java b/presto-main/src/main/java/io/prestosql/metadata/SpecializedFunctionKey.java deleted file mode 100644 index 5011b407da75..000000000000 --- a/presto-main/src/main/java/io/prestosql/metadata/SpecializedFunctionKey.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.prestosql.metadata; - -import java.util.Objects; - -import static java.util.Objects.requireNonNull; - -public class SpecializedFunctionKey -{ - private final SqlFunction function; - private final BoundVariables boundVariables; - private final int arity; - - public SpecializedFunctionKey(SqlFunction function, BoundVariables boundVariables, int arity) - { - this.function = requireNonNull(function, "function is null"); - this.boundVariables = requireNonNull(boundVariables, "boundVariables is null"); - this.arity = arity; - } - - public SqlFunction getFunction() - { - return function; - } - - public BoundVariables getBoundVariables() - { - return boundVariables; - } - - public int getArity() - { - return arity; - } - - @Override - public boolean equals(Object o) - { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - SpecializedFunctionKey that = (SpecializedFunctionKey) o; - - return Objects.equals(boundVariables, that.boundVariables) && - Objects.equals(function.getFunctionMetadata().getFunctionId(), that.function.getFunctionMetadata().getFunctionId()) && - arity == that.arity; - } - - @Override - public int hashCode() - { - return Objects.hash(function.getFunctionMetadata().getFunctionId(), boundVariables, arity); - } -}