Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
remove isnull when null count is zero (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
rui-mo authored Feb 14, 2021
1 parent 3532b28 commit d4d322a
Showing 1 changed file with 158 additions and 11 deletions.
169 changes: 158 additions & 11 deletions cpp/src/codegen/arrow_compute/ext/sort_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,14 @@ class SortArraysToIndicesKernel::Impl {
indice++;
}
std::string cached_insert_str = GetCachedInsert(
shuffle_typed_codegen_list.size(), projected_types_.size(), key_projector_);
shuffle_typed_codegen_list.size(), projected_types_.size(), key_projector_,
key_index_list_);
std::string comp_func_str =
GetCompFunction(key_index_list_, key_projector_, projected_types_,
key_field_list_, sort_directions_, nulls_order_);
std::string comp_func_str_without_null =
GetCompFunctionWithoutNull(key_index_list_, key_projector_, projected_types_,
key_field_list_, sort_directions_);

std::string pre_sort_valid_str = GetPreSortValid();

Expand Down Expand Up @@ -329,6 +333,7 @@ class TypedSorterImpl : public CodeGenBase {
// we should support nulls first and nulls last here
// we should also support desc and asc here
)" + comp_func_str +
comp_func_str_without_null +
R"(
// initiate buffer for all arrays
std::shared_ptr<arrow::Buffer> indices_buf;
Expand Down Expand Up @@ -375,6 +380,7 @@ class TypedSorterImpl : public CodeGenBase {
uint64_t num_batches_ = 0;
uint64_t items_total_ = 0;
uint64_t nulls_total_ = 0;
bool has_null_ = false;
class SortRelationResultIterator : public ResultIterator<SortRelation> {
public:
Expand Down Expand Up @@ -447,7 +453,8 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
)";
}
std::string GetCachedInsert(int shuffle_size, int projected_size,
const std::shared_ptr<gandiva::Projector>& key_projector) {
const std::shared_ptr<gandiva::Projector>& key_projector,
const std::vector<int>& sort_key_index_list) {
std::stringstream ss;
for (int i = 0; i < shuffle_size; i++) {
ss << "cached_" << i << "_.push_back(std::make_shared<ArrayType_" << i << ">(in["
Expand All @@ -459,6 +466,20 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
<< ">(projected_batch[" << i << "]));" << std::endl;
}
}
if (key_projector) {
for (int i = 0; i < projected_size; i++) {
ss << "if (!has_null_ && projected_" << i
<< "_[projected_0_.size() - 1]->null_count() > 0) { " << "has_null_ = true;}"
<< std::endl;
}
} else {
for (int i = 0; i < sort_key_index_list.size(); i++) {
int key_id = sort_key_index_list[i];
ss << "if (!has_null_ && cached_" << key_id << "_[cached_" << key_id
<< "_.size() - 1]->null_count() > 0) {"
<< "has_null_ = true;}" << std::endl;
}
}
return ss.str();
}
std::string GetCompFunction(
Expand All @@ -477,7 +498,26 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
ss << "auto comp = [this](ArrayItemIndexS x, ArrayItemIndexS y) {"
<< GetCompFunction_(0, projected, sort_key_index_list, key_field_list,
projected_types, sort_directions, nulls_order)
<< "};";
<< "};\n";
return ss.str();
}
std::string GetCompFunctionWithoutNull(
const std::vector<int>& sort_key_index_list,
const std::shared_ptr<gandiva::Projector>& key_projector,
const std::vector<std::shared_ptr<arrow::DataType>>& projected_types,
const std::vector<std::shared_ptr<arrow::Field>>& key_field_list,
const std::vector<bool>& sort_directions) {
std::stringstream ss;
bool projected;
if (key_projector) {
projected = true;
} else {
projected = false;
}
ss << "auto comp_without_null = [this](ArrayItemIndexS x, ArrayItemIndexS y) {"
<< GetCompFunction_Without_Null_(0, projected, sort_key_index_list, key_field_list,
projected_types, sort_directions)
<< "};\n";
return ss.str();
}
std::string GetCompFunction_(
Expand Down Expand Up @@ -515,21 +555,27 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
array + std::to_string(cur_key_id) + "_[y.array_id]->GetString(y.id)";
auto is_x_null = array + std::to_string(cur_key_id) + "_[x.array_id]->IsNull(x.id)";
auto is_y_null = array + std::to_string(cur_key_id) + "_[y.array_id]->IsNull(y.id)";
auto x_null_count =
array + std::to_string(cur_key_id) + "_[x.array_id]->null_count() > 0";
auto y_null_count =
array + std::to_string(cur_key_id) + "_[y.array_id]->null_count() > 0";
auto x_null = "(" + x_null_count + " && " + is_x_null + " )";
auto y_null = "(" + y_null_count + " && " + is_y_null + " )";
auto is_x_nan = "std::isnan(" + x_num_value + ")";
auto is_y_nan = "std::isnan(" + y_num_value + ")";

// Multiple keys sorting w/ nulls first/last is supported.
std::stringstream ss;
// We need to determine the position of nulls.
ss << "if (" << is_x_null << ") {\n";
ss << "if (" << x_null << ") {\n";
// If value accessed from x is null, return true to make nulls first.
if (nulls_first) {
ss << "return true;\n}";
} else {
ss << "return false;\n}";
}
// If value accessed from y is null, return false to make nulls first.
ss << " else if (" << is_y_null << ") {\n";
ss << " else if (" << y_null << ") {\n";
if (nulls_first) {
ss << "return false;\n}";
} else {
Expand Down Expand Up @@ -578,17 +624,17 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
// clear the contents of stringstream
ss.str(std::string());
if (data_type->id() == arrow::Type::STRING) {
ss << "if ((" << is_x_null << " && " << is_y_null << ") || (" << x_str_value
ss << "if ((" << x_null << " && " << y_null << ") || (" << x_str_value
<< " == " << y_str_value << ")) {";
} else {
if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE ||
data_type->id() == arrow::Type::FLOAT)) {
// need to check NaN
ss << "if ((" << is_x_null << " && " << is_y_null << ") || (" << is_x_nan
ss << "if ((" << x_null << " && " << y_null << ") || (" << is_x_nan
<< " && " << is_y_nan << ") || (" << x_num_value << " == " << y_num_value
<< ")) {";
} else {
ss << "if ((" << is_x_null << " && " << is_y_null << ") || (" << x_num_value
ss << "if ((" << x_null << " && " << y_null << ") || (" << x_num_value
<< " == " << y_num_value << ")) {";
}
}
Expand All @@ -597,6 +643,104 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
<< "} else { " << comp_str << "}";
return ss.str();
}
std::string GetCompFunction_Without_Null_(
int cur_key_index, bool projected, const std::vector<int>& sort_key_index_list,
const std::vector<std::shared_ptr<arrow::Field>>& key_field_list,
const std::vector<std::shared_ptr<arrow::DataType>>& projected_types,
const std::vector<bool>& sort_directions) {
std::string comp_str;
int cur_key_id;
auto field = key_field_list[cur_key_index];
bool asc = sort_directions[cur_key_index];
std::shared_ptr<arrow::DataType> data_type;
std::string array;
// if projected, use projected batch to compare, and use projected type
if (projected) {
array = "projected_";
data_type = projected_types[cur_key_index];
// use the index of projected key
cur_key_id = cur_key_index;
} else {
array = "cached_";
data_type = field->type();
// use the key_id
cur_key_id = sort_key_index_list[cur_key_index];
}

auto x_num_value =
array + std::to_string(cur_key_id) + "_[x.array_id]->GetView(x.id)";
auto x_str_value =
array + std::to_string(cur_key_id) + "_[x.array_id]->GetString(x.id)";
auto y_num_value =
array + std::to_string(cur_key_id) + "_[y.array_id]->GetView(y.id)";
auto y_str_value =
array + std::to_string(cur_key_id) + "_[y.array_id]->GetString(y.id)";
auto is_x_nan = "std::isnan(" + x_num_value + ")";
auto is_y_nan = "std::isnan(" + y_num_value + ")";

// Multiple keys sorting w/ nulls first/last is supported.
std::stringstream ss;
// If datatype is floating, we need to do partition for NaN if NaN check is enabled
if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE ||
data_type->id() == arrow::Type::FLOAT)) {
ss << "if (" << is_x_nan << ") {\n";
if (asc) {
ss << "return false;\n}";
} else {
ss << "return true;\n}";
}
ss << "else if (" << is_y_nan << ") {\n";
if (asc) {
ss << "return true;\n}";
} else {
ss << "return false;\n}";
}
// If values accessed from x and y are both not nan
ss << " else {\n";
}

