Skip to content

Commit

Permalink
Replace Metadata with FunctionDependencies in agg and window
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Jul 30, 2020
1 parent 45c96e7 commit d9da4f7
Show file tree
Hide file tree
Showing 17 changed files with 108 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ public FunctionMetadata get(FunctionId functionId)
return functions.get(functionId).getFunctionMetadata();
}

public AggregationFunctionMetadata getAggregationFunctionMetadata(Metadata metadata, FunctionBinding functionBinding)
public AggregationFunctionMetadata getAggregationFunctionMetadata(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
SqlFunction function = functions.get(functionBinding.getFunctionId());
checkArgument(function instanceof SqlAggregationFunction, "%s is not an aggregation function", functionBinding.getBoundSignature());
Expand All @@ -803,47 +803,47 @@ public AggregationFunctionMetadata getAggregationFunctionMetadata(Metadata metad
return new AggregationFunctionMetadata(aggregationFunction.isOrderSensitive(), Optional.empty());
}

InternalAggregationFunction implementation = getAggregateFunctionImplementation(metadata, functionBinding);
InternalAggregationFunction implementation = getAggregateFunctionImplementation(functionBinding, functionDependencies);
return new AggregationFunctionMetadata(aggregationFunction.isOrderSensitive(), Optional.of(implementation.getIntermediateType().getTypeSignature()));
}

public WindowFunctionSupplier getWindowFunctionImplementation(Metadata metadata, FunctionBinding functionBinding)
public WindowFunctionSupplier getWindowFunctionImplementation(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
SqlFunction function = functions.get(functionBinding.getFunctionId());
try {
if (function instanceof SqlAggregationFunction) {
InternalAggregationFunction aggregationFunction = specializedAggregationCache.get(functionBinding, () -> specializedAggregation(metadata, functionBinding));
InternalAggregationFunction aggregationFunction = specializedAggregationCache.get(functionBinding, () -> specializedAggregation(functionBinding, functionDependencies));
return supplier(function.getFunctionMetadata().getSignature(), aggregationFunction);
}
return specializedWindowCache.get(functionBinding, () -> specializeWindow(functionBinding));
return specializedWindowCache.get(functionBinding, () -> specializeWindow(functionBinding, functionDependencies));
}
catch (ExecutionException | UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), PrestoException.class);
throw new RuntimeException(e.getCause());
}
}

private WindowFunctionSupplier specializeWindow(FunctionBinding functionBinding)
private WindowFunctionSupplier specializeWindow(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
SqlWindowFunction function = (SqlWindowFunction) functions.get(functionBinding.getFunctionId());
return function.specialize(functionBinding);
return function.specialize(functionBinding, functionDependencies);
}

public InternalAggregationFunction getAggregateFunctionImplementation(Metadata metadata, FunctionBinding functionBinding)
public InternalAggregationFunction getAggregateFunctionImplementation(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
try {
return specializedAggregationCache.get(functionBinding, () -> specializedAggregation(metadata, functionBinding));
return specializedAggregationCache.get(functionBinding, () -> specializedAggregation(functionBinding, functionDependencies));
}
catch (ExecutionException | UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), PrestoException.class);
throw new RuntimeException(e.getCause());
}
}

private InternalAggregationFunction specializedAggregation(Metadata metadata, FunctionBinding functionBinding)
private InternalAggregationFunction specializedAggregation(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
SqlAggregationFunction function = (SqlAggregationFunction) functions.get(functionBinding.getFunctionId());
return function.specialize(functionBinding, metadata);
return function.specialize(functionBinding, functionDependencies);
}

public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1635,19 +1635,22 @@ public FunctionMetadata getFunctionMetadata(ResolvedFunction resolvedFunction)
@Override
public AggregationFunctionMetadata getAggregationFunctionMetadata(ResolvedFunction resolvedFunction)
{
return functions.getAggregationFunctionMetadata(this, toFunctionBinding(resolvedFunction));
FunctionDependencies functionDependencies = new FunctionDependencies(this, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies());
return functions.getAggregationFunctionMetadata(toFunctionBinding(resolvedFunction), functionDependencies);
}

@Override
public WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction)
{
return functions.getWindowFunctionImplementation(this, toFunctionBinding(resolvedFunction));
FunctionDependencies functionDependencies = new FunctionDependencies(this, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies());
return functions.getWindowFunctionImplementation(toFunctionBinding(resolvedFunction), functionDependencies);
}

@Override
public InternalAggregationFunction getAggregateFunctionImplementation(ResolvedFunction resolvedFunction)
{
return functions.getAggregateFunctionImplementation(this, toFunctionBinding(resolvedFunction));
FunctionDependencies functionDependencies = new FunctionDependencies(this, resolvedFunction.getTypeDependencies(), resolvedFunction.getFunctionDependencies());
return functions.getAggregateFunctionImplementation(toFunctionBinding(resolvedFunction), functionDependencies);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public boolean isDecomposable()
return decomposable;
}

public InternalAggregationFunction specialize(FunctionBinding functionBinding, Metadata metadata)
public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
return specialize(functionBinding);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,8 @@
*/
package io.prestosql.operator;

