Skip to content

Commit

Permalink
extract sum and count to GroupedAggregator interface
Browse files Browse the repository at this point in the history
  • Loading branch information
bkietz committed Feb 17, 2021
1 parent 0e9f430 commit 971d27b
Show file tree
Hide file tree
Showing 3 changed files with 220 additions and 121 deletions.
11 changes: 10 additions & 1 deletion cpp/src/arrow/compute/api_aggregate.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,19 @@ struct ARROW_EXPORT QuantileOptions : public FunctionOptions {
// TODO(michalursa) add docstring
struct ARROW_EXPORT GroupByOptions : public FunctionOptions {
struct Aggregate {
std::string name;
/// the name of the aggregation function
std::string function;

/// options for the aggregation function
const FunctionOptions* options;

/// the name of the resulting column in output
std::string name;
};
std::vector<Aggregate> aggregates;

/// the names of key columns
std::vector<std::string> key_names;
};

/// @}
Expand Down
191 changes: 152 additions & 39 deletions cpp/src/arrow/compute/kernels/aggregate_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
// specific language governing permissions and limitations
// under the License.

#include <map>
#include "arrow/compute/api_aggregate.h"
#include "arrow/compute/kernels/aggregate_basic_internal.h"
#include "arrow/compute/kernels/aggregate_internal.h"
#include "arrow/compute/kernels/common.h"
#include "arrow/util/cpu_info.h"
#include "arrow/util/make_unique.h"
#include <map>

namespace arrow {
namespace compute {
Expand All @@ -43,7 +43,8 @@ void AggregateFinalize(KernelContext* ctx, Datum* out) {
} // namespace

void AddAggKernel(std::shared_ptr<KernelSignature> sig, KernelInit init,
ScalarAggregateFunction* func, SimdLevel::type simd_level, bool nomerge) {
ScalarAggregateFunction* func, SimdLevel::type simd_level,
bool nomerge) {
ScalarAggregateKernel kernel(std::move(sig), init, AggregateConsume, AggregateMerge,
AggregateFinalize, nomerge);
// Set the simd level
Expand Down Expand Up @@ -92,6 +93,104 @@ struct CountImpl : public ScalarAggregator {
int64_t nulls = 0;
};

struct GroupedAggregator {
// GroupedAggregator subclasses are expected to be constructible from
// const FunctionOptions*. Will probably need an Init method as well
virtual ~GroupedAggregator() = default;

virtual void Consume(KernelContext*, const ExecBatch& batch,
const uint32_t* group_ids) = 0;

virtual void Finalize(KernelContext* ctx, Datum* out) = 0;

static Result<std::unique_ptr<GroupedAggregator>> Make(std::string function,
const FunctionOptions* options);
};

struct GroupedCountImpl : public GroupedAggregator {
explicit GroupedCountImpl(const FunctionOptions* options)
: options(checked_cast<const CountOptions&>(*options)) {}

void Consume(KernelContext* ctx, const ExecBatch& batch,
const uint32_t* group_ids) override {
if (batch.length == 0) return;

// maybe a batch of group_ids should include the min/max group id
auto max_group = *std::max_element(group_ids, group_ids + batch.length);
if (max_group >= counts.size()) {
counts.resize(max_group + 1, 0);
}

if (options.count_mode == CountOptions::COUNT_NON_NULL) {
auto input = batch[0].make_array();

for (int64_t i = 0; i < input->length(); ++i) {
if (input->IsNull(i)) continue;
counts[group_ids[i]]++;
}
} else {
for (int64_t i = 0; i < batch.length; ++i) {
counts[group_ids[i]]++;
}
}
}

void Finalize(KernelContext* ctx, Datum* out) override {
KERNEL_ASSIGN_OR_RAISE(auto counts_buf, ctx,
ctx->Allocate(sizeof(int64_t) * counts.size()));
std::copy(counts.begin(), counts.end(),
reinterpret_cast<int64_t*>(counts_buf->mutable_data()));
*out = std::make_shared<Int64Array>(counts.size(), std::move(counts_buf));
}

CountOptions options;
std::vector<int64_t> counts;
};

struct GroupedSumImpl : public GroupedAggregator {
explicit GroupedSumImpl(const FunctionOptions*) {}

void Consume(KernelContext* ctx, const ExecBatch& batch,
const uint32_t* group_ids) override {
if (batch.length == 0) return;

// maybe a batch of group_ids should include the min/max group id
auto max_group = *std::max_element(group_ids, group_ids + batch.length);
if (max_group >= sums.size()) {
sums.resize(max_group + 1, 0.0);
}

DCHECK_EQ(batch[0].type()->id(), Type::DOUBLE);
auto input = batch[0].array_as<DoubleArray>();

for (int64_t i = 0; i < input->length(); ++i) {
if (input->IsNull(i)) continue;
sums[group_ids[i]] += input->Value(i);
}
}

void Finalize(KernelContext* ctx, Datum* out) override {
KERNEL_ASSIGN_OR_RAISE(auto sums_buf, ctx,
ctx->Allocate(sizeof(double) * sums.size()));
std::copy(sums.begin(), sums.end(),
reinterpret_cast<double*>(sums_buf->mutable_data()));
*out = std::make_shared<DoubleArray>(sums.size(), std::move(sums_buf));
}

std::vector<double> sums;
};

Result<std::unique_ptr<GroupedAggregator>> GroupedAggregator::Make(
std::string function, const FunctionOptions* options) {
if (function == "count") {
return ::arrow::internal::make_unique<GroupedCountImpl>(options);
}
if (function == "sum") {
return ::arrow::internal::make_unique<GroupedSumImpl>(options);
}
return Status::NotImplemented("Grouped aggregate ", function);
}

std::unique_ptr<KernelState> CountInit(KernelContext*, const KernelInitArgs& args) {
return ::arrow::internal::make_unique<CountImpl>(
static_cast<const CountOptions&>(*args.options));
Expand Down Expand Up @@ -232,41 +331,41 @@ std::unique_ptr<KernelState> AllInit(KernelContext*, const KernelInitArgs& args)

struct GroupByImpl : public ScalarAggregator {
void Consume(KernelContext* ctx, const ExecBatch& batch) override {
std::vector<std::shared_ptr<ArrayData>> aggregands, keys;
ArrayDataVector aggregands, keys;

size_t i;
for (i = 0; i < aggregates.size(); ++i) {
for (i = 0; i < aggregators.size(); ++i) {
aggregands.push_back(batch[i].array());
}
while (i < static_cast<size_t>(batch.num_values())) {
keys.push_back(batch[i++].array());
}

auto key64 = batch[aggregates.size()].array_as<Int64Array>();
auto key64 = batch[aggregators.size()].array_as<Int64Array>();
if (key64->null_count() != 0) {
ctx->SetStatus(Status::NotImplemented("nulls in key column"));
return;
}

const int64_t* key64_raw = key64->raw_values();

auto valuesDouble = batch[0].array_as<DoubleArray>();
const double* valuesDouble_raw = valuesDouble->raw_values();

std::vector<uint32_t> group_ids(batch.length);
for (int64_t i = 0; i < batch.length; ++i) {
uint64_t key = key64_raw[i];
double value = valuesDouble_raw[i];
uint32_t groupid;
auto iter = map_.find(key);
if (iter == map_.end()) {
groupid = static_cast<uint32_t>(keys_.size());
group_ids[i] = static_cast<uint32_t>(keys_.size());
keys_.push_back(key);
sums_.push_back(0.0);
map_.insert(std::make_pair(key, groupid));
map_.insert(std::make_pair(key, group_ids[i]));
} else {
groupid = iter->second;
group_ids[i] = iter->second;
}
sums_[groupid] += value;
}

for (size_t i = 0; i < aggregators.size(); ++i) {
ExecBatch aggregand_batch{{aggregands[i]}, batch.length};
aggregators[i]->Consume(ctx, aggregand_batch, group_ids.data());
if (ctx->HasError()) return;
}
}

Expand All @@ -275,51 +374,63 @@ struct GroupByImpl : public ScalarAggregator {
}

void Finalize(KernelContext* ctx, Datum* out) override {
auto pool = ctx->memory_pool();
size_t length = keys_.size();
auto out_buffer = std::move(AllocateBuffer(sizeof(double) * length, pool)).ValueUnsafe();
auto out_values = out_buffer->mutable_data();
for (size_t i = 0; i < length; ++i) {
(reinterpret_cast<double*>(out_values))[i] = sums_[i];
FieldVector out_fields(aggregators.size() + 1);
ArrayDataVector out_columns(aggregators.size() + 1);
for (size_t i = 0; i < aggregators.size(); ++i) {
Datum aggregand;
aggregators[i]->Finalize(ctx, &aggregand);
if (ctx->HasError()) return;
out_columns[i] = aggregand.array();
out_fields[i] = field(options.aggregates[i].name, aggregand.type());
}
std::shared_ptr<Buffer> null_bitmap = nullptr;
Datum datum_sum = ArrayData::Make(float64(), length, {null_bitmap, std::move(out_buffer)}, 0);

auto out_buffer_key = std::move(AllocateBuffer(sizeof(int64_t) * length, pool)).ValueUnsafe();
auto out_keys = out_buffer_key->mutable_data();
for (size_t i = 0; i < length; ++i) {
(reinterpret_cast<int64_t*>(out_keys))[i] = keys_[i];
}
std::shared_ptr<Buffer> null_bitmap_key = nullptr;
Datum datum_key = ArrayData::Make(int64(), length, {null_bitmap_key, std::move(out_buffer_key)}, 0);
int64_t length = keys_.size();
KERNEL_ASSIGN_OR_RAISE(auto key_buf, ctx, ctx->Allocate(sizeof(int64_t) * length));
std::copy(keys_.begin(), keys_.end(),
reinterpret_cast<int64_t*>(key_buf->mutable_data()));
auto key = std::make_shared<Int64Array>(length, std::move(key_buf));

out_columns.back() = key->data();
out_fields.back() = field(options.key_names[0], key->type());

*out = Datum({std::move(datum_sum), std::move(datum_key)});
*out = ArrayData::Make(struct_(std::move(out_fields)), key->length(),
{/*null_bitmap=*/nullptr}, std::move(out_columns));
}

std::map<uint64_t, uint32_t> map_;
std::vector<uint64_t> keys_;
std::vector<double> sums_;
std::vector<GroupByOptions::Aggregate> aggregates;

GroupByOptions options;
std::vector<std::unique_ptr<GroupedAggregator>> aggregators;
};

std::unique_ptr<KernelState> GroupByInit(KernelContext* ctx, const KernelInitArgs& args) {
// TODO(michalursa) do construction of group by implementation
auto impl = ::arrow::internal::make_unique<GroupByImpl>();
impl->aggregates = checked_cast<const GroupByOptions*>(args.options)->aggregates;
impl->options = *checked_cast<const GroupByOptions*>(args.options);
const auto& aggregates = impl->options.aggregates;

if (impl->aggregates.size() > args.inputs.size()) {
if (aggregates.size() > args.inputs.size()) {
ctx->SetStatus(Status::Invalid("more aggegates than inputs!"));
return nullptr;
}

size_t n_keys = args.inputs.size() - impl->aggregates.size();
impl->aggregators.resize(aggregates.size());
for (size_t i = 0; i < aggregates.size(); ++i) {
ctx->SetStatus(GroupedAggregator::Make(aggregates[i].function, aggregates[i].options)
.Value(&impl->aggregators[i]));
if (ctx->HasError()) return nullptr;
}

size_t n_keys = args.inputs.size() - aggregates.size();
if (n_keys != 1) {
ctx->SetStatus(Status::NotImplemented("more than one key"));
return nullptr;
}

if (args.inputs.back().type->id() != Type::INT64) {
ctx->SetStatus(Status::NotImplemented("key of type", args.inputs.back().type->ToString()));
ctx->SetStatus(
Status::NotImplemented("key of type", args.inputs.back().type->ToString()));
return nullptr;
}

Expand Down Expand Up @@ -470,12 +581,14 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));

// group_by
func = std::make_shared<ScalarAggregateFunction>("group_by", Arity::VarArgs(), &group_by_doc);
func = std::make_shared<ScalarAggregateFunction>("group_by", Arity::VarArgs(),
&group_by_doc);
// aggregate::AddBasicAggKernels(aggregate::GroupByInit, {null()}, null(), func.get());
{
InputType any_array(ValueDescr::ARRAY);
auto sig = KernelSignature::Make({any_array}, ValueDescr::Array(int64()), true);
AddAggKernel(std::move(sig), aggregate::GroupByInit, func.get(), SimdLevel::NONE, true);
AddAggKernel(std::move(sig), aggregate::GroupByInit, func.get(), SimdLevel::NONE,
true);
}
DCHECK_OK(registry->AddFunction(std::move(func)));
// TODO(michalursa) add Kernels to the function named "group_by"
Expand Down
Loading

0 comments on commit 971d27b

Please sign in to comment.