diff --git a/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java b/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java index 79290843cbd5..21254d4c8018 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java +++ b/presto-main/src/main/java/io/prestosql/metadata/FunctionRegistry.java @@ -793,7 +793,7 @@ public FunctionMetadata get(FunctionId functionId) return functions.get(functionId).getFunctionMetadata(); } - public AggregationFunctionMetadata getAggregationFunctionMetadata(Metadata metadata, FunctionBinding functionBinding) + public AggregationFunctionMetadata getAggregationFunctionMetadata(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { SqlFunction function = functions.get(functionBinding.getFunctionId()); checkArgument(function instanceof SqlAggregationFunction, "%s is not an aggregation function", functionBinding.getBoundSignature()); @@ -803,19 +803,19 @@ public AggregationFunctionMetadata getAggregationFunctionMetadata(Metadata metad return new AggregationFunctionMetadata(aggregationFunction.isOrderSensitive(), Optional.empty()); } - InternalAggregationFunction implementation = getAggregateFunctionImplementation(metadata, functionBinding); + InternalAggregationFunction implementation = getAggregateFunctionImplementation(functionBinding, functionDependencies); return new AggregationFunctionMetadata(aggregationFunction.isOrderSensitive(), Optional.of(implementation.getIntermediateType().getTypeSignature())); } - public WindowFunctionSupplier getWindowFunctionImplementation(Metadata metadata, FunctionBinding functionBinding) + public WindowFunctionSupplier getWindowFunctionImplementation(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { SqlFunction function = functions.get(functionBinding.getFunctionId()); try { if (function instanceof SqlAggregationFunction) { - InternalAggregationFunction aggregationFunction = specializedAggregationCache.get(functionBinding, () -> specializedAggregation(metadata, functionBinding)); + InternalAggregationFunction aggregationFunction = specializedAggregationCache.get(functionBinding, () -> specializedAggregation(functionBinding, functionDependencies)); return supplier(function.getFunctionMetadata().getSignature(), aggregationFunction); } - return specializedWindowCache.get(functionBinding, () -> specializeWindow(functionBinding)); + return specializedWindowCache.get(functionBinding, () -> specializeWindow(functionBinding, functionDependencies)); } catch (ExecutionException | UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), PrestoException.class); @@ -823,16 +823,16 @@ public WindowFunctionSupplier getWindowFunctionImplementation(Metadata metadata, } } - private WindowFunctionSupplier specializeWindow(FunctionBinding functionBinding) + private WindowFunctionSupplier specializeWindow(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { SqlWindowFunction function = (SqlWindowFunction) functions.get(functionBinding.getFunctionId()); - return function.specialize(functionBinding); + return function.specialize(functionBinding, functionDependencies); } - public InternalAggregationFunction getAggregateFunctionImplementation(Metadata metadata, FunctionBinding functionBinding) + public InternalAggregationFunction getAggregateFunctionImplementation(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { try { - return specializedAggregationCache.get(functionBinding, () -> specializedAggregation(metadata, functionBinding)); + return specializedAggregationCache.get(functionBinding, () -> specializedAggregation(functionBinding, functionDependencies)); } catch (ExecutionException | UncheckedExecutionException e) { throwIfInstanceOf(e.getCause(), PrestoException.class); @@ -840,10 +840,10 @@ public InternalAggregationFunction getAggregateFunctionImplementation(Metadata m } } - private InternalAggregationFunction specializedAggregation(Metadata metadata, FunctionBinding functionBinding) + private InternalAggregationFunction specializedAggregation(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { SqlAggregationFunction function = (SqlAggregationFunction) functions.get(functionBinding.getFunctionId()); - return function.specialize(functionBinding, metadata); + return function.specialize(functionBinding, functionDependencies); } public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding) diff --git a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java index 9c5438cd4b5a..f320a815d5e0 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java +++ b/presto-main/src/main/java/io/prestosql/metadata/MetadataManager.java @@ -1635,19 +1635,22 @@ public FunctionMetadata getFunctionMetadata(ResolvedFunction resolvedFunction) @Override public AggregationFunctionMetadata getAggregationFunctionMetadata(ResolvedFunction resolvedFunction) { - return functions.getAggregationFunctionMetadata(this, toFunctionBinding(resolvedFunction)); + FunctionDependencies functionDependencies = new FunctionDependencies(this, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies()); + return functions.getAggregationFunctionMetadata(toFunctionBinding(resolvedFunction), functionDependencies); } @Override public WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction) { - return functions.getWindowFunctionImplementation(this, toFunctionBinding(resolvedFunction)); + FunctionDependencies functionDependencies = new FunctionDependencies(this, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies()); + return functions.getWindowFunctionImplementation(toFunctionBinding(resolvedFunction), functionDependencies); } @Override public InternalAggregationFunction getAggregateFunctionImplementation(ResolvedFunction resolvedFunction) { - return functions.getAggregateFunctionImplementation(this, toFunctionBinding(resolvedFunction)); + FunctionDependencies functionDependencies = new FunctionDependencies(this, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies()); + return functions.getAggregateFunctionImplementation(toFunctionBinding(resolvedFunction), functionDependencies); } @Override diff --git a/presto-main/src/main/java/io/prestosql/metadata/SqlAggregationFunction.java b/presto-main/src/main/java/io/prestosql/metadata/SqlAggregationFunction.java index e33f22e81330..1d37cc16bf97 100644 --- a/presto-main/src/main/java/io/prestosql/metadata/SqlAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/metadata/SqlAggregationFunction.java @@ -67,7 +67,7 @@ public boolean isDecomposable() return decomposable; } - public InternalAggregationFunction specialize(FunctionBinding functionBinding, Metadata metadata) + public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { return specialize(functionBinding); } diff --git a/presto-main/src/main/java/io/prestosql/operator/ParametricFunctionHelpers.java b/presto-main/src/main/java/io/prestosql/operator/ParametricFunctionHelpers.java index 52624744e429..021b5718c2df 100644 --- a/presto-main/src/main/java/io/prestosql/operator/ParametricFunctionHelpers.java +++ b/presto-main/src/main/java/io/prestosql/operator/ParametricFunctionHelpers.java @@ -13,10 +13,8 @@ */ package io.prestosql.operator; -import io.prestosql.metadata.BoundVariables; import io.prestosql.metadata.FunctionBinding; import io.prestosql.metadata.FunctionDependencies; -import io.prestosql.metadata.Metadata; import io.prestosql.operator.annotations.ImplementationDependency; import java.lang.invoke.MethodHandle; @@ -27,14 +25,6 @@ public final class ParametricFunctionHelpers { private ParametricFunctionHelpers() {} - public static MethodHandle bindDependencies(MethodHandle handle, List dependencies, BoundVariables variables, Metadata metadata) - { - for (ImplementationDependency dependency : dependencies) { - handle = MethodHandles.insertArguments(handle, 0, dependency.resolve(variables, metadata)); - } - return handle; - } - public static MethodHandle bindDependencies(MethodHandle handle, List dependencies, FunctionBinding functionBinding, FunctionDependencies functionDependencies) { for (ImplementationDependency dependency : dependencies) { diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java index b33cf01b9a68..48793d05ed03 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/AbstractMinMaxAggregationFunction.java @@ -18,9 +18,9 @@ import io.prestosql.annotation.UsedByGeneratedCode; import io.prestosql.metadata.FunctionArgumentDefinition; import io.prestosql.metadata.FunctionBinding; +import io.prestosql.metadata.FunctionDependencies; +import io.prestosql.metadata.FunctionDependencyDeclaration; import io.prestosql.metadata.FunctionMetadata; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.ResolvedFunction; import io.prestosql.metadata.Signature; import io.prestosql.metadata.SqlAggregationFunction; import io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor; @@ -104,11 +104,19 @@ protected AbstractMinMaxAggregationFunction(String name, boolean min, String des } @Override - public InternalAggregationFunction specialize(FunctionBinding functionBinding, Metadata metadata) + public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding) { Type type = functionBinding.getTypeVariable("E"); - ResolvedFunction resolvedFunction = metadata.resolveOperator(operatorType, ImmutableList.of(type, type)); - MethodHandle compareMethodHandle = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle(); + return FunctionDependencyDeclaration.builder() + .addOperator(operatorType, ImmutableList.of(type, type)) + .build(); + } + + @Override + public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) + { + Type type = functionBinding.getTypeVariable("E"); + MethodHandle compareMethodHandle = functionDependencies.getOperatorInvoker(operatorType, ImmutableList.of(type, type), Optional.empty()).getMethodHandle(); return generateAggregation(type, compareMethodHandle); } diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/ParametricAggregation.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/ParametricAggregation.java index 8bf6ecca0bff..7225920ba83c 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/ParametricAggregation.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/ParametricAggregation.java @@ -17,10 +17,11 @@ import com.google.common.collect.ImmutableList; import io.airlift.bytecode.DynamicClassLoader; import io.prestosql.metadata.BoundSignature; -import io.prestosql.metadata.BoundVariables; import io.prestosql.metadata.FunctionBinding; +import io.prestosql.metadata.FunctionDependencies; +import io.prestosql.metadata.FunctionDependencyDeclaration; +import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.prestosql.metadata.FunctionMetadata; -import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; import io.prestosql.metadata.SqlAggregationFunction; import io.prestosql.operator.ParametricImplementationsGroup; @@ -28,6 +29,7 @@ import io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata; import io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType; import io.prestosql.operator.aggregation.state.StateCompiler; +import io.prestosql.operator.annotations.ImplementationDependency; import io.prestosql.spi.PrestoException; import io.prestosql.spi.function.AccumulatorStateFactory; import io.prestosql.spi.function.AccumulatorStateSerializer; @@ -35,6 +37,7 @@ import io.prestosql.spi.type.TypeSignature; import java.lang.invoke.MethodHandle; +import java.util.Collection; import java.util.List; import java.util.Optional; @@ -78,7 +81,32 @@ public ParametricAggregation( } @Override - public InternalAggregationFunction specialize(FunctionBinding functionBinding, Metadata metadata) + public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding) + { + FunctionDependencyDeclarationBuilder builder = FunctionDependencyDeclaration.builder(); + declareDependencies(functionBinding, builder, implementations.getExactImplementations().values()); + declareDependencies(functionBinding, builder, implementations.getSpecializedImplementations()); + declareDependencies(functionBinding, builder, implementations.getGenericImplementations()); + return builder.build(); + } + + private static void declareDependencies(FunctionBinding functionBinding, FunctionDependencyDeclarationBuilder builder, Collection implementations) + { + for (AggregationImplementation implementation : implementations) { + for (ImplementationDependency dependency : implementation.getInputDependencies()) { + dependency.declareDependencies(functionBinding, builder); + } + for (ImplementationDependency dependency : implementation.getCombineDependencies()) { + dependency.declareDependencies(functionBinding, builder); + } + for (ImplementationDependency dependency : implementation.getOutputDependencies()) { + dependency.declareDependencies(functionBinding, builder); + } + } + } + + @Override + public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { // Bind variables Signature signature = getFunctionMetadata().getSignature(); @@ -100,12 +128,11 @@ public InternalAggregationFunction specialize(FunctionBinding functionBinding, M AccumulatorStateFactory stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader); // Bind provided dependencies to aggregation method handlers - BoundVariables boundVariables = new BoundVariables(functionBinding.getTypeVariables(), functionBinding.getLongVariables()); - MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), boundVariables, metadata); + MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), functionBinding, functionDependencies); Optional removeInputHandle = concreteImplementation.getRemoveInputFunction().map( - removeInputFunction -> bindDependencies(removeInputFunction, concreteImplementation.getRemoveInputDependencies(), boundVariables, metadata)); - MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), boundVariables, metadata); - MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), boundVariables, metadata); + removeInputFunction -> bindDependencies(removeInputFunction, concreteImplementation.getRemoveInputDependencies(), functionBinding, functionDependencies)); + MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), functionBinding, functionDependencies); + MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), functionBinding, functionDependencies); // Build metadata of input parameters List parametersMetadata = buildParameterMetadata(concreteImplementation.getInputParameterMetadataTypes(), inputTypes); diff --git a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java index f39b47b1b85f..e2eb7ae089b7 100644 --- a/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java +++ b/presto-main/src/main/java/io/prestosql/operator/aggregation/minmaxby/AbstractMinMaxBy.java @@ -25,9 +25,9 @@ import io.airlift.bytecode.expression.BytecodeExpression; import io.prestosql.metadata.FunctionArgumentDefinition; import io.prestosql.metadata.FunctionBinding; +import io.prestosql.metadata.FunctionDependencies; +import io.prestosql.metadata.FunctionDependencyDeclaration; import io.prestosql.metadata.FunctionMetadata; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.ResolvedFunction; import io.prestosql.metadata.Signature; import io.prestosql.metadata.SqlAggregationFunction; import io.prestosql.operator.aggregation.AccumulatorCompiler; @@ -116,14 +116,23 @@ protected AbstractMinMaxBy(boolean min, String description) } @Override - public InternalAggregationFunction specialize(FunctionBinding functionBinding, Metadata metadata) + public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding) + { + Type keyType = functionBinding.getTypeVariable("K"); + return FunctionDependencyDeclaration.builder() + .addOperator(min ? LESS_THAN : GREATER_THAN, ImmutableList.of(keyType, keyType)) + .build(); + } + + @Override + public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { Type keyType = functionBinding.getTypeVariable("K"); Type valueType = functionBinding.getTypeVariable("V"); - return generateAggregation(valueType, keyType, metadata); + return generateAggregation(valueType, keyType, functionDependencies); } - private InternalAggregationFunction generateAggregation(Type valueType, Type keyType, Metadata metadata) + private InternalAggregationFunction generateAggregation(Type valueType, Type keyType, FunctionDependencies functionDependencies) { Class stateClazz = getStateClass(keyType.getJavaType(), valueType.getJavaType()); DynamicClassLoader classLoader = new DynamicClassLoader(getClass().getClassLoader()); @@ -157,8 +166,7 @@ private InternalAggregationFunction generateAggregation(Type valueType, Type key CallSiteBinder binder = new CallSiteBinder(); OperatorType operator = min ? LESS_THAN : GREATER_THAN; - ResolvedFunction resolvedFunction = metadata.resolveOperator(operator, ImmutableList.of(keyType, keyType)); - MethodHandle compareMethod = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle(); + MethodHandle compareMethod = functionDependencies.getOperatorInvoker(operator, ImmutableList.of(keyType, keyType), Optional.empty()).getMethodHandle(); ClassDefinition definition = new ClassDefinition( a(PUBLIC, FINAL), diff --git a/presto-main/src/main/java/io/prestosql/operator/annotations/CastImplementationDependency.java b/presto-main/src/main/java/io/prestosql/operator/annotations/CastImplementationDependency.java index e24e74781ab7..42653fc1971c 100644 --- a/presto-main/src/main/java/io/prestosql/operator/annotations/CastImplementationDependency.java +++ b/presto-main/src/main/java/io/prestosql/operator/annotations/CastImplementationDependency.java @@ -18,8 +18,6 @@ import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.prestosql.metadata.FunctionInvoker; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.ResolvedFunction; import io.prestosql.spi.function.InvocationConvention; import io.prestosql.spi.type.TypeSignature; @@ -61,14 +59,6 @@ public void declareDependencies(FunctionBinding functionBinding, FunctionDepende applyBoundVariables(toType, boundVariables)); } - @Override - protected ResolvedFunction getResolvedFunction(BoundVariables boundVariables, Metadata metadata) - { - return metadata.getCoercion( - metadata.getType(applyBoundVariables(fromType, boundVariables)), - metadata.getType(applyBoundVariables(toType, boundVariables))); - } - @Override protected FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, Optional invocationConvention) { diff --git a/presto-main/src/main/java/io/prestosql/operator/annotations/FunctionImplementationDependency.java b/presto-main/src/main/java/io/prestosql/operator/annotations/FunctionImplementationDependency.java index 08518069382d..d4dcd795984a 100644 --- a/presto-main/src/main/java/io/prestosql/operator/annotations/FunctionImplementationDependency.java +++ b/presto-main/src/main/java/io/prestosql/operator/annotations/FunctionImplementationDependency.java @@ -18,8 +18,6 @@ import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.prestosql.metadata.FunctionInvoker; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.ResolvedFunction; import io.prestosql.spi.function.InvocationConvention; import io.prestosql.spi.type.TypeSignature; import io.prestosql.sql.tree.QualifiedName; @@ -29,7 +27,6 @@ import java.util.Optional; import static io.prestosql.metadata.SignatureBinder.applyBoundVariables; -import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypeSignatures; import static java.util.Objects.requireNonNull; public final class FunctionImplementationDependency @@ -52,12 +49,6 @@ public void declareDependencies(FunctionBinding functionBinding, FunctionDepende builder.addFunctionSignature(name, applyBoundVariables(argumentTypes, boundVariables)); } - @Override - protected ResolvedFunction getResolvedFunction(BoundVariables boundVariables, Metadata metadata) - { - return metadata.resolveFunction(name, fromTypeSignatures(applyBoundVariables(argumentTypes, boundVariables))); - } - @Override protected FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, Optional invocationConvention) { diff --git a/presto-main/src/main/java/io/prestosql/operator/annotations/ImplementationDependency.java b/presto-main/src/main/java/io/prestosql/operator/annotations/ImplementationDependency.java index df29357fd139..d281b600ccc7 100644 --- a/presto-main/src/main/java/io/prestosql/operator/annotations/ImplementationDependency.java +++ b/presto-main/src/main/java/io/prestosql/operator/annotations/ImplementationDependency.java @@ -14,11 +14,9 @@ package io.prestosql.operator.annotations; import com.google.common.collect.ImmutableSet; -import io.prestosql.metadata.BoundVariables; import io.prestosql.metadata.FunctionBinding; import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.prestosql.metadata.Metadata; import io.prestosql.spi.function.CastDependency; import io.prestosql.spi.function.Convention; import io.prestosql.spi.function.FunctionDependency; @@ -50,8 +48,6 @@ public interface ImplementationDependency { void declareDependencies(FunctionBinding functionBinding, FunctionDependencyDeclarationBuilder builder); - Object resolve(BoundVariables boundVariables, Metadata metadata); - Object resolve(FunctionBinding functionBinding, FunctionDependencies functionDependencies); static boolean isImplementationDependencyAnnotation(Annotation annotation) diff --git a/presto-main/src/main/java/io/prestosql/operator/annotations/LiteralImplementationDependency.java b/presto-main/src/main/java/io/prestosql/operator/annotations/LiteralImplementationDependency.java index 89a7e34f1e38..c1b827ea7f59 100644 --- a/presto-main/src/main/java/io/prestosql/operator/annotations/LiteralImplementationDependency.java +++ b/presto-main/src/main/java/io/prestosql/operator/annotations/LiteralImplementationDependency.java @@ -13,11 +13,9 @@ */ package io.prestosql.operator.annotations; -import io.prestosql.metadata.BoundVariables; import io.prestosql.metadata.FunctionBinding; import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.prestosql.metadata.Metadata; import static java.util.Objects.requireNonNull; @@ -34,12 +32,6 @@ public LiteralImplementationDependency(String literalName) @Override public void declareDependencies(FunctionBinding functionBinding, FunctionDependencyDeclarationBuilder builder) {} - @Override - public Long resolve(BoundVariables boundVariables, Metadata metadata) - { - return boundVariables.getLongVariable(literalName); - } - @Override public Object resolve(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { diff --git a/presto-main/src/main/java/io/prestosql/operator/annotations/OperatorImplementationDependency.java b/presto-main/src/main/java/io/prestosql/operator/annotations/OperatorImplementationDependency.java index 8def9548ad3b..2d8a3cde1331 100644 --- a/presto-main/src/main/java/io/prestosql/operator/annotations/OperatorImplementationDependency.java +++ b/presto-main/src/main/java/io/prestosql/operator/annotations/OperatorImplementationDependency.java @@ -19,11 +19,8 @@ import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; import io.prestosql.metadata.FunctionInvoker; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.ResolvedFunction; import io.prestosql.spi.function.InvocationConvention; import io.prestosql.spi.function.OperatorType; -import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; import java.util.List; @@ -31,7 +28,6 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.metadata.SignatureBinder.applyBoundVariables; import static io.prestosql.spi.function.OperatorType.CAST; import static io.prestosql.spi.function.OperatorType.SATURATED_FLOOR_CAST; @@ -68,15 +64,6 @@ public void declareDependencies(FunctionBinding functionBinding, FunctionDepende builder.addOperatorSignature(operator, applyBoundVariables(argumentTypes, boundVariables)); } - @Override - protected ResolvedFunction getResolvedFunction(BoundVariables boundVariables, Metadata metadata) - { - List argumentTypes = applyBoundVariables(this.argumentTypes, boundVariables).stream() - .map(metadata::getType) - .collect(toImmutableList()); - return metadata.resolveOperator(operator, argumentTypes); - } - @Override protected FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, Optional invocationConvention) { diff --git a/presto-main/src/main/java/io/prestosql/operator/annotations/ScalarImplementationDependency.java b/presto-main/src/main/java/io/prestosql/operator/annotations/ScalarImplementationDependency.java index 5a85850d5194..f95c06c5b2f8 100644 --- a/presto-main/src/main/java/io/prestosql/operator/annotations/ScalarImplementationDependency.java +++ b/presto-main/src/main/java/io/prestosql/operator/annotations/ScalarImplementationDependency.java @@ -13,12 +13,9 @@ */ package io.prestosql.operator.annotations; -import io.prestosql.metadata.BoundVariables; import io.prestosql.metadata.FunctionBinding; import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.FunctionInvoker; -import io.prestosql.metadata.Metadata; -import io.prestosql.metadata.ResolvedFunction; import io.prestosql.spi.function.InvocationConvention; import java.lang.invoke.MethodHandle; @@ -39,17 +36,8 @@ protected ScalarImplementationDependency(Optional invocati } } - protected abstract ResolvedFunction getResolvedFunction(BoundVariables boundVariables, Metadata metadata); - protected abstract FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, Optional invocationConvention); - @Override - public MethodHandle resolve(BoundVariables boundVariables, Metadata metadata) - { - ResolvedFunction resolvedFunction = getResolvedFunction(boundVariables, metadata); - return metadata.getScalarFunctionInvoker(resolvedFunction, invocationConvention).getMethodHandle(); - } - @Override public MethodHandle resolve(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { diff --git a/presto-main/src/main/java/io/prestosql/operator/annotations/TypeImplementationDependency.java b/presto-main/src/main/java/io/prestosql/operator/annotations/TypeImplementationDependency.java index 3db8b8f7de7e..0b9b0d330e20 100644 --- a/presto-main/src/main/java/io/prestosql/operator/annotations/TypeImplementationDependency.java +++ b/presto-main/src/main/java/io/prestosql/operator/annotations/TypeImplementationDependency.java @@ -18,8 +18,6 @@ import io.prestosql.metadata.FunctionBinding; import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder; -import io.prestosql.metadata.Metadata; -import io.prestosql.spi.type.Type; import io.prestosql.spi.type.TypeSignature; import java.util.Objects; @@ -45,12 +43,6 @@ public void declareDependencies(FunctionBinding functionBinding, FunctionDepende builder.addType(applyBoundVariables(signature, boundVariables)); } - @Override - public Type resolve(BoundVariables boundVariables, Metadata metadata) - { - return metadata.getType(applyBoundVariables(signature, boundVariables)); - } - @Override public Object resolve(FunctionBinding functionBinding, FunctionDependencies functionDependencies) { diff --git a/presto-main/src/main/java/io/prestosql/operator/window/SqlWindowFunction.java b/presto-main/src/main/java/io/prestosql/operator/window/SqlWindowFunction.java index 5c1c6f38c79c..df90c4bec514 100644 --- a/presto-main/src/main/java/io/prestosql/operator/window/SqlWindowFunction.java +++ b/presto-main/src/main/java/io/prestosql/operator/window/SqlWindowFunction.java @@ -15,6 +15,7 @@ import io.prestosql.metadata.FunctionArgumentDefinition; import io.prestosql.metadata.FunctionBinding; +import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.FunctionMetadata; import io.prestosql.metadata.Signature; import io.prestosql.metadata.SqlFunction; @@ -51,6 +52,11 @@ public FunctionMetadata getFunctionMetadata() return functionMetadata; } + public WindowFunctionSupplier specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies) + { + return specialize(functionBinding); + } + public WindowFunctionSupplier specialize(FunctionBinding functionBinding) { return supplier; diff --git a/presto-main/src/test/java/io/prestosql/operator/TestAnnotationEngineForAggregates.java b/presto-main/src/test/java/io/prestosql/operator/TestAnnotationEngineForAggregates.java index f6033cda171d..32acf02e1654 100644 --- a/presto-main/src/test/java/io/prestosql/operator/TestAnnotationEngineForAggregates.java +++ b/presto-main/src/test/java/io/prestosql/operator/TestAnnotationEngineForAggregates.java @@ -15,9 +15,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.prestosql.metadata.BoundSignature; import io.prestosql.metadata.FunctionBinding; +import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.LongVariableConstraint; import io.prestosql.metadata.Metadata; import io.prestosql.metadata.Signature; @@ -76,6 +78,7 @@ public class TestAnnotationEngineForAggregates extends TestAnnotationEngine { private static final Metadata METADATA = createTestMetadataManager(); + protected static final FunctionDependencies NO_FUNCTION_DEPENDENCIES = new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of()); @AggregationFunction("simple_exact_aggregate") @Description("Simple exact aggregate description") @@ -127,7 +130,7 @@ public void testSimpleExactAggregationParse() new BoundSignature(expectedSignature.getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE)), ImmutableMap.of(), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, NO_FUNCTION_DEPENDENCIES); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertEquals(specialized.name(), "simple_exact_aggregate"); } @@ -210,7 +213,7 @@ public void testNotAnnotatedAggregateStateAggregationParse() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE)), ImmutableMap.of(), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertEquals(specialized.name(), "no_aggregation_state_aggregate"); } @@ -263,7 +266,7 @@ public void testNotDecomposableAggregationParse() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE)), ImmutableMap.of(), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertEquals(specialized.name(), "custom_decomposable_aggregate"); } @@ -361,7 +364,7 @@ public void testSimpleGenericAggregationFunctionParse() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE)), ImmutableMap.of("T", DoubleType.DOUBLE), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertTrue(specialized.getParameterTypes().equals(ImmutableList.of(DoubleType.DOUBLE))); assertEquals(specialized.name(), "simple_generic_implementations"); @@ -424,7 +427,7 @@ public void testSimpleBlockInputAggregationParse() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE)), ImmutableMap.of(), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertEquals(specialized.name(), "block_input_aggregate"); } @@ -519,7 +522,7 @@ public void testSimpleImplicitSpecializedAggregationParse() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(new ArrayType(DoubleType.DOUBLE))), ImmutableMap.of("T", DoubleType.DOUBLE), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertEquals(specialized.name(), "implicit_specialized_aggregate"); } @@ -614,7 +617,7 @@ public void testSimpleExplicitSpecializedAggregationParse() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(new ArrayType(DoubleType.DOUBLE))), ImmutableMap.of("T", DoubleType.DOUBLE), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertEquals(specialized.name(), "implicit_specialized_aggregate"); } @@ -703,7 +706,7 @@ public void testMultiOutputAggregationParse() new BoundSignature(aggregation1.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE)), ImmutableMap.of(), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation1.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation1.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertEquals(specialized.name(), "multi_output_aggregate_1"); } @@ -772,7 +775,7 @@ public void testInjectOperatorAggregateParse() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE)), ImmutableMap.of(), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertEquals(specialized.name(), "inject_operator_aggregate"); } @@ -846,7 +849,7 @@ public void testInjectTypeAggregateParse() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), DoubleType.DOUBLE, ImmutableList.of(DoubleType.DOUBLE)), ImmutableMap.of("T", DoubleType.DOUBLE), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertEquals(specialized.name(), "inject_type_aggregate"); } @@ -917,7 +920,7 @@ public void testInjectLiteralAggregateParse() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), createVarcharType(17), ImmutableList.of(createVarcharType(17))), ImmutableMap.of(), ImmutableMap.of("x", 17L)); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), createVarcharType(17)); assertEquals(specialized.name(), "inject_literal_aggregate"); } @@ -991,7 +994,7 @@ public void testLongConstraintAggregateFunctionParse() .put("y", 13L) .put("z", 30L) .build()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), createVarcharType(30)); assertEquals(specialized.name(), "parametric_aggregate_long_constraint"); } @@ -1126,7 +1129,7 @@ public void testPartiallyFixedTypeParameterInjectionAggregateFunctionParse() .put("T2", DoubleType.DOUBLE) .build(), ImmutableMap.of()); - InternalAggregationFunction specialized = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction specialized = aggregation.specialize(functionBinding, new FunctionDependencies(METADATA, ImmutableSet.of(), ImmutableSet.of())); assertEquals(specialized.getFinalType(), DoubleType.DOUBLE); assertTrue(specialized.getParameterTypes().equals(ImmutableList.of(DoubleType.DOUBLE))); assertEquals(specialized.name(), "partially_fixed_type_parameter_injection"); diff --git a/presto-ml/src/test/java/io/prestosql/plugin/ml/TestLearnAggregations.java b/presto-ml/src/test/java/io/prestosql/plugin/ml/TestLearnAggregations.java index 95ffb810c7fb..c18970a1241a 100644 --- a/presto-ml/src/test/java/io/prestosql/plugin/ml/TestLearnAggregations.java +++ b/presto-ml/src/test/java/io/prestosql/plugin/ml/TestLearnAggregations.java @@ -15,10 +15,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; import io.prestosql.RowPageBuilder; import io.prestosql.metadata.BoundSignature; import io.prestosql.metadata.FunctionBinding; +import io.prestosql.metadata.FunctionDependencies; import io.prestosql.metadata.MetadataManager; import io.prestosql.operator.aggregation.Accumulator; import io.prestosql.operator.aggregation.InternalAggregationFunction; @@ -52,6 +54,7 @@ public class TestLearnAggregations { private static final MetadataManager METADATA = createTestMetadataManager(); + protected static final FunctionDependencies NO_FUNCTION_DEPENDENCIES = new FunctionDependencies(METADATA, ImmutableMap.of(), ImmutableSet.of()); static { METADATA.addParametricType(new ClassifierParametricType()); @@ -73,7 +76,7 @@ public void testLearn() new BoundSignature(aggregation.getFunctionMetadata().getSignature().getName(), BIGINT_CLASSIFIER, ImmutableList.of(BIGINT, mapType)), ImmutableMap.of(), ImmutableMap.of()); - InternalAggregationFunction aggregationFunction = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction aggregationFunction = aggregation.specialize(functionBinding, NO_FUNCTION_DEPENDENCIES); assertLearnClassifer(aggregationFunction.bind(ImmutableList.of(0, 1), Optional.empty()).createAccumulator()); } @@ -93,7 +96,7 @@ public void testLearnLibSvm() ImmutableList.of(BIGINT, mapType, VARCHAR)), ImmutableMap.of(), ImmutableMap.of("x", (long) Integer.MAX_VALUE)); - InternalAggregationFunction aggregationFunction = aggregation.specialize(functionBinding, METADATA); + InternalAggregationFunction aggregationFunction = aggregation.specialize(functionBinding, NO_FUNCTION_DEPENDENCIES); assertLearnClassifer(aggregationFunction.bind(ImmutableList.of(0, 1, 2), Optional.empty()).createAccumulator()); }