import io.prestosql.metadata.BoundVariables;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionDependencies;
import io.prestosql.metadata.Metadata;
import io.prestosql.operator.annotations.ImplementationDependency;

import java.lang.invoke.MethodHandle;
Expand All @@ -27,14 +25,6 @@ public final class ParametricFunctionHelpers
{
private ParametricFunctionHelpers() {}

public static MethodHandle bindDependencies(MethodHandle handle, List<ImplementationDependency> dependencies, BoundVariables variables, Metadata metadata)
{
for (ImplementationDependency dependency : dependencies) {
handle = MethodHandles.insertArguments(handle, 0, dependency.resolve(variables, metadata));
}
return handle;
}

public static MethodHandle bindDependencies(MethodHandle handle, List<ImplementationDependency> dependencies, FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
for (ImplementationDependency dependency : dependencies) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
import io.prestosql.annotation.UsedByGeneratedCode;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionDependencies;
import io.prestosql.metadata.FunctionDependencyDeclaration;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
Expand Down Expand Up @@ -104,11 +104,19 @@ protected AbstractMinMaxAggregationFunction(String name, boolean min, String des
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding, Metadata metadata)
public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding)
{
Type type = functionBinding.getTypeVariable("E");
ResolvedFunction resolvedFunction = metadata.resolveOperator(operatorType, ImmutableList.of(type, type));
MethodHandle compareMethodHandle = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle();
return FunctionDependencyDeclaration.builder()
.addOperator(operatorType, ImmutableList.of(type, type))
.build();
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
Type type = functionBinding.getTypeVariable("E");
MethodHandle compareMethodHandle = functionDependencies.getOperatorInvoker(operatorType, ImmutableList.of(type, type), Optional.empty()).getMethodHandle();
return generateAggregation(type, compareMethodHandle);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,27 @@
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.DynamicClassLoader;
import io.prestosql.metadata.BoundSignature;
import io.prestosql.metadata.BoundVariables;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionDependencies;
import io.prestosql.metadata.FunctionDependencyDeclaration;
import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.operator.ParametricImplementationsGroup;
import io.prestosql.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata;
import io.prestosql.operator.aggregation.AggregationMetadata.ParameterMetadata.ParameterType;
import io.prestosql.operator.aggregation.state.StateCompiler;
import io.prestosql.operator.annotations.ImplementationDependency;
import io.prestosql.spi.PrestoException;
import io.prestosql.spi.function.AccumulatorStateFactory;
import io.prestosql.spi.function.AccumulatorStateSerializer;
import io.prestosql.spi.type.Type;
import io.prestosql.spi.type.TypeSignature;

import java.lang.invoke.MethodHandle;
import java.util.Collection;
import java.util.List;
import java.util.Optional;

Expand Down Expand Up @@ -78,7 +81,32 @@ public ParametricAggregation(
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding, Metadata metadata)
public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding)
{
FunctionDependencyDeclarationBuilder builder = FunctionDependencyDeclaration.builder();
declareDependencies(functionBinding, builder, implementations.getExactImplementations().values());
declareDependencies(functionBinding, builder, implementations.getSpecializedImplementations());
declareDependencies(functionBinding, builder, implementations.getGenericImplementations());
return builder.build();
}

private static void declareDependencies(FunctionBinding functionBinding, FunctionDependencyDeclarationBuilder builder, Collection<AggregationImplementation> implementations)
{
for (AggregationImplementation implementation : implementations) {
for (ImplementationDependency dependency : implementation.getInputDependencies()) {
dependency.declareDependencies(functionBinding, builder);
}
for (ImplementationDependency dependency : implementation.getCombineDependencies()) {
dependency.declareDependencies(functionBinding, builder);
}
for (ImplementationDependency dependency : implementation.getOutputDependencies()) {
dependency.declareDependencies(functionBinding, builder);
}
}
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
// Bind variables
Signature signature = getFunctionMetadata().getSignature();
Expand All @@ -100,12 +128,11 @@ public InternalAggregationFunction specialize(FunctionBinding functionBinding, M
AccumulatorStateFactory<?> stateFactory = StateCompiler.generateStateFactory(stateClass, classLoader);

// Bind provided dependencies to aggregation method handlers
BoundVariables boundVariables = new BoundVariables(functionBinding.getTypeVariables(), functionBinding.getLongVariables());
MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), boundVariables, metadata);
MethodHandle inputHandle = bindDependencies(concreteImplementation.getInputFunction(), concreteImplementation.getInputDependencies(), functionBinding, functionDependencies);
Optional<MethodHandle> removeInputHandle = concreteImplementation.getRemoveInputFunction().map(
removeInputFunction -> bindDependencies(removeInputFunction, concreteImplementation.getRemoveInputDependencies(), boundVariables, metadata));
MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), boundVariables, metadata);
MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), boundVariables, metadata);
removeInputFunction -> bindDependencies(removeInputFunction, concreteImplementation.getRemoveInputDependencies(), functionBinding, functionDependencies));
MethodHandle combineHandle = bindDependencies(concreteImplementation.getCombineFunction(), concreteImplementation.getCombineDependencies(), functionBinding, functionDependencies);
MethodHandle outputHandle = bindDependencies(concreteImplementation.getOutputFunction(), concreteImplementation.getOutputDependencies(), functionBinding, functionDependencies);

