Skip to content

Commit

Permalink
Remove argument channels from WindowFunctionSupplier
Browse files Browse the repository at this point in the history
  • Loading branch information
dain committed Dec 17, 2021
1 parent 408ca55 commit 1866a23
Show file tree
Hide file tree
Showing 13 changed files with 239 additions and 98 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import io.trino.operator.aggregation.LambdaProvider;
import io.trino.operator.window.FrameInfo;
import io.trino.operator.window.MappedWindowFunction;
import io.trino.operator.window.WindowFunctionSupplier;
import io.trino.spi.function.WindowFunction;
import io.trino.spi.type.Type;
Expand Down Expand Up @@ -56,7 +57,13 @@ public static WindowFunctionDefinition window(WindowFunctionSupplier functionSup
return new WindowFunctionDefinition(functionSupplier, type, Optional.empty(), ignoreNulls, lambdaProviders, inputs);
}

WindowFunctionDefinition(WindowFunctionSupplier functionSupplier, Type type, Optional<FrameInfo> frameInfo, boolean ignoreNulls, List<LambdaProvider> lambdaProviders, List<Integer> argumentChannels)
private WindowFunctionDefinition(
WindowFunctionSupplier functionSupplier,
Type type,
Optional<FrameInfo> frameInfo,
boolean ignoreNulls,
List<LambdaProvider> lambdaProviders,
List<Integer> argumentChannels)
{
requireNonNull(functionSupplier, "functionSupplier is null");
requireNonNull(type, "type is null");
Expand Down Expand Up @@ -84,6 +91,6 @@ public Type getType()

public WindowFunction createWindowFunction()
{
return functionSupplier.createWindowFunction(argumentChannels, ignoreNulls, lambdaProviders);
return new MappedWindowFunction(functionSupplier.createWindowFunction(ignoreNulls, lambdaProviders), argumentChannels);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,14 @@ private static void generateAddOrRemoveInputWindowIndex(
// TODO: implement masking based on maskChannel field once Window Functions support DISTINCT arguments to the functions.

Parameter index = arg("index", WindowIndex.class);
Parameter channels = arg("channels", type(List.class, Integer.class));
Parameter startPosition = arg("startPosition", int.class);
Parameter endPosition = arg("endPosition", int.class);

MethodDefinition method = definition.declareMethod(
a(PUBLIC),
generatedFunctionName,
type(void.class),
ImmutableList.of(index, channels, startPosition, endPosition));
ImmutableList.of(index, startPosition, endPosition));
Scope scope = method.getScope();

Variable position = scope.declareVariable(int.class, "position");
Expand All @@ -416,7 +415,6 @@ private static void generateAddOrRemoveInputWindowIndex(
lambdaProviderFields,
stateField,
index,
channels,
position));

method.getBody()
Expand All @@ -425,22 +423,20 @@ private static void generateAddOrRemoveInputWindowIndex(
.condition(BytecodeExpressions.lessThanOrEqual(position, endPosition))
.update(position.increment())
.body(new IfStatement()
.condition(anyParametersAreNull(argumentNullable, index, channels, position))
.condition(anyParametersAreNull(argumentNullable, index, position))
.ifFalse(invokeInputFunction)))
.ret();
}

private static BytecodeExpression anyParametersAreNull(
List<Boolean> argumentNullable,
Variable index,
Variable channels,
Variable position)
{
BytecodeExpression isNull = constantFalse();
for (int inputChannel = 0; inputChannel < argumentNullable.size(); inputChannel++) {
if (!argumentNullable.get(inputChannel)) {
BytecodeExpression getChannel = channels.invoke("get", Object.class, constantInt(inputChannel)).cast(int.class);
isNull = BytecodeExpressions.or(isNull, index.invoke("isNull", boolean.class, getChannel, position));
isNull = BytecodeExpressions.or(isNull, index.invoke("isNull", boolean.class, constantInt(inputChannel), position));
}
}

Expand All @@ -453,7 +449,6 @@ private static List<BytecodeExpression> getInvokeFunctionOnWindowIndexParameters
List<FieldDefinition> lambdaProviderFields,
List<FieldDefinition> stateField,
Variable index,
Variable channels,
Variable position)
{
List<BytecodeExpression> expressions = new ArrayList<>();
Expand All @@ -465,8 +460,7 @@ private static List<BytecodeExpression> getInvokeFunctionOnWindowIndexParameters

// input parameters
for (int i = 0; i < inputParameterCount; i++) {
BytecodeExpression getChannel = channels.invoke("get", Object.class, constantInt(i)).cast(int.class);
expressions.add(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, getChannel, position));
expressions.add(index.cast(InternalWindowIndex.class).invoke("getRawBlock", Block.class, constantInt(i), position));
}

// position parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.WindowIndex;

import java.util.List;

