Skip to content

Commit

Permalink
Add FunctionBinding
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Jul 30, 2020
1 parent da28161 commit 4336c25
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ public BoundVariables(Map<String, Type> typeVariables, Map<String, Long> longVar
.collect(toImmutableSortedMap(CASE_INSENSITIVE_ORDER, Map.Entry::getKey, Map.Entry::getValue));
}

public Map<String, Type> getTypeVariables()
{
return typeVariables;
}

public Type getTypeVariable(String variableName)
{
return getValue(typeVariables, variableName);
Expand All @@ -53,6 +58,11 @@ public boolean containsTypeVariable(String variableName)
return containsValue(typeVariables, variableName);
}

public Map<String, Long> getLongVariables()
{
return longVariables;
}

public Long getLongVariable(String variableName)
{
return getValue(longVariables, variableName);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.prestosql.metadata;

import com.google.common.collect.ImmutableSortedMap;
import io.prestosql.spi.type.Type;

import java.util.Map;
import java.util.Objects;

import static java.lang.String.CASE_INSENSITIVE_ORDER;
import static java.util.Objects.requireNonNull;

public class FunctionBinding
{
private final FunctionId functionId;
private final Signature boundSignature;
private final Map<String, Type> typeVariables;
private final Map<String, Long> longVariables;

public FunctionBinding(FunctionId functionId, Signature boundSignature, Map<String, Type> typeVariables, Map<String, Long> longVariables)
{
this.functionId = requireNonNull(functionId, "functionId is null");
this.boundSignature = requireNonNull(boundSignature, "boundSignature is null");
this.typeVariables = ImmutableSortedMap.copyOf(requireNonNull(typeVariables, "typeVariables is null"), CASE_INSENSITIVE_ORDER);
this.longVariables = ImmutableSortedMap.copyOf(requireNonNull(longVariables, "longVariables is null"), CASE_INSENSITIVE_ORDER);
}

public FunctionId getFunctionId()
{
return functionId;
}

public Signature getBoundSignature()
{
return boundSignature;
}

public Map<String, Type> getTypeVariables()
{
return typeVariables;
}

public Map<String, Long> getLongVariables()
{
return longVariables;
}

@Override
public boolean equals(Object o)
{
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FunctionBinding that = (FunctionBinding) o;
return Objects.equals(functionId, that.functionId) &&
Objects.equals(boundSignature, that.boundSignature) &&
Objects.equals(typeVariables, that.typeVariables) &&
Objects.equals(longVariables, that.longVariables);
}

@Override
public int hashCode()
{
return Objects.hash(functionId, boundSignature, typeVariables, longVariables);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@
import static io.prestosql.operator.scalar.ZipFunction.ZIP_FUNCTIONS;
import static io.prestosql.operator.scalar.ZipWithFunction.ZIP_WITH_FUNCTION;
import static io.prestosql.operator.window.AggregateWindowFunction.supplier;
import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypeSignatures;
import static io.prestosql.type.DecimalCasts.BIGINT_TO_DECIMAL_CAST;
import static io.prestosql.type.DecimalCasts.BOOLEAN_TO_DECIMAL_CAST;
import static io.prestosql.type.DecimalCasts.DECIMAL_TO_BIGINT_CAST;
Expand Down Expand Up @@ -379,9 +378,9 @@
@ThreadSafe
public class FunctionRegistry
{
private final Cache<SpecializedFunctionKey, ScalarFunctionImplementation> specializedScalarCache;
private final Cache<SpecializedFunctionKey, InternalAggregationFunction> specializedAggregationCache;
private final Cache<SpecializedFunctionKey, WindowFunctionSupplier> specializedWindowCache;
private final Cache<FunctionBinding, ScalarFunctionImplementation> specializedScalarCache;
private final Cache<FunctionBinding, InternalAggregationFunction> specializedAggregationCache;
private final Cache<FunctionBinding, WindowFunctionSupplier> specializedWindowCache;
private volatile FunctionMap functions = new FunctionMap();

public FunctionRegistry(Metadata metadata, FeaturesConfig featuresConfig)
Expand Down Expand Up @@ -798,57 +797,63 @@ public FunctionMetadata get(FunctionId functionId)
return functions.get(functionId).getFunctionMetadata();
}

public AggregationFunctionMetadata getAggregationFunctionMetadata(Metadata metadata, ResolvedFunction resolvedFunction)
public AggregationFunctionMetadata getAggregationFunctionMetadata(Metadata metadata, FunctionBinding functionBinding)
{
SqlFunction function = functions.get(resolvedFunction.getFunctionId());
checkArgument(function instanceof SqlAggregationFunction, "%s is not an aggregation function", resolvedFunction);
SqlFunction function = functions.get(functionBinding.getFunctionId());
checkArgument(function instanceof SqlAggregationFunction, "%s is not an aggregation function", functionBinding.getBoundSignature());

SqlAggregationFunction aggregationFunction = (SqlAggregationFunction) function;
if (!aggregationFunction.isDecomposable()) {
return new AggregationFunctionMetadata(aggregationFunction.isOrderSensitive(), Optional.empty());
}

InternalAggregationFunction implementation = getAggregateFunctionImplementation(metadata, resolvedFunction);
InternalAggregationFunction implementation = getAggregateFunctionImplementation(metadata, functionBinding);
return new AggregationFunctionMetadata(aggregationFunction.isOrderSensitive(), Optional.of(implementation.getIntermediateType().getTypeSignature()));
}

public WindowFunctionSupplier getWindowFunctionImplementation(Metadata metadata, ResolvedFunction resolvedFunction)
public WindowFunctionSupplier getWindowFunctionImplementation(Metadata metadata, FunctionBinding functionBinding)
{
SpecializedFunctionKey key = getSpecializedFunctionKey(metadata, resolvedFunction);
SqlFunction function = functions.get(functionBinding.getFunctionId());
try {
if (key.getFunction() instanceof SqlAggregationFunction) {
return supplier(key.getFunction().getFunctionMetadata().getSignature(), specializedAggregationCache.get(key, () -> specializedAggregation(metadata, key)));
if (function instanceof SqlAggregationFunction) {
InternalAggregationFunction aggregationFunction = specializedAggregationCache.get(functionBinding, () -> specializedAggregation(metadata, functionBinding));
return supplier(function.getFunctionMetadata().getSignature(), aggregationFunction);
}
return specializedWindowCache.get(key, () -> specializeWindow(metadata, key));
return specializedWindowCache.get(functionBinding, () -> specializeWindow(metadata, functionBinding));
}
catch (ExecutionException | UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), PrestoException.class);
throw new RuntimeException(e.getCause());
}
}

private static WindowFunctionSupplier specializeWindow(Metadata metadata, SpecializedFunctionKey key)
private WindowFunctionSupplier specializeWindow(Metadata metadata, FunctionBinding functionBinding)
{
return ((SqlWindowFunction) key.getFunction())
.specialize(key.getBoundVariables(), key.getArity(), metadata);
SqlWindowFunction function = (SqlWindowFunction) functions.get(functionBinding.getFunctionId());
return function.specialize(
new BoundVariables(functionBinding.getTypeVariables(), functionBinding.getLongVariables()),
functionBinding.getBoundSignature().getArgumentTypes().size(),
metadata);
}

public InternalAggregationFunction getAggregateFunctionImplementation(Metadata metadata, ResolvedFunction resolvedFunction)
public InternalAggregationFunction getAggregateFunctionImplementation(Metadata metadata, FunctionBinding functionBinding)
{
SpecializedFunctionKey key = getSpecializedFunctionKey(metadata, resolvedFunction);
try {
return specializedAggregationCache.get(key, () -> specializedAggregation(metadata, key));
return specializedAggregationCache.get(functionBinding, () -> specializedAggregation(metadata, functionBinding));
}
catch (ExecutionException | UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), PrestoException.class);
throw new RuntimeException(e.getCause());
}
}

