Skip to content

Commit

Permalink
Convert histogram aggregation to annotated function
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed May 27, 2022
1 parent dc516cf commit dc45401
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 107 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ public static FunctionBundle create(FeaturesConfig featuresConfig, TypeOperators
.functions(DECIMAL_TO_INTEGER_SATURATED_FLOOR_CAST, INTEGER_TO_DECIMAL_SATURATED_FLOOR_CAST)
.functions(DECIMAL_TO_SMALLINT_SATURATED_FLOOR_CAST, SMALLINT_TO_DECIMAL_SATURATED_FLOOR_CAST)
.functions(DECIMAL_TO_TINYINT_SATURATED_FLOOR_CAST, TINYINT_TO_DECIMAL_SATURATED_FLOOR_CAST)
.function(new Histogram(blockTypeOperators))
.aggregates(Histogram.class)
.aggregates(ChecksumAggregationFunction.class)
.aggregates(ArbitraryAggregationFunction.class)
.functions(GREATEST, LEAST)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ public class GroupedHistogramState
private TypedHistogram typedHistogram;
private long size;

public GroupedHistogramState(Type keyType, BlockPositionEqual equalOperator, BlockPositionHashCode hashCodeOperator, int expectedEntriesCount)
public GroupedHistogramState(Type type, BlockPositionEqual equalOperator, BlockPositionHashCode hashCodeOperator, int expectedEntriesCount)
{
this.type = requireNonNull(keyType, "keyType is null");
this.type = requireNonNull(type, "type is null");
this.equalOperator = requireNonNull(equalOperator, "equalOperator is null");
this.hashCodeOperator = requireNonNull(hashCodeOperator, "hashCodeOperator is null");
typedHistogram = new GroupedTypedHistogram(keyType, equalOperator, hashCodeOperator, expectedEntriesCount);
typedHistogram = new GroupedTypedHistogram(type, equalOperator, hashCodeOperator, expectedEntriesCount);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,90 +13,44 @@
*/
package io.trino.operator.aggregation.histogram;

import com.google.common.collect.ImmutableList;
import io.trino.metadata.AggregationFunctionMetadata;
import io.trino.metadata.BoundSignature;
import io.trino.metadata.FunctionMetadata;
import io.trino.metadata.Signature;
import io.trino.metadata.SqlAggregationFunction;
import io.trino.operator.aggregation.AggregationMetadata;
import io.trino.operator.aggregation.AggregationMetadata.AccumulatorStateDescriptor;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AggregationFunction;
import io.trino.spi.function.AggregationState;
import io.trino.spi.function.BlockIndex;
import io.trino.spi.function.BlockPosition;
import io.trino.spi.function.CombineFunction;
import io.trino.spi.function.Description;
import io.trino.spi.function.InputFunction;
import io.trino.spi.function.OutputFunction;
import io.trino.spi.function.SqlType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeSignature;
import io.trino.type.BlockTypeOperators;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;

import java.lang.invoke.MethodHandle;
import java.util.Optional;

import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.TypeSignature.mapType;
import static io.trino.util.Reflection.methodHandle;
import static java.util.Objects.requireNonNull;

public class Histogram
extends SqlAggregationFunction
@AggregationFunction("histogram")
@Description("Count the number of times each value occurs")
public final class Histogram
{
public static final String NAME = "histogram";
private static final MethodHandle OUTPUT_FUNCTION = methodHandle(Histogram.class, "output", Type.class, HistogramState.class, BlockBuilder.class);
private static final MethodHandle INPUT_FUNCTION = methodHandle(Histogram.class, "input", Type.class, HistogramState.class, Block.class, int.class);
private static final MethodHandle COMBINE_FUNCTION = methodHandle(Histogram.class, "combine", HistogramState.class, HistogramState.class);

public static final int EXPECTED_SIZE_FOR_HASHING = 10;
private final BlockTypeOperators blockTypeOperators;

public Histogram(BlockTypeOperators blockTypeOperators)
{
super(
FunctionMetadata.aggregateBuilder()
.signature(Signature.builder()
.name(NAME)
.comparableTypeParameter("K")
.returnType(mapType(new TypeSignature("K"), BIGINT.getTypeSignature()))
.argumentType(new TypeSignature("K"))
.build())
.description("Count the number of times each value occurs")
.build(),
AggregationFunctionMetadata.builder()
.intermediateType(mapType(new TypeSignature("K"), BIGINT.getTypeSignature()))
.build());
this.blockTypeOperators = blockTypeOperators;
}

@Override
public AggregationMetadata specialize(BoundSignature boundSignature)
{
Type keyType = boundSignature.getArgumentTypes().get(0);
BlockPositionEqual keyEqual = blockTypeOperators.getEqualOperator(keyType);
BlockPositionHashCode keyHashCode = blockTypeOperators.getHashCodeOperator(keyType);
Type outputType = boundSignature.getReturnType();
HistogramStateSerializer stateSerializer = new HistogramStateSerializer(outputType);
MethodHandle inputFunction = INPUT_FUNCTION.bindTo(keyType);
MethodHandle outputFunction = OUTPUT_FUNCTION.bindTo(outputType);

return new AggregationMetadata(
inputFunction,
Optional.empty(),
Optional.of(COMBINE_FUNCTION),
outputFunction,
ImmutableList.of(new AccumulatorStateDescriptor<>(
HistogramState.class,
stateSerializer,
new HistogramStateFactory(keyType, keyEqual, keyHashCode, EXPECTED_SIZE_FOR_HASHING))));
}
private Histogram() {}

public static void input(Type type, HistogramState state, Block key, int position)
@InputFunction
@TypeParameter("T")
public static void input(
@TypeParameter("T") Type type,
@AggregationState("T") HistogramState state,
@BlockPosition @SqlType("T") Block key,
@BlockIndex int position)
{
TypedHistogram typedHistogram = state.get();
long startSize = typedHistogram.getEstimatedSize();
typedHistogram.add(position, key, 1L);
state.addMemoryUsage(typedHistogram.getEstimatedSize() - startSize);
}

public static void combine(HistogramState state, HistogramState otherState)
@CombineFunction
public static void combine(@AggregationState("T") HistogramState state, @AggregationState("T") HistogramState otherState)
{
// NOTE: state = current merged state; otherState = scratchState (new data to be added)
// for grouped histograms and single histograms, we have a single histogram object. In neither case, can otherState.get() return null.
Expand All @@ -110,7 +64,8 @@ public static void combine(HistogramState state, HistogramState otherState)
state.addMemoryUsage(typedHistogram.getEstimatedSize() - startSize);
}

public static void output(Type type, HistogramState state, BlockBuilder out)
@OutputFunction("map(T, BIGINT)")
public static void output(@TypeParameter("T") Type type, @AggregationState("T") HistogramState state, BlockBuilder out)
{
TypedHistogram typedHistogram = state.get();
typedHistogram.serialize(out);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
import io.trino.spi.function.AccumulatorState;
import io.trino.spi.function.AccumulatorStateMetadata;

@AccumulatorStateMetadata(stateFactoryClass = HistogramStateFactory.class, stateSerializerClass = HistogramStateSerializer.class)
@AccumulatorStateMetadata(
stateFactoryClass = HistogramStateFactory.class,
stateSerializerClass = HistogramStateSerializer.class,
typeParameters = "T",
serializedType = "map(T, BIGINT)")
public interface HistogramState
extends AccumulatorState
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,41 +14,55 @@
package io.trino.operator.aggregation.histogram;

import io.trino.spi.function.AccumulatorStateFactory;
import io.trino.spi.function.Convention;
import io.trino.spi.function.OperatorDependency;
import io.trino.spi.function.OperatorType;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;
import io.trino.type.BlockTypeOperators.BlockPositionEqual;
import io.trino.type.BlockTypeOperators.BlockPositionHashCode;

import static io.trino.spi.function.InvocationConvention.InvocationArgumentConvention.BLOCK_POSITION;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.FAIL_ON_NULL;
import static io.trino.spi.function.InvocationConvention.InvocationReturnConvention.NULLABLE_RETURN;
import static java.util.Objects.requireNonNull;

public class HistogramStateFactory
implements AccumulatorStateFactory<HistogramState>
{
private final Type keyType;
public static final int EXPECTED_SIZE_FOR_HASHING = 10;

private final Type type;
private final BlockPositionEqual equalOperator;
private final BlockPositionHashCode hashCodeOperator;
private final int expectedEntriesCount;

public HistogramStateFactory(
Type keyType,
BlockPositionEqual equalOperator,
BlockPositionHashCode hashCodeOperator,
int expectedEntriesCount)
@TypeParameter("T") Type type,
@OperatorDependency(
operator = OperatorType.EQUAL,
argumentTypes = {"T", "T"},
convention = @Convention(arguments = {BLOCK_POSITION, BLOCK_POSITION}, result = NULLABLE_RETURN))
BlockPositionEqual equalOperator,
@OperatorDependency(
operator = OperatorType.HASH_CODE,
argumentTypes = "T",
convention = @Convention(arguments = BLOCK_POSITION, result = FAIL_ON_NULL))
BlockPositionHashCode hashCodeOperator)
{
this.keyType = requireNonNull(keyType, "keyType is null");
this.type = requireNonNull(type, "type is null");
this.equalOperator = requireNonNull(equalOperator, "equalOperator is null");
this.hashCodeOperator = requireNonNull(hashCodeOperator, "hashCodeOperator is null");
this.expectedEntriesCount = expectedEntriesCount;
}

@Override
public HistogramState createSingleState()
{
return new SingleHistogramState(keyType, equalOperator, hashCodeOperator, expectedEntriesCount);
return new SingleHistogramState(type, equalOperator, hashCodeOperator, EXPECTED_SIZE_FOR_HASHING);
}

@Override
public HistogramState createGroupedState()
{
return new GroupedHistogramState(keyType, equalOperator, hashCodeOperator, expectedEntriesCount);
return new GroupedHistogramState(type, equalOperator, hashCodeOperator, EXPECTED_SIZE_FOR_HASHING);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.AccumulatorStateSerializer;
import io.trino.spi.function.TypeParameter;
import io.trino.spi.type.Type;

import static io.trino.operator.aggregation.histogram.Histogram.EXPECTED_SIZE_FOR_HASHING;
import static io.trino.operator.aggregation.histogram.HistogramStateFactory.EXPECTED_SIZE_FOR_HASHING;

public class HistogramStateSerializer
implements AccumulatorStateSerializer<HistogramState>
{
private final Type serializedType;

public HistogramStateSerializer(Type serializedType)
public HistogramStateSerializer(@TypeParameter("map(T, BIGINT)") Type serializedType)
{
this.serializedType = serializedType;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import com.google.common.collect.ImmutableList;
import io.trino.metadata.TestingFunctionResolution;
import io.trino.operator.GroupByIdBlock;
import io.trino.operator.aggregation.histogram.Histogram;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.sql.tree.QualifiedName;
Expand Down Expand Up @@ -134,7 +133,7 @@ public GroupedAggregator testSharedGroupWithLargeBlocksRunner(Data data)
private static TestingAggregationFunction getInternalAggregationFunctionVarChar()
{
TestingFunctionResolution functionResolution = new TestingFunctionResolution();
return functionResolution.getAggregateFunction(QualifiedName.of(Histogram.NAME), fromTypes(VARCHAR));
return functionResolution.getAggregateFunction(QualifiedName.of("histogram"), fromTypes(VARCHAR));
}

public static void main(String[] args)
Expand Down
Loading

0 comments on commit dc45401

Please sign in to comment.