Skip to content

Commit

Permalink
Reuse expression rewrites in aggregation rewrites
Browse files Browse the repository at this point in the history
This removes `identifierQuote` from `AggregateFunctionRewriter`.
Instead, leverage `ConnectorExpressionRewriter`.

Note on changes in Pinot connector: they are supposed to maintain status
quo around column name quoting.

Note that this changes how `Variable`-s in aggregates are handled, but
doesn't enable pushdown of aggregations on complex expressions yet.
Currently, `PushAggregationIntoTableScan` pushes aggregates over simple
references only.
  • Loading branch information
findepi committed Mar 16, 2022
1 parent b704cee commit 5b5b8c6
Show file tree
Hide file tree
Showing 37 changed files with 233 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,36 @@
import com.google.common.collect.ImmutableSet;
import io.trino.matching.Match;
import io.trino.plugin.base.aggregation.AggregateFunctionRule.RewriteContext;
import io.trino.plugin.base.expression.ConnectorExpressionRewriter;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.expression.ConnectorExpression;

import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;

import static java.util.Objects.requireNonNull;

public final class AggregateFunctionRewriter<Result>
public final class AggregateFunctionRewriter<AggregationResult, ExpressionResult>
{
private final Function<String, String> identifierQuote;
private final Set<AggregateFunctionRule<Result>> rules;
private final ConnectorExpressionRewriter<ExpressionResult> connectorExpressionRewriter;
private final Set<AggregateFunctionRule<AggregationResult, ExpressionResult>> rules;

public AggregateFunctionRewriter(Function<String, String> identifierQuote, Set<AggregateFunctionRule<Result>> rules)
public AggregateFunctionRewriter(ConnectorExpressionRewriter<ExpressionResult> connectorExpressionRewriter, Set<AggregateFunctionRule<AggregationResult, ExpressionResult>> rules)
{
this.identifierQuote = requireNonNull(identifierQuote, "identifierQuote is null");
this.connectorExpressionRewriter = requireNonNull(connectorExpressionRewriter, "connectorExpressionRewriter is null");
this.rules = ImmutableSet.copyOf(requireNonNull(rules, "rules is null"));
}

public Optional<Result> rewrite(ConnectorSession session, AggregateFunction aggregateFunction, Map<String, ColumnHandle> assignments)
public Optional<AggregationResult> rewrite(ConnectorSession session, AggregateFunction aggregateFunction, Map<String, ColumnHandle> assignments)
{
requireNonNull(aggregateFunction, "aggregateFunction is null");
requireNonNull(assignments, "assignments is null");

RewriteContext context = new RewriteContext()
RewriteContext<ExpressionResult> context = new RewriteContext<>()
{
@Override
public Map<String, ColumnHandle> getAssignments()
Expand All @@ -53,23 +54,23 @@ public Map<String, ColumnHandle> getAssignments()
}

@Override
public Function<String, String> getIdentifierQuote()
public ConnectorSession getSession()
{
return identifierQuote;
return session;
}

@Override
public ConnectorSession getSession()
public Optional<ExpressionResult> rewriteExpression(ConnectorExpression expression)
{
return session;
return connectorExpressionRewriter.rewrite(session, expression, assignments);
}
};

for (AggregateFunctionRule<Result> rule : rules) {
for (AggregateFunctionRule<AggregationResult, ExpressionResult> rule : rules) {
Iterator<Match> matches = rule.getPattern().match(aggregateFunction, context).iterator();
while (matches.hasNext()) {
Match match = matches.next();
Optional<Result> rewritten = rule.rewrite(aggregateFunction, match.captures(), context);
Optional<AggregationResult> rewritten = rule.rewrite(aggregateFunction, match.captures(), context);
if (rewritten.isPresent()) {
return rewritten;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,21 @@
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.expression.ConnectorExpression;

import java.util.Map;
import java.util.Optional;
import java.util.function.Function;

import static com.google.common.base.Verify.verifyNotNull;
import static java.util.Objects.requireNonNull;

public interface AggregateFunctionRule<Result>
public interface AggregateFunctionRule<AggregationResult, ExpressionResult>
{
Pattern<AggregateFunction> getPattern();

Optional<Result> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context);
Optional<AggregationResult> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<ExpressionResult> context);

interface RewriteContext
interface RewriteContext<ExpressionResult>
{
default ColumnHandle getAssignment(String name)
{
Expand All @@ -44,8 +44,8 @@ default ColumnHandle getAssignment(String name)

Map<String, ColumnHandle> getAssignments();

Function<String, String> getIdentifierQuote();

ConnectorSession getSession();

Optional<ExpressionResult> rewriteExpression(ConnectorExpression expression);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -43,7 +42,7 @@
* can result in rounding of the output to a bigint.
*/
public abstract class BaseImplementAvgBigint
implements AggregateFunctionRule<JdbcExpression>
implements AggregateFunctionRule<JdbcExpression, String>
{
private final Capture<Variable> argument;

Expand All @@ -64,19 +63,17 @@ public Pattern<AggregateFunction> getPattern()
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<String> context)
{
Variable argument = captures.get(this.argument);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == DOUBLE);

String columnName = context.getIdentifierQuote().apply(columnHandle.getColumnName());

return Optional.of(new JdbcExpression(
format(getRewriteFormatExpression(), columnName),
format(getRewriteFormatExpression(), context.rewriteExpression(argument).orElseThrow()),
new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
}

// TODO String.format is not great for contract of an extensible API. Replace with formatting method.
/**
* Implement this method for each connector supporting avg(bigint) pushdown
* @return A format string expression with a single placeholder for the column name; The string expression pushes down avg to the remote database
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
* Implements {@code avg(decimal(p, s)}
*/
public class ImplementAvgDecimal
implements AggregateFunctionRule<JdbcExpression>
implements AggregateFunctionRule<JdbcExpression, String>
{
private static final Capture<Variable> ARGUMENT = newCapture();

Expand All @@ -54,15 +54,15 @@ public Pattern<AggregateFunction> getPattern()
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<String> context)
{
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
DecimalType type = (DecimalType) columnHandle.getColumnType();
verify(aggregateFunction.getOutputType().equals(type));

return Optional.of(new JdbcExpression(
format("CAST(avg(%s) AS decimal(%s, %s))", context.getIdentifierQuote().apply(columnHandle.getColumnName()), type.getPrecision(), type.getScale()),
format("CAST(avg(%s) AS decimal(%s, %s))", context.rewriteExpression(argument).orElseThrow(), type.getPrecision(), type.getScale()),
columnHandle.getJdbcTypeHandle()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
* Implements {@code avg(float)}
*/
public class ImplementAvgFloatingPoint
implements AggregateFunctionRule<JdbcExpression>
implements AggregateFunctionRule<JdbcExpression, String>
{
private static final Capture<Variable> ARGUMENT = newCapture();

Expand All @@ -55,14 +55,14 @@ public Pattern<AggregateFunction> getPattern()
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<String> context)
{
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == columnHandle.getColumnType());

return Optional.of(new JdbcExpression(
format("avg(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
format("avg(%s)", context.rewriteExpression(argument).orElseThrow()),
columnHandle.getJdbcTypeHandle()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import static java.lang.String.format;

public class ImplementCorr
implements AggregateFunctionRule<JdbcExpression>
implements AggregateFunctionRule<JdbcExpression, String>
{
private static final Capture<List<Variable>> ARGUMENTS = newCapture();

Expand All @@ -53,17 +53,18 @@ public Pattern<AggregateFunction> getPattern()
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<String> context)
{
List<Variable> arguments = captures.get(ARGUMENTS);
verify(arguments.size() == 2);

JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(arguments.get(0).getName());
JdbcColumnHandle columnHandle2 = (JdbcColumnHandle) context.getAssignment(arguments.get(1).getName());
Variable argument1 = arguments.get(0);
Variable argument2 = arguments.get(1);
JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(argument1.getName());
verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType()));

return Optional.of(new JdbcExpression(
format("corr(%s, %s)", context.getIdentifierQuote().apply(columnHandle1.getColumnName()), context.getIdentifierQuote().apply(columnHandle2.getColumnName())),
format("corr(%s, %s)", context.rewriteExpression(argument1).orElseThrow(), context.rewriteExpression(argument2).orElseThrow()),
columnHandle1.getJdbcTypeHandle()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import io.trino.matching.Pattern;
import io.trino.plugin.base.aggregation.AggregateFunctionRule;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.spi.connector.AggregateFunction;
Expand All @@ -41,7 +40,7 @@
* Implements {@code count(x)}.
*/
public class ImplementCount
implements AggregateFunctionRule<JdbcExpression>
implements AggregateFunctionRule<JdbcExpression, String>
{
private static final Capture<Variable> ARGUMENT = newCapture();

Expand All @@ -64,14 +63,13 @@ public Pattern<AggregateFunction> getPattern()
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<String> context)
{
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
verify(aggregateFunction.getOutputType() == BIGINT);

return Optional.of(new JdbcExpression(
format("count(%s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
format("count(%s)", context.rewriteExpression(argument).orElseThrow()),
bigintTypeHandle));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
* Implements {@code count(*)}.
*/
public class ImplementCountAll
implements AggregateFunctionRule<JdbcExpression>
implements AggregateFunctionRule<JdbcExpression, String>
{
private final JdbcTypeHandle bigintTypeHandle;

Expand All @@ -57,7 +57,7 @@ public Pattern<AggregateFunction> getPattern()
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<String> context)
{
verify(aggregateFunction.getOutputType() == BIGINT);
return Optional.of(new JdbcExpression("count(*)", bigintTypeHandle));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
* Implements {@code count(DISTINCT x)}.
*/
public class ImplementCountDistinct
implements AggregateFunctionRule<JdbcExpression>
implements AggregateFunctionRule<JdbcExpression, String>
{
private static final Capture<Variable> ARGUMENT = newCapture();

Expand All @@ -71,7 +71,7 @@ public Pattern<AggregateFunction> getPattern()
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<String> context)
{
Variable argument = captures.get(ARGUMENT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(argument.getName());
Expand All @@ -84,7 +84,7 @@ public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Cap
}

return Optional.of(new JdbcExpression(
format("count(DISTINCT %s)", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
format("count(DISTINCT %s)", context.rewriteExpression(argument).orElseThrow()),
bigintTypeHandle));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import static java.lang.String.format;

public class ImplementCovariancePop
implements AggregateFunctionRule<JdbcExpression>
implements AggregateFunctionRule<JdbcExpression, String>
{
private static final Capture<List<Variable>> ARGUMENTS = newCapture();

Expand All @@ -53,17 +53,18 @@ public Pattern<AggregateFunction> getPattern()
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<String> context)
{
List<Variable> arguments = captures.get(ARGUMENTS);
verify(arguments.size() == 2);

JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(arguments.get(0).getName());
JdbcColumnHandle columnHandle2 = (JdbcColumnHandle) context.getAssignment(arguments.get(1).getName());
Variable argument1 = arguments.get(0);
Variable argument2 = arguments.get(1);
JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(argument1.getName());
verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType()));

return Optional.of(new JdbcExpression(
format("covar_pop(%s, %s)", context.getIdentifierQuote().apply(columnHandle1.getColumnName()), context.getIdentifierQuote().apply(columnHandle2.getColumnName())),
format("covar_pop(%s, %s)", context.rewriteExpression(argument1).orElseThrow(), context.rewriteExpression(argument2).orElseThrow()),
columnHandle1.getJdbcTypeHandle()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import static java.lang.String.format;

public class ImplementCovarianceSamp
implements AggregateFunctionRule<JdbcExpression>
implements AggregateFunctionRule<JdbcExpression, String>
{
private static final Capture<List<Variable>> ARGUMENTS = newCapture();

Expand All @@ -53,17 +53,18 @@ public Pattern<AggregateFunction> getPattern()
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext<String> context)
{
List<Variable> arguments = captures.get(ARGUMENTS);
verify(arguments.size() == 2);

JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(arguments.get(0).getName());
JdbcColumnHandle columnHandle2 = (JdbcColumnHandle) context.getAssignment(arguments.get(1).getName());
Variable argument1 = arguments.get(0);
Variable argument2 = arguments.get(1);
JdbcColumnHandle columnHandle1 = (JdbcColumnHandle) context.getAssignment(argument1.getName());
verify(aggregateFunction.getOutputType().equals(columnHandle1.getColumnType()));

return Optional.of(new JdbcExpression(
format("covar_samp(%s, %s)", context.getIdentifierQuote().apply(columnHandle1.getColumnName()), context.getIdentifierQuote().apply(columnHandle2.getColumnName())),
format("covar_samp(%s, %s)", context.rewriteExpression(argument1).orElseThrow(), context.rewriteExpression(argument2).orElseThrow()),
columnHandle1.getJdbcTypeHandle()));
}
}
Loading

0 comments on commit 5b5b8c6

Please sign in to comment.