private static InternalAggregationFunction specializedAggregation(Metadata metadata, SpecializedFunctionKey key)
private InternalAggregationFunction specializedAggregation(Metadata metadata, FunctionBinding functionBinding)
{
SqlAggregationFunction function = (SqlAggregationFunction) key.getFunction();
InternalAggregationFunction implementation = function.specialize(key.getBoundVariables(), key.getArity(), metadata);
SqlAggregationFunction function = (SqlAggregationFunction) functions.get(functionBinding.getFunctionId());
InternalAggregationFunction implementation = function.specialize(
new BoundVariables(functionBinding.getTypeVariables(), functionBinding.getLongVariables()),
functionBinding.getBoundSignature().getArgumentTypes().size(),
metadata);
checkArgument(
function.isOrderSensitive() == implementation.isOrderSensitive(),
"implementation order sensitivity doesn't match for: %s",
Expand All @@ -860,25 +865,27 @@ private static InternalAggregationFunction specializedAggregation(Metadata metad
return implementation;
}

public FunctionInvoker getScalarFunctionInvoker(Metadata metadata, ResolvedFunction resolvedFunction, InvocationConvention invocationConvention)
public FunctionInvoker getScalarFunctionInvoker(Metadata metadata, FunctionBinding functionBinding, InvocationConvention invocationConvention)
{
SpecializedFunctionKey key = getSpecializedFunctionKey(metadata, resolvedFunction);
ScalarFunctionImplementation scalarFunctionImplementation;
try {
scalarFunctionImplementation = specializedScalarCache.get(key, () -> specializeScalarFunction(metadata, key));
scalarFunctionImplementation = specializedScalarCache.get(functionBinding, () -> specializeScalarFunction(metadata, functionBinding));
}
catch (ExecutionException | UncheckedExecutionException e) {
throwIfInstanceOf(e.getCause(), PrestoException.class);
throw new RuntimeException(e.getCause());
}
FunctionInvokerProvider functionInvokerProvider = new FunctionInvokerProvider(metadata);
return functionInvokerProvider.createFunctionInvoker(scalarFunctionImplementation, resolvedFunction.getSignature(), invocationConvention);
return functionInvokerProvider.createFunctionInvoker(scalarFunctionImplementation, functionBinding.getBoundSignature(), invocationConvention);
}

private static ScalarFunctionImplementation specializeScalarFunction(Metadata metadata, SpecializedFunctionKey key)
private ScalarFunctionImplementation specializeScalarFunction(Metadata metadata, FunctionBinding functionBinding)
{
SqlScalarFunction function = (SqlScalarFunction) key.getFunction();
ScalarFunctionImplementation specialize = function.specialize(key.getBoundVariables(), key.getArity(), metadata);
SqlScalarFunction function = (SqlScalarFunction) functions.get(functionBinding.getFunctionId());
ScalarFunctionImplementation specialize = function.specialize(
new BoundVariables(functionBinding.getTypeVariables(), functionBinding.getLongVariables()),
functionBinding.getBoundSignature().getArgumentTypes().size(),
metadata);
FunctionMetadata functionMetadata = function.getFunctionMetadata();
for (ScalarImplementationChoice choice : specialize.getAllChoices()) {
checkArgument(choice.isNullable() == functionMetadata.isNullable(), "choice nullability doesn't match for: " + functionMetadata.getSignature());
Expand All @@ -900,19 +907,6 @@ else if (argumentProperty.getNullConvention() != BLOCK_AND_POSITION) {
return specialize;
}

private SpecializedFunctionKey getSpecializedFunctionKey(Metadata metadata, ResolvedFunction resolvedFunction)
{
SqlFunction function = functions.get(resolvedFunction.getFunctionId());
Signature signature = resolvedFunction.getSignature();
BoundVariables boundVariables = new SignatureBinder(metadata, function.getFunctionMetadata().getSignature(), false)
.bindVariables(fromTypeSignatures(signature.getArgumentTypes()), signature.getReturnType())
.orElseThrow(() -> new IllegalArgumentException("Could not extract bound variables"));
return new SpecializedFunctionKey(
function,
boundVariables,
signature.getArgumentTypes().size());
}

private static class FunctionMap
{
private final Map<FunctionId, SqlFunction> functions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1561,26 +1561,26 @@ public FunctionMetadata getFunctionMetadata(ResolvedFunction resolvedFunction)
@Override
public AggregationFunctionMetadata getAggregationFunctionMetadata(ResolvedFunction resolvedFunction)
{
return functions.getAggregationFunctionMetadata(this, resolvedFunction);
return functions.getAggregationFunctionMetadata(this, toFunctionBinding(resolvedFunction));
}

@Override
public WindowFunctionSupplier getWindowFunctionImplementation(ResolvedFunction resolvedFunction)
{
return functions.getWindowFunctionImplementation(this, resolvedFunction);
return functions.getWindowFunctionImplementation(this, toFunctionBinding(resolvedFunction));
}

@Override
public InternalAggregationFunction getAggregateFunctionImplementation(ResolvedFunction resolvedFunction)
{
return functions.getAggregateFunctionImplementation(this, resolvedFunction);
return functions.getAggregateFunctionImplementation(this, toFunctionBinding(resolvedFunction));
}

@Override
public FunctionInvoker getScalarFunctionInvoker(ResolvedFunction resolvedFunction, Optional<InvocationConvention> invocationConvention)
{
InvocationConvention expectedConvention = invocationConvention.orElseGet(() -> getDefaultCallingConvention(resolvedFunction));
return functions.getScalarFunctionInvoker(this, resolvedFunction, expectedConvention);
return functions.getScalarFunctionInvoker(this, toFunctionBinding(resolvedFunction), expectedConvention);
}

/**
Expand All @@ -1602,6 +1602,15 @@ private InvocationConvention getDefaultCallingConvention(ResolvedFunction resolv
false);
}

private FunctionBinding toFunctionBinding(ResolvedFunction resolvedFunction)
{
return SignatureBinder.bindFunction(
this,
resolvedFunction.getFunctionId(),
functions.get(resolvedFunction.getFunctionId()).getSignature(),
resolvedFunction.getSignature());
}

@Override
public ProcedureRegistry getProcedureRegistry()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import static io.prestosql.metadata.SignatureBinder.RelationshipType.EXPLICIT_COERCION_FROM;
import static io.prestosql.metadata.SignatureBinder.RelationshipType.EXPLICIT_COERCION_TO;
import static io.prestosql.metadata.SignatureBinder.RelationshipType.IMPLICIT_COERCION;
import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypeSignatures;
import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes;
import static io.prestosql.type.TypeCalculation.calculateLiteralValue;
import static io.prestosql.type.TypeCoercion.isCovariantTypeBase;
Expand Down Expand Up @@ -185,6 +186,18 @@ public static TypeSignature applyBoundVariables(TypeSignature typeSignature, Bou
return new TypeSignature(baseType, parameters);
}

public static FunctionBinding bindFunction(Metadata metadata, FunctionId functionId, Signature functionSignature, Signature boundSignature)
{
BoundVariables boundVariables = new SignatureBinder(metadata, functionSignature, false)
.bindVariables(fromTypeSignatures(boundSignature.getArgumentTypes()), boundSignature.getReturnType())
.orElseThrow(() -> new IllegalArgumentException("Could not extract bound variables"));
return new FunctionBinding(
functionId,
boundSignature,
boundVariables.getTypeVariables(),
boundVariables.getLongVariables());
}

/**
* Example of not allowed literal variable usages across typeSignatures:
* <p><ul>
Expand Down
Loading

0 comments on commit 4336c25

Please sign in to comment.