Skip to content

Commit

Permalink
Pass partition ID to rand expression for achieving genuine random dis…
Browse files Browse the repository at this point in the history
…tribution globally (apache#80)

* Add an offset for seed to achieve genuine random distribution globally

* Filter out rand in projection cache

* Evaluate expr with literal input for getting seed value
  • Loading branch information
PHILO-HE authored and zhztheplayer committed Feb 28, 2022
1 parent 8dfaa7a commit 71f60de
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 17 deletions.
8 changes: 8 additions & 0 deletions cpp/src/gandiva/expression_cache_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,14 @@ class ExpressionCacheKey {

size_t Hash() const { return hash_code_; }

std::string ToString() {
std::stringstream stringstream;
for (const auto &item : expressions_as_strings_) {
stringstream << item << " || ";
}
return stringstream.str();
}

bool operator==(const ExpressionCacheKey& other) const {
if (hash_code_ != other.hash_code_) {
return false;
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/gandiva/function_registry_math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ std::vector<NativeFunction> GetMathOpsFunctionRegistry() {
NativeFunction::kNeedsFunctionHolder),
NativeFunction("random", {"rand"}, DataTypeVector{int64()}, float64(),
kResultNullNever, "gdv_fn_random_with_seed64",
NativeFunction::kNeedsFunctionHolder),
NativeFunction("random", {"rand"}, DataTypeVector{int64(), int32()}, float64(),
kResultNullNever, "gdv_fn_random_with_seed64_offset",
NativeFunction::kNeedsFunctionHolder)};

return math_fn_registry_;
Expand Down
7 changes: 7 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ double gdv_fn_random_with_seed64(int64_t ptr, int64_t seed, bool seed_validity)
return (*holder)();
}

double gdv_fn_random_with_seed64_offset(int64_t ptr, int64_t seed, bool seed_validity,
int32_t offset, bool offset_validity) {
gandiva::RandomGeneratorHolder* holder =
reinterpret_cast<gandiva::RandomGeneratorHolder*>(ptr);
return (*holder)();
}

int64_t gdv_fn_to_date_utf8_utf8(int64_t context_ptr, int64_t holder_ptr,
const char* data, int data_len, bool in1_validity,
const char* pattern, int pattern_len, bool in2_validity,
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/gandiva/gdv_function_stubs.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ double gdv_fn_random_with_seed(int64_t ptr, int32_t seed, bool seed_validity);

double gdv_fn_random_with_seed64(int64_t ptr, int64_t seed, bool seed_validity);

double gdv_fn_random_with_seed64_offset(int64_t ptr, int64_t seed, bool seed_validity,
int32_t offset, bool offset_validity);

GANDIVA_EXPORT
const char* gdv_fn_base64_encode_binary(int64_t context, const char* in, int32_t in_len,
int32_t* out_len);
Expand Down
11 changes: 7 additions & 4 deletions cpp/src/gandiva/projector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,13 @@ Status Projector::Make(SchemaPtr schema, const ExpressionVector& exprs,
output_fields.push_back(expr->result());
}

// Instantiate the projector with the completely built llvm generator
*projector = std::shared_ptr<Projector>(
new Projector(std::move(llvm_gen), schema, output_fields, configuration));
projector->get()->SetBuiltFromCache(llvm_flag);
// For statful projection, we should not cache it.
if (cache_key.ToString().find(" rand(") == std::string::npos) {
// Instantiate the projector with the completely built llvm generator
*projector = std::shared_ptr<Projector>(
new Projector(std::move(llvm_gen), schema, output_fields, configuration));
projector->get()->SetBuiltFromCache(llvm_flag);
}

return Status::OK();
}
Expand Down
83 changes: 70 additions & 13 deletions cpp/src/gandiva/random_generator_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,91 @@

#include "gandiva/random_generator_holder.h"
#include "gandiva/node.h"
#include "gandiva/projector.h"
#include "arrow/memory_pool.h"
#include "arrow/status.h"
#include "gandiva/tree_expr_builder.h"
//#include "gandiva/tests/test_util.h"
#include "arrow/type_traits.h"
#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/array/builder_base.h"
#include <type_traits>


namespace gandiva {
Status RandomGeneratorHolder::Make(const FunctionNode& node,
std::shared_ptr<RandomGeneratorHolder>* holder) {
ARROW_RETURN_IF(node.children().size() > 1,
Status::Invalid("'random' function requires at most one parameter"));
ARROW_RETURN_IF(node.children().size() > 2,
Status::Invalid("'random' function requires at most two parameters"));

if (node.children().size() == 0) {
*holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder());
return Status::OK();
}

auto literal = dynamic_cast<LiteralNode*>(node.children().at(0).get());
ARROW_RETURN_IF(literal == nullptr,
Status::Invalid("'random' function requires a literal as parameter"));

auto literal_type = literal->return_type()->id();
ARROW_RETURN_IF(
//ARROW_RETURN_IF(literal == nullptr,
// Status::Invalid("'random' function requires a literal as parameter"));
int64_t seed;
if (literal != nullptr) {
auto literal_type = literal->return_type()->id();
ARROW_RETURN_IF(
literal_type != arrow::Type::INT32 && literal_type != arrow::Type::INT64,
Status::Invalid("'random' function requires an int32/int64 literal as parameter"));

if (literal_type == arrow::Type::INT32) {
*holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder(
literal->is_null() ? 0 : arrow::util::get<int32_t>(literal->holder())));
if (literal_type == arrow::Type::INT32) {
seed = literal->is_null() ? 0 : arrow::util::get<int32_t>(literal->holder());
} else {
seed = literal->is_null() ? 0 : arrow::util::get<int64_t>(literal->holder());
}
} else {
*holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder(
literal->is_null() ? 0 : arrow::util::get<int64_t>(literal->holder())));
auto first_children_node_ptr = node.children().at(0);
//auto schema = arrow::schema({});
auto f0 = arrow::field("f0", arrow::float64());
auto schema = arrow::schema({f0});
std::shared_ptr<Projector> projector;
// Not actually used.
auto res = field("res", arrow::int32());
auto expr = TreeExprBuilder::MakeExpression(first_children_node_ptr, res);
auto builder = ConfigurationBuilder();
auto config = builder.DefaultConfiguration();
auto status = Projector::Make(schema, {expr}, config, &projector);
arrow::ArrayVector outputs;
arrow::MemoryPool* pool = arrow::default_memory_pool();

//arrow::ArrayVector inputs;
//auto in_batch = arrow::RecordBatch::Make(schema, 0, inputs);

std::vector<int> input0 = {16, 10, -14, 8};
std::vector<bool> validity = {true, true, true, true};
std::shared_ptr<arrow::Array> array0;
//arrow::ArrayFromVector<arrow::DoubleType, double>(validity, values, &array0);
//auto array0 = MakeArrowArray<arrow::DoubleType, double>(input0, validity);

auto type = arrow::TypeTraits<arrow::Int32Type>::type_singleton();
std::unique_ptr<arrow::ArrayBuilder> builder_ptr;
MakeBuilder(pool, type, &builder_ptr);
auto& arrow_array_builder = dynamic_cast<typename arrow::TypeTraits<arrow::Int32Type>::BuilderType&>(*builder_ptr);
for (size_t i = 0; i < input0.size(); ++i) {
arrow_array_builder.Append(input0[i]);
}
arrow_array_builder.Finish(&array0);

auto in_batch = arrow::RecordBatch::Make(schema, 4, {array0});

//arrow::RecordBatch in_batch;
projector->Evaluate(*in_batch, pool, &outputs);
auto result_arr = std::dynamic_pointer_cast<arrow::Int32Array>(outputs.at(0));
//seed = dynamic_cast<int32_t>(result_arr->Value(0));
seed = result_arr->Value(0);
}
// The offset is a partition ID in spark SQL. It is used to achieve genuine random distribution globally.
int32_t offset = 0;
if (node.children().size() > 1) {
auto offset_node = dynamic_cast<LiteralNode*>(node.children().at(1).get());
offset = offset_node->is_null() ? 0 : arrow::util::get<int32_t>(offset_node->holder());
}
*holder = std::shared_ptr<RandomGeneratorHolder>(new RandomGeneratorHolder(seed + offset));
return Status::OK();
}
} // namespace gandiva
30 changes: 30 additions & 0 deletions cpp/src/gandiva/random_generator_holder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,20 @@ class TestRandGenHolder : public ::testing::Test {
std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(seed), seed_is_null);
return FunctionNode("rand", {seed_node}, arrow::float64());
}

