Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wip hashagg opt3 #3

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions native-sql-engine/cpp/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ include(FindPkgConfig)
include(GNUInstallDirs)
include(CheckCXXCompilerFlag)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 17)

set(CMAKE_CXX_STANDARD_REQUIRED ON)

Expand Down Expand Up @@ -520,7 +520,6 @@ set(SPARK_COLUMNAR_PLUGIN_SRCS
precompile/builder.cc
precompile/array.cc
precompile/type.cc
precompile/sort.cc
precompile/hash_arrays_kernel.cc
precompile/unsafe_array.cc
precompile/gandiva_projector.cc
Expand Down
183 changes: 128 additions & 55 deletions native-sql-engine/cpp/src/codegen/arrow_compute/ext/actions_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,20 +152,31 @@ class UniqueAction : public ActionBase {
row_id_ = 0;
in_null_count_ = in_->null_count();
// prepare evaluate lambda
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id_);
if (cache_validity_[dest_group_id] == false) {
if (!is_null) {
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id_);
if (cache_validity_[dest_group_id] == false) {
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] = (CType)in_->GetView(row_id_);
} else {
cache_validity_[dest_group_id] = true;
null_flag_[dest_group_id] = true;
}
}
row_id_++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
if (cache_validity_[dest_group_id] == false) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] = (CType)in_->GetView(row_id_);
} else {
cache_validity_[dest_group_id] = true;
null_flag_[dest_group_id] = true;
}
}
row_id_++;
return arrow::Status::OK();
};
row_id_++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id_++;
Expand Down Expand Up @@ -1802,15 +1813,25 @@ class SumAction<DataType, CType, ResDataType, ResCType,
// prepare evaluate lambda
data_ = const_cast<CType*>(in_->data()->GetValues<CType>(1));
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
if (!in_null_count_) {
*on_valid = [this](int dest_group_id) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += data_[row_id];
}
row_id++;
return arrow::Status::OK();
};

row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += data_[row_id];
}
row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -1952,15 +1973,24 @@ class SumAction<DataType, CType, ResDataType, ResCType,
in_null_count_ = in_->null_count();
// prepare evaluate lambda
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += in_->GetView(row_id);
}
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += in_->GetView(row_id);
}
row_id++;
return arrow::Status::OK();
};
row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -2108,18 +2138,29 @@ class SumActionPartial<DataType, CType, ResDataType, ResCType,

in_ = in_list[0];
in_null_count_ = in_->null_count();
// prepare evaluate lambda

data_ = const_cast<CType*>(in_->data()->GetValues<CType>(1));
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
// prepare evaluate lambda
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += data_[row_id];
}
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += data_[row_id];
}
row_id++;
return arrow::Status::OK();
};

row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -2263,17 +2304,28 @@ class SumActionPartial<DataType, CType, ResDataType, ResCType,

in_ = std::make_shared<ArrayType>(in_list[0]);
in_null_count_ = in_->null_count();
// prepare evaluate lambda

row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
// prepare evaluate lambda
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += in_->GetView(row_id);
}
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
cache_validity_[dest_group_id] = true;
cache_[dest_group_id] += in_->GetView(row_id);
}
row_id++;
return arrow::Status::OK();
};

row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -2785,16 +2837,26 @@ class SumCountAction<DataType, CType, ResDataType, ResCType,
// prepare evaluate lambda
data_ = const_cast<CType*>(in_->data()->GetValues<CType>(1));
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
if (!in_null_count_) {
*on_valid = [this](int dest_group_id) {
cache_sum_[dest_group_id] += data_[row_id];
cache_count_[dest_group_id] += 1;
cache_validity_[dest_group_id] = true;
}
row_id++;
return arrow::Status::OK();
};
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_sum_[dest_group_id] += data_[row_id];
cache_count_[dest_group_id] += 1;
cache_validity_[dest_group_id] = true;
}
row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down Expand Up @@ -2963,16 +3025,27 @@ class SumCountAction<DataType, CType, ResDataType, ResCType,
in_null_count_ = in_->null_count();
// prepare evaluate lambda
row_id = 0;
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
if (in_null_count_) {
*on_valid = [this](int dest_group_id) {
const bool is_null = in_null_count_ > 0 && in_->IsNull(row_id);
if (!is_null) {
cache_sum_[dest_group_id] += in_->GetView(row_id);
cache_count_[dest_group_id] += 1;
cache_validity_[dest_group_id] = true;
}
row_id++;
return arrow::Status::OK();
};
} else {
*on_valid = [this](int dest_group_id) {
cache_sum_[dest_group_id] += in_->GetView(row_id);
cache_count_[dest_group_id] += 1;
cache_validity_[dest_group_id] = true;
}
row_id++;
return arrow::Status::OK();
};

row_id++;
return arrow::Status::OK();
};
}

