Skip to content

Commit

Permalink
Implement aggregation pushdown for ClickHouse
Browse files Browse the repository at this point in the history
  • Loading branch information
wgzhao authored and hashhar committed May 12, 2021
1 parent 776838e commit cdb56d4
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 2 deletions.
18 changes: 17 additions & 1 deletion docs/src/main/sphinx/connector/clickhouse.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,23 @@ Property Name Default Value Description

=========================== ================ ==============================================================================================================

Currently the connector only supports ``Log`` and ``MergeTree`` table engines in create table.
Currently the connector only supports ``Log`` and ``MergeTree`` table engines
in create table statement. ``ReplicatedMergeTree`` engine is not yet supported.

Pushdown
--------

The connector supports pushdown for a number of operations:

* :ref:`limit-pushdown`

:ref:`Aggregate pushdown <aggregation-pushdown>` for the following functions:

* :func:`avg`
* :func:`count`
* :func:`max`
* :func:`min`
* :func:`sum`

Limitations
-----------
Expand Down
4 changes: 4 additions & 0 deletions plugin/trino-clickhouse/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
<artifactId>trino-base-jdbc</artifactId>
</dependency>

<dependency>
<groupId>io.trino</groupId>
<artifactId>trino-matching</artifactId>
</dependency>
<dependency>
<groupId>io.airlift</groupId>
<artifactId>configuration</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,28 @@
package io.trino.plugin.clickhouse;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.plugin.jdbc.BaseJdbcClient;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ColumnMapping;
import io.trino.plugin.jdbc.ConnectionFactory;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTableHandle;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.RemoteTableName;
import io.trino.plugin.jdbc.WriteMapping;
import io.trino.plugin.jdbc.expression.AggregateFunctionRewriter;
import io.trino.plugin.jdbc.expression.AggregateFunctionRule;
import io.trino.plugin.jdbc.expression.ImplementAvgFloatingPoint;
import io.trino.plugin.jdbc.expression.ImplementCount;
import io.trino.plugin.jdbc.expression.ImplementCountAll;
import io.trino.plugin.jdbc.expression.ImplementMinMax;
import io.trino.plugin.jdbc.expression.ImplementSum;
import io.trino.plugin.jdbc.mapping.IdentifierMapping;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorTableMetadata;
Expand Down Expand Up @@ -101,6 +112,9 @@ public class ClickHouseClient
extends BaseJdbcClient
{
private final boolean mapStringAsVarchar;
private final AggregateFunctionRewriter aggregateFunctionRewriter;

static final int CLICKHOUSE_MAX_DECIMAL_PRECISION = 76;

@Inject
public ClickHouseClient(
Expand All @@ -114,6 +128,30 @@ public ClickHouseClient(

// TODO (https://github.com/trinodb/trino/issues/7102) define session property
this.mapStringAsVarchar = clickHouseConfig.isMapStringAsVarchar();
JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
this.aggregateFunctionRewriter = new AggregateFunctionRewriter(
this::quoted,
ImmutableSet.<AggregateFunctionRule>builder()
.add(new ImplementCountAll(bigintTypeHandle))
.add(new ImplementCount(bigintTypeHandle))
.add(new ImplementMinMax())
.add(new ImplementSum(ClickHouseClient::toTypeHandle))
.add(new ImplementAvgFloatingPoint())
.add(new ImplementAvgDecimal())
.add(new ImplementAvgBigint())
.build());
}

@Override
public Optional<JdbcExpression> implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map<String, ColumnHandle> assignments)
{
// TODO support complex ConnectorExpressions
return aggregateFunctionRewriter.rewrite(session, aggregate, assignments);
}

private static Optional<JdbcTypeHandle> toTypeHandle(DecimalType decimalType)
{
return Optional.of(new JdbcTypeHandle(Types.DECIMAL, Optional.of("Decimal"), Optional.of(decimalType.getPrecision()), Optional.of(decimalType.getScale()), Optional.empty(), Optional.empty()));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.clickhouse;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.expression.AggregateFunctionRule;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;

import java.sql.Types;
import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;

public class ImplementAvgBigint
implements AggregateFunctionRule
{
private static final Capture<Variable> INPUT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(variable().with(expressionType().equalTo(BIGINT)).capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
verify(aggregateFunction.getOutputType() == DOUBLE);

return Optional.of(new JdbcExpression(
format("avg((%s * 1.0))", context.getIdentifierQuote().apply(columnHandle.getColumnName())),
new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.clickhouse;

import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.expression.AggregateFunctionRule;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.DecimalType;

import java.util.Optional;

import static com.google.common.base.Verify.verify;
import static io.trino.matching.Capture.newCapture;
import static io.trino.plugin.clickhouse.ClickHouseClient.CLICKHOUSE_MAX_DECIMAL_PRECISION;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.basicAggregation;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.expressionType;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.functionName;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.singleInput;
import static io.trino.plugin.jdbc.expression.AggregateFunctionPatterns.variable;
import static java.lang.String.format;

/**
* Implements {@code avg(decimal(p, s)}
*/
public class ImplementAvgDecimal
implements AggregateFunctionRule
{
private static final Capture<Variable> INPUT = newCapture();

@Override
public Pattern<AggregateFunction> getPattern()
{
return basicAggregation()
.with(functionName().equalTo("avg"))
.with(singleInput().matching(
variable()
.with(expressionType().matching(DecimalType.class::isInstance))
.capturedAs(INPUT)));
}

@Override
public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context)
{
Variable input = captures.get(INPUT);
JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName());
DecimalType type = (DecimalType) columnHandle.getColumnType();
verify(aggregateFunction.getOutputType().equals(type));

// When decimal type has maximum precision we can get result that does not match Trino avg semantics.
if (type.getPrecision() == CLICKHOUSE_MAX_DECIMAL_PRECISION) {
return Optional.of(new JdbcExpression(
format("avg(CAST(%s AS decimal(%s, %s)))", context.getIdentifierQuote().apply(columnHandle.getColumnName()), type.getPrecision(), type.getScale()),
columnHandle.getJdbcTypeHandle()));
}

// ClickHouse avg function rounds down resulting decimal.
// To match Trino avg semantics, we extend scale by 1 and round result to target scale.
return Optional.of(new JdbcExpression(
format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", context.getIdentifierQuote().apply(columnHandle.getColumnName()), type.getPrecision() + 1, type.getScale() + 1, type.getScale()),
columnHandle.getJdbcTypeHandle()));
}
}
Loading

0 comments on commit cdb56d4

Please sign in to comment.