public interface WindowAccumulator
{
long getEstimatedSize();

WindowAccumulator copy();

void addInput(WindowIndex index, List<Integer> channels, int startPosition, int endPosition);
void addInput(WindowIndex index, int startPosition, int endPosition);

void removeInput(WindowIndex index, List<Integer> channels, int startPosition, int endPosition);
void removeInput(WindowIndex index, int startPosition, int endPosition);

void evaluateFinal(BlockBuilder blockBuilder);
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
*/
package io.trino.operator.window;

import com.google.common.collect.ImmutableList;
import io.trino.operator.aggregation.WindowAccumulator;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.WindowFunction;
import io.trino.spi.function.WindowIndex;

import java.util.List;
import java.util.function.Supplier;

import static java.lang.Math.max;
Expand All @@ -29,7 +27,6 @@
class AggregateWindowFunction
implements WindowFunction
{
private final List<Integer> argumentChannels;
private final Supplier<WindowAccumulator> accumulatorFactory;
private final boolean hasRemoveInput;

Expand All @@ -38,11 +35,10 @@ class AggregateWindowFunction
private int currentStart;
private int currentEnd;

public AggregateWindowFunction(Supplier<WindowAccumulator> accumulatorFactory, boolean hasRemoveInput, List<Integer> argumentChannels)
public AggregateWindowFunction(Supplier<WindowAccumulator> accumulatorFactory, boolean hasRemoveInput)
{
this.accumulatorFactory = requireNonNull(accumulatorFactory, "accumulatorFactory is null");
this.hasRemoveInput = hasRemoveInput;
this.argumentChannels = ImmutableList.copyOf(requireNonNull(argumentChannels, "argumentChannels is null"));
}

@Override
Expand Down Expand Up @@ -113,12 +109,12 @@ private void buildNewFrame(int frameStart, int frameEnd)

private void accumulate(int start, int end)
{
accumulator.addInput(windowIndex, argumentChannels, start, end);
accumulator.addInput(windowIndex, start, end);
}

private void remove(int start, int end)
{
accumulator.removeInput(windowIndex, argumentChannels, start, end);
accumulator.removeInput(windowIndex, start, end);
}

private void resetAccumulator()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ public List<Class<?>> getLambdaInterfaces()
}

@Override
public WindowFunction createWindowFunction(List<Integer> inputs, boolean ignoreNulls, List<LambdaProvider> lambdaProviders)
public WindowFunction createWindowFunction(boolean ignoreNulls, List<LambdaProvider> lambdaProviders)
{
return new AggregateWindowFunction(() -> createWindowAccumulator(lambdaProviders), hasRemoveInput, inputs);
return new AggregateWindowFunction(() -> createWindowAccumulator(lambdaProviders), hasRemoveInput);
}

private WindowAccumulator createWindowAccumulator(List<LambdaProvider> lambdaProviders)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,36 @@
*/
package io.trino.operator.window;

import io.trino.spi.block.BlockBuilder;
import io.trino.spi.function.WindowFunction;
import io.trino.spi.function.WindowIndex;

import java.util.List;

import static java.util.Objects.requireNonNull;

public final class FramedWindowFunction
public final class MappedWindowFunction
implements WindowFunction
{
private final WindowFunction function;
private final FrameInfo frame;
private final MappedWindowIndex mappedWindowIndex;

public FramedWindowFunction(WindowFunction windowFunction, FrameInfo frameInfo)
public MappedWindowFunction(WindowFunction windowFunction, List<Integer> argumentChannels)
{
this.function = requireNonNull(windowFunction, "windowFunction is null");
this.frame = requireNonNull(frameInfo, "frameInfo is null");
this.mappedWindowIndex = new MappedWindowIndex(argumentChannels);
}

public WindowFunction getFunction()
@Override
public void reset(WindowIndex windowIndex)
{
return function;
mappedWindowIndex.setDelegate((InternalWindowIndex) windowIndex);
function.reset(mappedWindowIndex);
}

public FrameInfo getFrame()
@Override
public void processRow(BlockBuilder output, int peerGroupStart, int peerGroupEnd, int frameStart, int frameEnd)
{
return frame;
function.processRow(output, peerGroupStart, peerGroupEnd, frameStart, frameEnd);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.operator.window;

import com.google.common.primitives.Ints;
import io.airlift.slice.Slice;
import io.trino.annotation.UsedByGeneratedCode;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;

import java.util.List;

import static java.util.Objects.requireNonNull;

public class MappedWindowIndex
implements InternalWindowIndex
{
private final int[] channelMap;
private InternalWindowIndex delegate;

public MappedWindowIndex(List<Integer> argumentChannels)
{
this.channelMap = Ints.toArray(requireNonNull(argumentChannels, "argumentChannels is null"));
}

public void setDelegate(InternalWindowIndex delegate)
{
this.delegate = delegate;
}

@Override
public int size()
{
return delegate.size();
}

@Override
public boolean isNull(int channel, int position)
{
return delegate.isNull(toDelegateChannel(channel), position);
}

@Override
public boolean getBoolean(int channel, int position)
{
return delegate.getBoolean(toDelegateChannel(channel), position);
}

@Override
public long getLong(int channel, int position)
{
return delegate.getLong(toDelegateChannel(channel), position);
}

@Override
public double getDouble(int channel, int position)
{
return delegate.getDouble(toDelegateChannel(channel), position);
}

@Override
public Slice getSlice(int channel, int position)
{
return delegate.getSlice(toDelegateChannel(channel), position);
}

@Override
public Block getSingleValueBlock(int channel, int position)
{
return delegate.getSingleValueBlock(toDelegateChannel(channel), position);
}

@Override
public Object getObject(int channel, int position)
{
return delegate.getObject(toDelegateChannel(channel), position);
}

@Override
public void appendTo(int channel, int position, BlockBuilder output)
{
delegate.appendTo(toDelegateChannel(channel), position, output);
}

@Override
@UsedByGeneratedCode
public Block getRawBlock(int channel, int position)
{
return delegate.getRawBlock(toDelegateChannel(channel), position);
}

@Override
@UsedByGeneratedCode
public int getRawBlockPosition(int position)
{
return delegate.getRawBlockPosition(position);
}

private int toDelegateChannel(int channel)
{
return channelMap[channel];
}
}
Loading

0 comments on commit 1866a23

Please sign in to comment.