Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Dec 26, 2024
1 parent af07e9b commit 987ae4f
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 356 deletions.
205 changes: 0 additions & 205 deletions cpp/src/arrow/acero/hash_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3448,208 +3448,3 @@ TEST(HashJoin, LARGE_MEMORY_TEST(BuildSideOver4GBVarLength)) {

} // namespace acero
} // namespace arrow

using namespace arrow;
using namespace arrow::acero;

TEST(HashJoin, GH44513_Old) {
// int main() {
const int64_t num_small_rows = 18201475;
const int64_t num_large_rows = 360449051;
const int64_t num_batches = 8;
const int64_t seed = 42;

auto small_schema =
schema({field("c0", int64()), field("key0", int64()), field("key1", int64()),
field("c1", int64()), field("c2", int64()), field("key2", int64()),
field("c3", int64()), field("c4", int64()), field("c5", int64())});
auto large_schema = schema({field("key0", int64()), field("key1", int64()),
field("key2", int64()), field("payload", int64())});

const int64_t small_key_min = 28487150299;
const int64_t small_key_max = 98486720299;
const double small_key_null_probability = 0;
auto small_key_arr =
RandomArrayGenerator(seed).Int64(num_small_rows / num_batches, small_key_min,
small_key_max, small_key_null_probability);

// const int64_t num_matching_rows = 17060116;
// const double matching_rate = double(num_matching_rows) / double(num_small_rows) /
// (double(num_small_rows) / double(small_key_max));
const int64_t large_key_min = 2018570299;
const int64_t large_key_max = 99756520299;
// const int64_t large_key_min = small_key_max * (1 - matching_rate);
// const int64_t large_key_max = 391014 + large_key_min;
const double large_key_null_probability = 0;
auto large_key_arr =
RandomArrayGenerator(seed).Int64(num_large_rows / num_batches, large_key_min,
large_key_max, large_key_null_probability);

const int64_t payload_min = 0;
const int64_t payload_max = 96920;
const double payload_null_probability = 283328672.0 / num_large_rows;

// auto small_key0_arr = RandomArrayGenerator(seed).Int64(
// num_small_rows / num_batches, key_min, key_max, key_null_probability);
// auto small_key1_arr = RandomArrayGenerator(seed).Int64(
// num_small_rows / num_batches, key_min, key_max, key_null_probability);
// auto small_key2_arr = RandomArrayGenerator(seed).Int64(
// num_small_rows / num_batches, key_min, key_max, key_null_probability);
ExecBatch small_batch(
{small_key_arr, small_key_arr, small_key_arr, small_key_arr, small_key_arr,
small_key_arr, small_key_arr, small_key_arr, small_key_arr},
num_small_rows / num_batches);
auto small_batches =
BatchesWithSchema{std::vector<ExecBatch>(num_batches, small_batch), small_schema};

// auto large_key0_arr = RandomArrayGenerator(seed).Int64(
// num_large_rows / num_batches, key_min, key_max, key_null_probability);
// auto large_key1_arr = RandomArrayGenerator(seed).Int64(
// num_large_rows / num_batches, key_min, key_max, key_null_probability);
// auto large_key2_arr = RandomArrayGenerator(seed).Int64(
// num_large_rows / num_batches, key_min, key_max, key_null_probability);
auto large_payload_arr = RandomArrayGenerator(seed).Int64(
num_large_rows / num_batches, payload_min, payload_max, payload_null_probability);
ExecBatch large_batch({large_key_arr, large_key_arr, large_key_arr, large_payload_arr},
num_large_rows / num_batches);
auto large_batches =
BatchesWithSchema{std::vector<ExecBatch>(num_batches, large_batch), large_schema};

{
Declaration small_source{
"exec_batch_source",
ExecBatchSourceNodeOptions(small_batches.schema, small_batches.batches)};
Declaration large_source{
"exec_batch_source",
ExecBatchSourceNodeOptions(large_batches.schema, large_batches.batches)};

HashJoinNodeOptions join_opts(JoinType::RIGHT_OUTER,
/*left_keys=*/{"key0", "key1", "key2"},
/*right_keys=*/{"key0", "key1", "key2"});
Declaration join{
"hashjoin", {std::move(large_source), std::move(small_source)}, join_opts};

AggregateNodeOptions agg_opts{/*aggregates=*/{
{"count_all", "count(*)"}, {"sum", nullptr, "payload", "sum(payload)"}}};
Declaration agg{"aggregate", {std::move(join)}, std::move(agg_opts)};

auto result = DeclarationToTable(std::move(agg)).ValueOrDie();
std::cout << result->ToString() << std::endl;
}
{
Declaration small_source{
"exec_batch_source",
ExecBatchSourceNodeOptions(small_batches.schema, small_batches.batches)};
Declaration large_source{
"exec_batch_source",
ExecBatchSourceNodeOptions(large_batches.schema, large_batches.batches)};

HashJoinNodeOptions join_opts(JoinType::LEFT_OUTER,
/*left_keys=*/{"key0", "key1", "key2"},
/*right_keys=*/{"key0", "key1", "key2"});
Declaration join{
"hashjoin", {std::move(small_source), std::move(large_source)}, join_opts};

AggregateNodeOptions agg_opts{/*aggregates=*/{
{"count_all", "count(*)"}, {"sum", nullptr, "payload", "sum(payload)"}}};
Declaration agg{"aggregate", {std::move(join)}, std::move(agg_opts)};

auto result = DeclarationToTable(std::move(agg)).ValueOrDie();
std::cout << result->ToString() << std::endl;
}

// return 0;
}