*on_null = [this]() {
row_id++;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -911,14 +911,14 @@ class HashAggregateKernel::Impl {
for (int i = 0; i < length; i++) {
aggr_key_unsafe_row->reset();

for (auto payload_arr : payloads) {
for (const auto& payload_arr : payloads) {
payload_arr->Append(i, &aggr_key_unsafe_row);
}
aggr_key = arrow::util::string_view(aggr_key_unsafe_row->data,
aggr_key_unsafe_row->cursor);

// FIXME(): all keys are null?
aggr_hash_table_->GetOrInsert(
aggr_key, [](int) {}, [](int) {}, &(indices[i]));
aggr_key_unsafe_row->data, aggr_key_unsafe_row->cursor, [](int) {},
[](int) {}, &(indices[i]));
}
} else {
for (int i = 0; i < length; i++) {
Expand Down Expand Up @@ -973,7 +973,6 @@ class HashAggregateKernel::Impl {

arrow::Status Next(std::shared_ptr<arrow::RecordBatch>* out) {
uint64_t out_length = 0;
int gp_idx = 0;
std::vector<std::shared_ptr<arrow::Array>> outputs;
for (auto action : action_impl_list_) {
action->Finish(offset_, batch_size_, &outputs);
Expand Down
52 changes: 30 additions & 22 deletions native-sql-engine/cpp/src/precompile/hash_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,6 @@
namespace sparkcolumnarplugin {
namespace precompile {

#define TYPED_SPARSE_HASH_MAP_IMPL(TYPENAME, TYPE) \
class TYPENAME::Impl : public SparseHashMap<TYPE> { \
public: \
Impl(arrow::MemoryPool* pool) : SparseHashMap<TYPE>(pool) {} \
}; \
\
TYPENAME::TYPENAME(arrow::MemoryPool* pool) { impl_ = std::make_shared<Impl>(pool); } \
arrow::Status TYPENAME::GetOrInsert(const TYPE& value, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), \
int32_t* out_memo_index) { \
return impl_->GetOrInsert(value, on_found, on_not_found, out_memo_index); \
} \
int32_t TYPENAME::GetOrInsertNull(void (*on_found)(int32_t), \
void (*on_not_found)(int32_t)) { \
return impl_->GetOrInsertNull(on_found, on_not_found); \
} \
int32_t TYPENAME::Get(const TYPE& value) { return impl_->Get(value); } \
int32_t TYPENAME::GetNull() { return impl_->GetNull(); }

#undef TYPED_SPARSE_HASH_MAP_IMPL

#define TYPED_ARROW_HASH_MAP_IMPL(HASHMAPNAME, TYPENAME, TYPE, MEMOTABLETYPE) \
using MEMOTABLETYPE = \
typename arrow::internal::HashTraits<arrow::TYPENAME>::MemoTableType; \
Expand All @@ -72,6 +51,35 @@ namespace precompile {
int32_t HASHMAPNAME::Get(const TYPE& value) { return impl_->Get(value); } \
int32_t HASHMAPNAME::GetNull() { return impl_->GetNull(); }

#define TYPED_ARROW_HASH_MAP_BINARY_IMPL(HASHMAPNAME, TYPENAME, TYPE, MEMOTABLETYPE) \
using MEMOTABLETYPE = \
typename arrow::internal::HashTraits<arrow::TYPENAME>::MemoTableType; \
class HASHMAPNAME::Impl : public MEMOTABLETYPE { \
public: \
Impl(arrow::MemoryPool* pool) : MEMOTABLETYPE(pool, 128) {} \
}; \
\
HASHMAPNAME::HASHMAPNAME(arrow::MemoryPool* pool) { \
impl_ = std::make_shared<Impl>(pool); \
} \
arrow::Status HASHMAPNAME::GetOrInsert(const TYPE& value, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), \
int32_t* out_memo_index) { \
return impl_->GetOrInsert(value, on_found, on_not_found, out_memo_index); \
} \
arrow::Status HASHMAPNAME::GetOrInsert(const void* value, int len, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), \
int32_t* out_memo_index) { \
return impl_->GetOrInsert(value, len, on_found, on_not_found, out_memo_index); \
} \
int32_t HASHMAPNAME::GetOrInsertNull(void (*on_found)(int32_t), \
void (*on_not_found)(int32_t)) { \
return impl_->GetOrInsertNull(on_found, on_not_found); \
} \
int32_t HASHMAPNAME::Size() { return impl_->size(); } \
int32_t HASHMAPNAME::Get(const TYPE& value) { return impl_->Get(value); } \
int32_t HASHMAPNAME::GetNull() { return impl_->GetNull(); }

#define TYPED_ARROW_HASH_MAP_DECIMAL_IMPL(HASHMAPNAME, TYPENAME, TYPE, MEMOTABLETYPE) \
using MEMOTABLETYPE = \
typename arrow::internal::HashTraits<arrow::TYPENAME>::MemoTableType; \
Expand Down Expand Up @@ -103,7 +111,7 @@ TYPED_ARROW_HASH_MAP_IMPL(FloatHashMap, FloatType, float, FloatMemoTableType)
TYPED_ARROW_HASH_MAP_IMPL(DoubleHashMap, DoubleType, double, DoubleMemoTableType)
TYPED_ARROW_HASH_MAP_IMPL(Date32HashMap, Date32Type, int32_t, Date32MemoTableType)
TYPED_ARROW_HASH_MAP_IMPL(Date64HashMap, Date64Type, int64_t, Date64MemoTableType)
TYPED_ARROW_HASH_MAP_IMPL(StringHashMap, StringType, arrow::util::string_view,
TYPED_ARROW_HASH_MAP_BINARY_IMPL(StringHashMap, StringType, arrow::util::string_view,
StringMemoTableType)
TYPED_ARROW_HASH_MAP_DECIMAL_IMPL(Decimal128HashMap, Decimal128Type, arrow::Decimal128,
DecimalMemoTableType)
Expand Down
2 changes: 2 additions & 0 deletions native-sql-engine/cpp/src/precompile/hash_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ namespace precompile {
TYPENAME(arrow::MemoryPool* pool); \
arrow::Status GetOrInsert(const TYPE& value, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), int32_t* out_memo_index); \
arrow::Status GetOrInsert(const void* value, int len, void (*on_found)(int32_t), \
void (*on_not_found)(int32_t), int32_t* out_memo_index); \
int32_t GetOrInsertNull(void (*on_found)(int32_t), void (*on_not_found)(int32_t)); \
int32_t Get(const TYPE& value); \
int32_t Size(); \
Expand Down
Loading