Skip to content

Commit

Permalink
[FLINK-34587][table] Introduce MODE aggregate function
Browse files Browse the repository at this point in the history
  • Loading branch information
snuyanzin committed Mar 6, 2024
1 parent 9b13755 commit 50709ba
Show file tree
Hide file tree
Showing 11 changed files with 320 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/data/sql_functions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,10 @@ aggregate:
Divides the rows for each window partition into `n` buckets ranging from 1 to at most `n`.
If the number of rows in the window partition doesn't divide evenly into the number of buckets, then the remainder values are distributed one per bucket, starting with the first bucket.
For example, with 6 rows and 4 buckets, the bucket values would be as follows: 1 1 2 2 3 4
- sql: MODE(n)
description: Returns the most frequent value in a group of values.
If there are multiple values that appear the same number of times, one of them will be returned.
NULL values are ignored. If there is no non-null value, the function returns NULL.
- sql: ARRAY_AGG([ ALL | DISTINCT ] expression [ RESPECT NULLS | IGNORE NULLS ])
table: FIELD.arrayAgg
description: |
Expand Down
4 changes: 4 additions & 0 deletions docs/data/sql_functions_zh.yml
Original file line number Diff line number Diff line change
Expand Up @@ -1187,6 +1187,10 @@ aggregate:
将窗口分区中的所有数据按照顺序划分为 n 个分组,返回分配给各行数据的分组编号(从 1 开始,最大为 n)。
如果不能均匀划分为 n 个分组,则剩余值从第 1 个分组开始,为每一分组分配一个。
比如某个窗口分区有 6 行数据,划分为 4 个分组,则各行的分组编号为:1,1,2,2,3,4。
- sql: MODE(n)
description: Returns the most frequent value in a group of values.
If there are multiple values that appear the same number of times, one of them will be returned.
NULL values are ignored. If there is no non-null value, the function returns NULL.
- sql: ARRAY_AGG([ ALL | DISTINCT ] expression [ RESPECT NULLS | IGNORE NULLS ])
table: FIELD.arrayAgg
description: |
Expand Down
9 changes: 9 additions & 0 deletions flink-python/pyflink/table/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,15 @@ def collect(self) -> 'Expression':
def array_agg(self) -> 'Expression':
return _unary_op("arrayAgg")(self)

@property
def mode_agg(self) -> 'Expression':
"""
Returns the most frequent value in a group of values.
If there are multiple values that appear the same number of times, one of them will be returned.
NULL values are ignored. If there is no non-null value, the function returns NULL.
"""
return _unary_op("modeAgg")(self)