// Build metadata of input parameters
List<ParameterMetadata> parametersMetadata = buildParameterMetadata(concreteImplementation.getInputParameterMetadataTypes(), inputTypes);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import io.airlift.bytecode.expression.BytecodeExpression;
import io.prestosql.metadata.FunctionArgumentDefinition;
import io.prestosql.metadata.FunctionBinding;
import io.prestosql.metadata.FunctionDependencies;
import io.prestosql.metadata.FunctionDependencyDeclaration;
import io.prestosql.metadata.FunctionMetadata;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.metadata.Signature;
import io.prestosql.metadata.SqlAggregationFunction;
import io.prestosql.operator.aggregation.AccumulatorCompiler;
Expand Down Expand Up @@ -116,14 +116,23 @@ protected AbstractMinMaxBy(boolean min, String description)
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding, Metadata metadata)
public FunctionDependencyDeclaration getFunctionDependencies(FunctionBinding functionBinding)
{
Type keyType = functionBinding.getTypeVariable("K");
return FunctionDependencyDeclaration.builder()
.addOperator(min ? LESS_THAN : GREATER_THAN, ImmutableList.of(keyType, keyType))
.build();
}

@Override
public InternalAggregationFunction specialize(FunctionBinding functionBinding, FunctionDependencies functionDependencies)
{
Type keyType = functionBinding.getTypeVariable("K");
Type valueType = functionBinding.getTypeVariable("V");
return generateAggregation(valueType, keyType, metadata);
return generateAggregation(valueType, keyType, functionDependencies);
}

private InternalAggregationFunction generateAggregation(Type valueType, Type keyType, Metadata metadata)
private InternalAggregationFunction generateAggregation(Type valueType, Type keyType, FunctionDependencies functionDependencies)
{
Class<?> stateClazz = getStateClass(keyType.getJavaType(), valueType.getJavaType());
DynamicClassLoader classLoader = new DynamicClassLoader(getClass().getClassLoader());
Expand Down Expand Up @@ -157,8 +166,7 @@ private InternalAggregationFunction generateAggregation(Type valueType, Type key

CallSiteBinder binder = new CallSiteBinder();
OperatorType operator = min ? LESS_THAN : GREATER_THAN;
ResolvedFunction resolvedFunction = metadata.resolveOperator(operator, ImmutableList.of(keyType, keyType));
MethodHandle compareMethod = metadata.getScalarFunctionInvoker(resolvedFunction, Optional.empty()).getMethodHandle();
MethodHandle compareMethod = functionDependencies.getOperatorInvoker(operator, ImmutableList.of(keyType, keyType), Optional.empty()).getMethodHandle();

ClassDefinition definition = new ClassDefinition(
a(PUBLIC, FINAL),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import io.prestosql.metadata.FunctionDependencies;
import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder;
import io.prestosql.metadata.FunctionInvoker;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.TypeSignature;

Expand Down Expand Up @@ -61,14 +59,6 @@ public void declareDependencies(FunctionBinding functionBinding, FunctionDepende
applyBoundVariables(toType, boundVariables));
}

@Override
protected ResolvedFunction getResolvedFunction(BoundVariables boundVariables, Metadata metadata)
{
return metadata.getCoercion(
metadata.getType(applyBoundVariables(fromType, boundVariables)),
metadata.getType(applyBoundVariables(toType, boundVariables)));
}

@Override
protected FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, Optional<InvocationConvention> invocationConvention)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
import io.prestosql.metadata.FunctionDependencies;
import io.prestosql.metadata.FunctionDependencyDeclaration.FunctionDependencyDeclarationBuilder;
import io.prestosql.metadata.FunctionInvoker;
import io.prestosql.metadata.Metadata;
import io.prestosql.metadata.ResolvedFunction;
import io.prestosql.spi.function.InvocationConvention;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.sql.tree.QualifiedName;
Expand All @@ -29,7 +27,6 @@
import java.util.Optional;

import static io.prestosql.metadata.SignatureBinder.applyBoundVariables;
import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypeSignatures;
import static java.util.Objects.requireNonNull;

public final class FunctionImplementationDependency
Expand All @@ -52,12 +49,6 @@ public void declareDependencies(FunctionBinding functionBinding, FunctionDepende
builder.addFunctionSignature(name, applyBoundVariables(argumentTypes, boundVariables));
}

@Override
protected ResolvedFunction getResolvedFunction(BoundVariables boundVariables, Metadata metadata)
{
return metadata.resolveFunction(name, fromTypeSignatures(applyBoundVariables(argumentTypes, boundVariables)));
}

@Override
protected FunctionInvoker getInvoker(FunctionBinding functionBinding, FunctionDependencies functionDependencies, Optional<InvocationConvention> invocationConvention)
{
Expand Down
Loading

0 comments on commit d9da4f7

Please sign in to comment.