FunctionNode BuildRandWithSeedFunc(int64_t seed, bool seed_is_null) {
auto seed_node =
std::make_shared<LiteralNode>(arrow::int64(), LiteralHolder(seed), seed_is_null);
return FunctionNode("rand", {seed_node}, arrow::float64());
}

FunctionNode BuildRandWithSeedFunc(int64_t seed, bool seed_is_null, int32_t offset,
bool offset_is_null) {
auto seed_node =
std::make_shared<LiteralNode>(arrow::int64(), LiteralHolder(seed), seed_is_null);
auto offset_node = std::make_shared<LiteralNode>(arrow::int32(), LiteralHolder(offset), offset_is_null);
return FunctionNode("rand", {seed_node, offset_node}, arrow::float64());
}
};

TEST_F(TestRandGenHolder, NoSeed) {
Expand Down Expand Up @@ -106,6 +120,22 @@ TEST_F(TestRandGenHolder, WithValidSeedsInLongType) {
EXPECT_NE(random_1(), random_2());
}

// Test valid seed with offset given.
TEST_F(TestRandGenHolder, WithValidSeedsAndOffset) {
std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1;
std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2;
FunctionNode rand_func_1 = BuildRandWithSeedFunc(1000L, false);
FunctionNode rand_func_2 = BuildRandWithSeedFunc(900L, false, 100, false);
auto status = RandomGeneratorHolder::Make(rand_func_1, &rand_gen_holder_1);
EXPECT_EQ(status.ok(), true) << status.message();
status = RandomGeneratorHolder::Make(rand_func_2, &rand_gen_holder_2);
EXPECT_EQ(status.ok(), true) << status.message();

auto& random_1 = *rand_gen_holder_1;
auto& random_2 = *rand_gen_holder_2;
EXPECT_EQ(random_1(), random_2());
}

TEST_F(TestRandGenHolder, WithInValidSeed) {
std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_1;
std::shared_ptr<RandomGeneratorHolder> rand_gen_holder_2;
Expand Down

0 comments on commit 71f60de

Please sign in to comment.