def alias(self, name: str, *extra_names: str) -> 'Expression[T]':
"""
Specifies a name for an expression i.e. a field.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MIN;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MINUS;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MOD;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.MODE;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.NOT;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.NOT_BETWEEN;
import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.NOT_EQUALS;
Expand Down Expand Up @@ -535,6 +536,15 @@ public OutType arrayAgg() {
return toApiSpecificExpression(unresolvedCall(ARRAY_AGG, toExpr()));
}

/**
* Returns the most frequent value in a group of values. If there are multiple values that
* appear the same number of times, one of them will be returned. NULL values are ignored. If
* there is no non-null value, the function returns NULL.
*/
public OutType modeAgg() {
return toApiSpecificExpression(unresolvedCall(MODE, toExpr()));
}

/**
* Returns a new value being cast to {@code toType}. A cast error throws an exception and fails
* the job. When performing a cast operation that may fail, like {@link DataTypes#STRING()} to
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,14 @@ ANY, and(logical(LogicalTypeRoot.BOOLEAN), LITERAL)
TypeStrategies.aggArg0(LogicalTypeMerging::findAvgAggType, true))
.build();

public static final BuiltInFunctionDefinition MODE =
BuiltInFunctionDefinition.newBuilder()
.name("mode")
.kind(AGGREGATE)
.inputTypeStrategy(sequence(ANY))
.outputTypeStrategy(TypeStrategies.aggArg0(t -> t, true))
.build();

public static final BuiltInFunctionDefinition VAR_POP =
BuiltInFunctionDefinition.newBuilder()
.name("varPop")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ public class SqlAggFunctionVisitor extends ExpressionDefaultVisitor<SqlAggFuncti
BuiltInFunctionDefinitions.COLLECT, FlinkSqlOperatorTable.COLLECT);
AGG_DEF_SQL_OPERATOR_MAPPING.put(
BuiltInFunctionDefinitions.ARRAY_AGG, FlinkSqlOperatorTable.ARRAY_AGG);
AGG_DEF_SQL_OPERATOR_MAPPING.put(
BuiltInFunctionDefinitions.MODE, FlinkSqlOperatorTable.MODE_AGG);
AGG_DEF_SQL_OPERATOR_MAPPING.put(
BuiltInFunctionDefinitions.JSON_OBJECTAGG_NULL_ON_NULL,
FlinkSqlOperatorTable.JSON_OBJECTAGG_NULL_ON_NULL);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,8 @@ public List<SqlGroupedWindowFunction> getAuxiliaryFunctions() {
.withSyntax(SqlSyntax.FUNCTION)
.withAllowsNullTreatment(true);

public static final SqlAggFunction MODE_AGG = SqlStdOperatorTable.MODE;

// ARRAY OPERATORS
public static final SqlOperator ARRAY_VALUE_CONSTRUCTOR = new SqlArrayConstructor();
public static final SqlOperator ELEMENT = SqlStdOperatorTable.ELEMENT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ class AggFunctionFactory(
case a: SqlAggFunction if a.getKind == SqlKind.ARRAY_AGG =>
createArrayAggFunction(argTypes, call.ignoreNulls)

case a: SqlAggFunction if a.getKind == SqlKind.MODE =>
createModeAggFunction(argTypes)

case fn: SqlAggFunction if fn.getKind == SqlKind.JSON_OBJECTAGG =>
val onNull = fn.asInstanceOf[SqlJsonObjectAggAggFunction].getNullClause
new JsonObjectAggFunction(argTypes, onNull == SqlJsonConstructorNullClause.ABSENT_ON_NULL)
Expand Down Expand Up @@ -629,4 +632,8 @@ class AggFunctionFactory(
ignoreNulls: Boolean): UserDefinedFunction = {
new ArrayAggFunction(types(0), ignoreNulls)
}

private def createModeAggFunction(types: Array[LogicalType]): UserDefinedFunction = {
new ModeAggFunction(types(0))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.planner.functions.aggfunctions;

import org.apache.flink.table.data.StringData;
import org.apache.flink.table.runtime.functions.aggregate.ModeAggFunction;
import org.apache.flink.table.types.logical.VarCharType;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.List;

import static org.apache.flink.table.data.StringData.fromString;

/** Test case for built-in MODE with retraction aggregate function. */
final class ModeAggFunctionTest
extends AggFunctionTestBase<StringData, StringData, ModeAggFunction.ModeAcc<StringData>> {

@Override
protected List<List<StringData>> getInputValueSets() {
return Arrays.asList(
Arrays.asList(fromString("1"), fromString("1"), fromString("3")),
Arrays.asList(fromString("4"), null, fromString("4")),
Arrays.asList(null, null),
Arrays.asList(
null,
fromString("2"),
fromString("1"),
fromString("2"),
fromString("2"),
fromString("1")));
}

@Override
protected List<StringData> getExpectedResults() {
return Arrays.asList(fromString("1"), fromString("4"), null, fromString("2"));
}

@Override
protected ModeAggFunction<StringData> getAggregator() {
return new ModeAggFunction<>(new VarCharType());
}

@Override
protected Class<?> getAccClass() {
return ModeAggFunction.ModeAcc.class;
}

@Override
protected Method getAccumulateFunc() throws NoSuchMethodException {
return getAggregator().getClass().getMethod("accumulate", getAccClass(), Object.class);
}

@Override
protected Method getRetractFunc() throws NoSuchMethodException {
return getAggregator().getClass().getMethod("retract", getAccClass(), Object.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2807,6 +2807,31 @@ class OverAggregateITCase extends BatchTestBase {
)
}

@Test
def testMode(): Unit = {
checkResult(
"SELECT h, d, MODE(d) over(partition by h)" +
" FROM Table5",
Seq(
row(1, 1, 4),
row(1, 2, 4),
row(1, 4, 4),
row(1, 4, 4),
row(1, 5, 4),
row(2, 2, 3),
row(2, 3, 3),
row(2, 3, 3),
row(2, 4, 3),
row(2, 4, 3),
row(2, 5, 3),
row(2, 5, 3),
row(3, 3, 5),
row(3, 5, 5),
row(3, 5, 5)
)
)
}

@Test
def testPercentRank(): Unit = {
checkResult(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.runtime.functions.aggregate;

import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.apache.flink.table.types.utils.DataTypeUtils.toInternalDataType;

/** Built-in MODE with retraction aggregate function. */
@Internal
public class ModeAggFunction<T> extends BuiltInAggregateFunction<T, ModeAggFunction.ModeAcc<T>> {

private final transient DataType valueDataType;

public ModeAggFunction(LogicalType valueType) {
this.valueDataType = toInternalDataType(valueType);
}

@Override
public List<DataType> getArgumentDataTypes() {
return Collections.singletonList(valueDataType);
}

@Override
public DataType getOutputDataType() {
return valueDataType;
}

@Override
public T getValue(ModeAcc<T> accumulator) {
return accumulator.curMode;
}

@Override
public ModeAcc<T> createAccumulator() {
final ModeAcc<T> acc = new ModeAcc<>();
acc.buffer = new MapView<>();
acc.curCnt = 0;
acc.curMode = null;
return acc;
}

@Override
public DataType getAccumulatorDataType() {
return DataTypes.STRUCTURED(
ModeAcc.class,
DataTypes.FIELD("curCnt", DataTypes.BIGINT()),
DataTypes.FIELD("curMode", valueDataType.nullable()),
DataTypes.FIELD(
"buffer",
MapView.newMapViewDataType(valueDataType.notNull(), DataTypes.BIGINT())));
}

/** Accumulator for MODE. */
public static class ModeAcc<T> {
public long curCnt;
public T curMode;
public MapView<T, Long> buffer;

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
ModeAcc<?> lagAcc = (ModeAcc<?>) o;
return curCnt == lagAcc.curCnt
&& Objects.equals(curMode, lagAcc.curMode)
&& Objects.equals(buffer, lagAcc.buffer);
}

@Override
public int hashCode() {
return Objects.hash(curCnt, curMode, buffer);
}
}

public void accumulate(ModeAcc<T> acc, T value) throws Exception {
if (value == null) {
return;
}
Long cnt = acc.buffer.get(value);
cnt = cnt == null ? 1 : cnt + 1;
acc.buffer.put(value, cnt);
if (Objects.equals(value, acc.curMode)) {
acc.curCnt = cnt;
} else if (cnt > acc.curCnt) {
acc.curMode = value;
acc.curCnt = cnt;
}
}

public void retract(ModeAcc<T> acc, T value) throws Exception {
if (value == null) {
return;
}
Long cnt = acc.buffer.get(value);
cnt = cnt == null ? 0L : cnt - 1;
if (cnt <= 0) {
acc.buffer.remove(value);
} else {
acc.buffer.put(value, cnt);
}
acc.curCnt = cnt;
if (Objects.equals(value, acc.curMode)) {
for (Map.Entry<T, Long> entry : acc.buffer.entries()) {
if (entry.getValue() > acc.curCnt) {
acc.curMode = entry.getKey();
acc.curCnt = entry.getValue();
}
}
}
if (acc.curCnt == 0) {
acc.curMode = null;
}
}

public void resetAccumulator(ModeAcc<T> acc) {
acc.buffer.clear();
acc.curCnt = 0;
acc.curMode = null;
}

public void merge(ModeAggFunction.ModeAcc<T> acc, Iterable<ModeAggFunction.ModeAcc<T>> its)
throws Exception {
for (ModeAggFunction.ModeAcc<T> otherAcc : its) {
if (!otherAcc.buffer.iterator().hasNext()) {
// otherAcc is empty, skip it
continue;
}
for (Map.Entry<T, Long> entry : otherAcc.buffer.entries()) {
final T key = entry.getKey();
final long newValue;
if (acc.buffer.contains(key)) {
newValue = acc.buffer.get(key) + entry.getValue();
acc.buffer.put(key, newValue);
} else {
newValue = entry.getValue();
acc.buffer.put(key, newValue);
}
if (newValue > acc.curCnt) {
acc.curMode = key;
acc.curCnt = newValue;
}
}
}
}
}

0 comments on commit 50709ba

Please sign in to comment.