Skip to content

Commit

Permalink
[Feature](function) support function array_flatten (apache#47404)
Browse files Browse the repository at this point in the history
### What problem does this PR solve?

### Check List (For Author)

- Test <!-- At least one of them must be included. -->
    - [X] Regression test
    - [ ] Unit Test
    - [ ] Manual test (add detailed scripts or steps below)
    - [ ] No need to test or manual test. Explain why:
- [ ] This is a refactor/code format and no logic has been changed.
        - [ ] Previous test can cover this change.
        - [ ] No code files have been changed.
        - [ ] Other reason <!-- Add your reason?  -->

- Behavior changed:
    - [X] No.
    - [ ] Yes. <!-- Explain the behavior change -->

- Does this need documentation?
    - [ ] No.
    - [X] Yes. <apache/doris-website#1951>

### Check List (For Reviewer who merge this PR)

- [X] Confirm the release note
- [X] Confirm test cases
- [X] Confirm document
- [X] Add branch pick label <!-- Add branch pick label that this PR
should merge into -->
  • Loading branch information
BiteTheDDDDt authored Feb 6, 2025
1 parent 3fb37ce commit 5f46933
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 1 deletion.
92 changes: 92 additions & 0 deletions be/src/vec/functions/array/function_array_flatten.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "common/status.h"
#include "vec/aggregate_functions/aggregate_function.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_nullable.h"
#include "vec/common/assert_cast.h"
#include "vec/core/block.h"
#include "vec/core/column_numbers.h"
#include "vec/core/column_with_type_and_name.h"
#include "vec/core/types.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_array.h"
#include "vec/functions/function.h"
#include "vec/functions/simple_function_factory.h"

namespace doris::vectorized {
#include "common/compile_check_begin.h"

class FunctionArrayFlatten : public IFunction {
public:
static constexpr auto name = "array_flatten";
static FunctionPtr create() { return std::make_shared<FunctionArrayFlatten>(); }

/// Get function name.
String get_name() const override { return name; }

size_t get_number_of_arguments() const override { return 1; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
DataTypePtr arg = arguments[0];
while (is_array(arg)) {
arg = remove_nullable(assert_cast<const DataTypeArray*>(arg.get())->get_nested_type());
}
return std::make_shared<DataTypeArray>(make_nullable(arg));
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const override {
auto src_column =
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
auto* src_column_array_ptr =
assert_cast<ColumnArray*>(remove_nullable(src_column)->assume_mutable().get());
ColumnArray* nested_src_column_array_ptr = src_column_array_ptr;

auto result_column_offsets =
assert_cast<ColumnArray::ColumnOffsets&>(src_column_array_ptr->get_offsets_column())
.clone();
auto* offsets = assert_cast<ColumnArray::ColumnOffsets*>(result_column_offsets.get())
->get_data()
.data();

while (src_column_array_ptr->get_data_ptr()->is_column_array()) {
nested_src_column_array_ptr = assert_cast<ColumnArray*>(
remove_nullable(src_column_array_ptr->get_data_ptr())->assume_mutable().get());

for (size_t i = 0; i < input_rows_count; ++i) {
offsets[i] = nested_src_column_array_ptr->get_offsets()[offsets[i] - 1];
}
src_column_array_ptr = nested_src_column_array_ptr;
}

block.replace_by_position(
result, ColumnArray::create(assert_cast<const ColumnNullable&>(
nested_src_column_array_ptr->get_data())
.clone(),
std::move(result_column_offsets)));
return Status::OK();
}
};

void register_function_array_flatten(SimpleFunctionFactory& factory) {
factory.register_function<FunctionArrayFlatten>();
}

} // namespace doris::vectorized
3 changes: 2 additions & 1 deletion be/src/vec/functions/array/function_array_register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "vec/functions/simple_function_factory.h"

namespace doris::vectorized {

void register_function_array_flatten(SimpleFunctionFactory&);
void register_function_array_shuffle(SimpleFunctionFactory&);
void register_function_array_exists(SimpleFunctionFactory&);
void register_function_array_element(SimpleFunctionFactory&);
Expand Down Expand Up @@ -59,6 +59,7 @@ void register_function_array_contains_all(SimpleFunctionFactory&);
void register_function_array_match(SimpleFunctionFactory&);

void register_function_array(SimpleFunctionFactory& factory) {
register_function_array_flatten(factory);
register_function_array_shuffle(factory);
register_function_array_exists(factory);
register_function_array_element(factory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirst;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirstIndex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFlatten;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayJoin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLast;
Expand Down Expand Up @@ -516,6 +517,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ArrayFilter.class, "array_filter"),
scalar(ArrayFirst.class, "array_first"),
scalar(ArrayFirstIndex.class, "array_first_index"),
scalar(ArrayFlatten.class, "array_flatten"),
scalar(ArrayIntersect.class, "array_intersect"),
scalar(ArrayJoin.class, "array_join"),
scalar(ArrayLast.class, "array_last"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;

import com.google.common.base.Preconditions;

import java.util.List;

/**
* ScalarFunction 'array_flatten'
*/
public class ArrayFlatten extends ScalarFunction
implements CustomSignature, PropagateNullable {

/**
* constructor with 1 arguments.
*/
public ArrayFlatten(Expression arg) {
super("array_flatten", arg);
}

@Override
public FunctionSignature customSignature() {
DataType dataType = getArgument(0).getDataType();
while (dataType instanceof ArrayType) {
dataType = ((ArrayType) dataType).getItemType();
}
return FunctionSignature.ret(ArrayType.of(dataType)).args(getArgument(0).getDataType());
}

/**
* withChildren.
*/
@Override
public ArrayFlatten withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new ArrayFlatten(children.get(0));
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayFlatten(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExists;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirstIndex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFlatten;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayJoin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLastIndex;
Expand Down Expand Up @@ -677,6 +678,10 @@ default R visitArrayShuffle(ArrayShuffle arrayShuffle, C context) {
return visitScalarFunction(arrayShuffle, context);
}

default R visitArrayFlatten(ArrayFlatten arrayFlatten, C context) {
return visitScalarFunction(arrayFlatten, context);
}

default R visitArrayMap(ArrayMap arraySort, C context) {
return visitScalarFunction(arraySort, context);
}
Expand Down
30 changes: 30 additions & 0 deletions regression-test/data/nereids_function_p0/scalar_function/Array.out
Original file line number Diff line number Diff line change
Expand Up @@ -16923,3 +16923,33 @@ false false
-- !sql --
false false

-- !sql --
[1, 2, 3, 4, 5]

-- !sql --
[]

-- !sql --
[1]

-- !sql --
[1, 2, 3]

-- !sql --
[1, 2, 3, 4, 5]

-- !sql --
[null, null]

-- !sql --
[1, 2, 3, 4, 5]

-- !sql --
[1, 2, 3, 4, 5]

-- !sql --
[1, 2, 3, 4, 5, 6, 7, 8, 9]

-- !sql --
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !test --
1 [1, 2, 3] [1, 2, 3] [1, 2, 3] [1, 2, 3] ["a", "b", "c"]
2 \N \N [] \N ["b", null]
3 [1, 2, null] [null] [null, 2] [null, null, 3] [null, "aaaab", "ccc"]

Original file line number Diff line number Diff line change
Expand Up @@ -1424,4 +1424,15 @@ suite("nereids_scalar_fn_Array") {
// map_contains_value
qt_sql """ select map_contains_value(map(1,1), 257), map_contains_value(map(1,2), 258);"""

qt_sql """select array_flatten([[1,2,3],[4,5]]);"""
qt_sql """select array_flatten([[],[]]);"""
qt_sql """select array_flatten([[1],[]]);"""
qt_sql """select array_flatten([[1,2,3],null]);"""
qt_sql """select array_flatten([[1,2,3],null,[4,5]]);"""
qt_sql """select array_flatten([null,null]);"""
qt_sql """select array_flatten([[1,2,3,4,5]]);"""
qt_sql """select array_flatten([[[1,2,3,4,5]]]);;"""
qt_sql """select array_flatten([ [[1,2,3,4,5]],[[6,7],[8,9]] ]);"""
qt_sql """select array_flatten([[[[[[1,2,3,4,5],[6,7],[8,9],[10,11],[12]]]]]]);"""

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("array_flatten") {
sql """DROP TABLE IF EXISTS t_array_flatten"""
sql """
CREATE TABLE IF NOT EXISTS t_array_flatten (
`k1` int(11) NULL COMMENT "",
`a1` array<tinyint(4)> NULL COMMENT "",
`aaa1` array<array<array<tinyint(4)>>> NULL COMMENT "",
`aa3` array<array<int(11)>> NOT NULL COMMENT "",
`aa5` array<array<largeint(40)>> NULL COMMENT "",
`aa14` array<array<string>> NULL COMMENT ""
) ENGINE=OLAP
DUPLICATE KEY(`k1`)
DISTRIBUTED BY HASH(`k1`) BUCKETS 1
PROPERTIES (
"replication_allocation" = "tag.location.default: 1",
"storage_format" = "V2"
)
"""
sql """ INSERT INTO t_array_flatten VALUES(1, [1, 2, 3],[[[1]],[[2],[3]]],[[1,2],[3]],[[1,2],[3]],[['a'],['b','c']]) """
sql """ INSERT INTO t_array_flatten VALUES(2, null,null,[],null,[null,['b',null]]) """
sql """ INSERT INTO t_array_flatten VALUES(3, [1, 2, null],[[[]],[[null],[]]],[[null,2],[]],[[null,null],[3]],[[null],['aaaab','ccc']]) """



qt_test """
select k1, array_flatten(a1), array_flatten(aaa1), array_flatten(aa3), array_flatten(aa5), array_flatten(aa14) from t_array_flatten order by k1;
"""
}

0 comments on commit 5f46933

Please sign in to comment.