TEST(HashJoin, GH44513) {
const int64_t num_large_rows = 360449051 * 4;
// const int64_t num_large_rows = 18201475;
const int64_t num_batches = 8;
const int64_t seed = 42;

auto small_schema =
schema({field("key0", int64()), field("key1", int64()), field("key2", int64())});
auto large_schema = schema({field("key0", int64()), field("key1", int64()),
field("key2", int64()), field("payload", int64())});

const int64_t key0_match = 88506230299;
const int64_t key1_match = 16556030299;
const int64_t key2_match = 11240299;
const int64_t payload_match = 42;

ASSERT_OK_AND_ASSIGN(auto small_key0_arr,
Constant(MakeScalar(key0_match))->Generate(1));
ASSERT_OK_AND_ASSIGN(auto small_key1_arr,
Constant(MakeScalar(key1_match))->Generate(1));
ASSERT_OK_AND_ASSIGN(auto small_key2_arr,
Constant(MakeScalar(key2_match))->Generate(1));
ExecBatch small_batch({small_key0_arr, small_key1_arr, small_key2_arr}, 1);

auto large_unmatch_key_arr = RandomArrayGenerator(seed).Int64(
num_large_rows / num_batches, key0_match + 1, 99756520299);
ASSERT_OK_AND_ASSIGN(auto large_unmatch_payload_arr,
MakeArrayOfNull(int64(), num_large_rows / num_batches));
ExecBatch large_unmatch_batch({large_unmatch_key_arr, large_unmatch_key_arr,
large_unmatch_key_arr, large_unmatch_payload_arr},
num_large_rows / num_batches);

auto large_match_key0_arr = small_key0_arr;
auto large_match_key1_arr = small_key1_arr;
auto large_match_key2_arr = small_key2_arr;
ASSERT_OK_AND_ASSIGN(auto large_match_payload_arr,
Constant(MakeScalar(payload_match))->Generate(1));
ExecBatch large_match_batch({large_match_key0_arr, large_match_key1_arr,
large_match_key2_arr, large_match_payload_arr},
1);

auto small_batches =
BatchesWithSchema{std::vector<ExecBatch>{small_batch}, small_schema};
auto large_batches = BatchesWithSchema{
std::vector<ExecBatch>(num_batches, large_unmatch_batch), large_schema};
large_batches.batches.push_back(large_match_batch);

// {
// Declaration small_source{
// "exec_batch_source",
// ExecBatchSourceNodeOptions(small_batches.schema, small_batches.batches)};
// Declaration large_source{
// "exec_batch_source",
// ExecBatchSourceNodeOptions(large_batches.schema, large_batches.batches)};

// HashJoinNodeOptions join_opts(JoinType::INNER,
// /*left_keys=*/{"key0", "key1", "key2"},
// /*right_keys=*/{"key0", "key1", "key2"});
// Declaration join{
// "hashjoin", {std::move(large_source), std::move(small_source)}, join_opts};

// AggregateNodeOptions agg_opts{/*aggregates=*/{
// {"count_all", "count(*)"}, {"sum", nullptr, "payload", "sum(payload)"}}};
// Declaration agg{"aggregate", {std::move(join)}, std::move(agg_opts)};

// auto result = DeclarationToTable(std::move(agg)).ValueOrDie();
// std::cout << result->ToString() << std::endl;
// }
{
Declaration small_source{
"exec_batch_source",
ExecBatchSourceNodeOptions(small_batches.schema, small_batches.batches)};
Declaration large_source{
"exec_batch_source",
ExecBatchSourceNodeOptions(large_batches.schema, large_batches.batches)};

HashJoinNodeOptions join_opts(JoinType::INNER,
/*left_keys=*/{"key0", "key1", "key2"},
/*right_keys=*/{"key0", "key1", "key2"});
Declaration join{
"hashjoin", {std::move(small_source), std::move(large_source)}, join_opts};

AggregateNodeOptions agg_opts{/*aggregates=*/{
{"count_all", "count(*)"}, {"sum", nullptr, "payload", "sum(payload)"}}};
Declaration agg{"aggregate", {std::move(join)}, std::move(agg_opts)};

auto result = DeclarationToTable(std::move(agg)).ValueOrDie();
std::cout << result->ToString() << std::endl;
}

// return 0;
}
10 changes: 6 additions & 4 deletions cpp/src/arrow/acero/swiss_join.cc
Original file line number Diff line number Diff line change
Expand Up @@ -633,10 +633,11 @@ void SwissTableMerge::MergePartition(SwissTable* target, const SwissTable* sourc
std::vector<uint32_t>* overflow_hashes) {
// Prepare parameters needed for scanning full slots in source.
//
int source_group_id_bits =
int64_t source_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(source->log_blocks());
uint64_t source_group_id_mask = ~0ULL >> (64 - source_group_id_bits);
int64_t source_block_bytes = source_group_id_bits + 8;
int64_t source_block_bytes =
SwissTable::num_block_bytes_from_num_groupid_bits(source_group_id_bits);
ARROW_DCHECK(source_block_bytes % sizeof(uint64_t) == 0);

// Compute index of the last block in target that corresponds to the given
Expand Down Expand Up @@ -694,9 +695,10 @@ inline bool SwissTableMerge::InsertNewGroup(SwissTable* target, uint64_t group_i
//
int64_t block_id = hash >> (SwissTable::bits_hash_ - target->log_blocks());
int64_t block_id_mask = ((1LL << target->log_blocks()) - 1);
int num_group_id_bits =
int64_t num_group_id_bits =
SwissTable::num_groupid_bits_from_log_blocks(target->log_blocks());
int64_t num_block_bytes = num_group_id_bits + sizeof(uint64_t);
int64_t num_block_bytes =
SwissTable::num_block_bytes_from_num_groupid_bits(num_group_id_bits);
ARROW_DCHECK(num_block_bytes % sizeof(uint64_t) == 0);
uint8_t* block_bytes = target->blocks() + block_id * num_block_bytes;
uint64_t block = *reinterpret_cast<const uint64_t*>(block_bytes);
Expand Down
Loading

0 comments on commit 987ae4f

Please sign in to comment.