Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Jun 26, 2023
1 parent 007221a commit 378eb48
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
3 changes: 3 additions & 0 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2767,6 +2767,9 @@ def register_aggregate_function(func, function_name, function_doc, in_types, out
This is often used with ordered or segmented aggregation where groups
can be emit before accumulating all of the input data.
Note that currently size of any input column can not exceed 2 GB limit
(all groups combined).
Parameters
----------
func : callable
Expand Down
8 changes: 4 additions & 4 deletions python/pyarrow/src/arrow/python/udf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.


#include "arrow/python/udf.h"
#include "arrow/array/builder_base.h"
#include "arrow/buffer_builder.h"
#include "arrow/compute/api_aggregate.h"
Expand All @@ -24,7 +24,6 @@
#include "arrow/compute/kernel.h"
#include "arrow/compute/row/grouper.h"
#include "arrow/python/common.h"
#include "arrow/python/udf.h"
#include "arrow/table.h"
#include "arrow/util/checked_cast.h"
#include "arrow/util/logging.h"
Expand Down Expand Up @@ -310,7 +309,7 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
const ArraySpan& groups_array_data = batch[batch.num_values() - 1].array;
DCHECK_EQ(groups_array_data.offset, 0);
int64_t batch_num_values = groups_array_data.length;
const auto* batch_groups = groups_array_data.GetValues<uint32_t>(1, 0);
const auto* batch_groups = groups_array_data.GetValues<uint32_t>(1);
RETURN_NOT_OK(groups.Append(batch_groups, batch_num_values));
values.push_back(std::move(rb));
num_values += batch_num_values;
Expand Down Expand Up @@ -352,7 +351,8 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
UdfContext udf_context{ctx->memory_pool(), table->num_rows()};

if (rb->num_rows() == 0) {
return Status::Invalid("Finalized is called with empty inputs");
*out = Datum();
return Status::OK();
}

ARROW_ASSIGN_OR_RAISE(RecordBatchVector rbs, ApplyGroupings(*groupings, rb));
Expand Down
7 changes: 5 additions & 2 deletions python/pyarrow/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,11 @@ def test_hash_agg_empty(unary_agg_func_fixture):
arr2 = pa.array([], pa.int32())
table = pa.table([arr2, arr1], names=["id", "value"])

with pytest.raises(pa.ArrowInvalid, match='empty inputs'):
table.group_by("id").aggregate([("value", "mean_udf")])
result = table.group_by("id").aggregate([("value", "mean_udf")])
expected = pa.table([pa.array([], pa.int32()), pa.array(
[], pa.float64())], names=['id', 'value_mean_udf'])

assert result == expected


def test_hash_agg_wrong_output_dtype(wrong_output_dtype_agg_func_fixture):
Expand Down

0 comments on commit 378eb48

Please sign in to comment.