// Multiple keys sorting w/ different ordering is supported.
// For string type of data, GetString should be used instead of GetView.
if (asc) {
if (data_type->id() == arrow::Type::STRING) {
ss << "return " << x_str_value << " < " << y_str_value << ";\n";
} else {
ss << "return " << x_num_value << " < " << y_num_value << ";\n";
}
} else {
if (data_type->id() == arrow::Type::STRING) {
ss << "return " << x_str_value << " > " << y_str_value << ";\n";
} else {
ss << "return " << x_num_value << " > " << y_num_value << ";\n";
}
}
if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE ||
data_type->id() == arrow::Type::FLOAT)) {
ss << "}" << std::endl;
}
comp_str = ss.str();
if ((cur_key_index + 1) == sort_key_index_list.size()) {
return comp_str;
}
// clear the contents of stringstream
ss.str(std::string());
if (data_type->id() == arrow::Type::STRING) {
ss << "if (" << x_str_value << " == " << y_str_value << ") {";
} else {
if (NaN_check_ && (data_type->id() == arrow::Type::DOUBLE ||
data_type->id() == arrow::Type::FLOAT)) {
// need to check NaN
ss << "if ((" << is_x_nan << " && " << is_y_nan << ") || ("
<< x_num_value << " == " << y_num_value << ")) {";
} else {
ss << "if (" << x_num_value << " == " << y_num_value << ") {";
}
}
ss << GetCompFunction_Without_Null_(cur_key_index + 1, projected, sort_key_index_list,
key_field_list, projected_types, sort_directions)
<< "} else { " << comp_str << "}";
return ss.str();
}
std::string GetPreSortValid() {
if (nulls_first_) {
return R"(
Expand All @@ -620,9 +764,12 @@ extern "C" void MakeCodeGen(arrow::compute::FunctionContext* ctx,
}
}
std::string GetSortFunction() {
return "gfx::timsort(indices_begin, indices_begin + "
"items_total_, "
"comp);";
std::stringstream ss;
ss << "if (has_null_) {\n"
<< "gfx::timsort(indices_begin, indices_begin + items_total_, comp);} else {\n"
<< "gfx::timsort(indices_begin, indices_begin + items_total_, comp_without_null);}"
<< std::endl;
return ss.str();
}
std::string GetMakeResultIter(int shuffle_size) {
std::stringstream ss;
Expand Down

0 comments on commit d4d322a

Please sign in to comment.