Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-15151: [C++] ]Adding RecordBatchReaderSource to solve an issue in R API #15183

Merged
merged 12 commits into from
Jan 12, 2023
30 changes: 27 additions & 3 deletions cpp/examples/arrow/execution_plan_documentation_examples.cc
Original file line number Diff line number Diff line change
Expand Up @@ -561,8 +561,6 @@ arrow::Status SourceOrderBySinkExample() {

ARROW_ASSIGN_OR_RAISE(auto basic_data, MakeSortTestBasicBatches());

std::cout << "basic data created" << std::endl;

arrow::AsyncGenerator<std::optional<cp::ExecBatch>> sink_gen;

auto source_node_options = cp::SourceNodeOptions{basic_data.schema, basic_data.gen()};
Expand Down Expand Up @@ -761,8 +759,29 @@ arrow::Status TableSinkExample() {
std::cout << "Results : " << output_table->ToString() << std::endl;
return arrow::Status::OK();
}

// (Doc section: Table Sink Example)

// (Doc section: RecordBatchReaderSource Example)

/// \brief An example showing the usage of a RecordBatchReader as the data source.
///
/// RecordBatchReaderSourceSink Example
/// This example shows how a record_batch_reader_source can be used
/// in an execution plan. This includes the source node
/// receiving data from a TableRecordBatchReader.

arrow::Status RecordBatchReaderSourceSinkExample() {
ARROW_ASSIGN_OR_RAISE(auto table, GetTable());
std::shared_ptr<arrow::RecordBatchReader> reader =
std::make_shared<arrow::TableBatchReader>(table);
cp::Declaration reader_source{"record_batch_reader_source",
cp::RecordBatchReaderSourceNodeOptions{reader}};
return ExecutePlanAndCollectAsTable(std::move(reader_source));
}

// (Doc section: RecordBatchReaderSource Example)

enum ExampleMode {
SOURCE_SINK = 0,
TABLE_SOURCE_SINK = 1,
Expand All @@ -777,7 +796,8 @@ enum ExampleMode {
KSELECT = 10,
WRITE = 11,
UNION = 12,
TABLE_SOURCE_TABLE_SINK = 13
TABLE_SOURCE_TABLE_SINK = 13,
RECORD_BATCH_READER_SOURCE = 14
};

int main(int argc, char** argv) {
Expand Down Expand Up @@ -848,6 +868,10 @@ int main(int argc, char** argv) {
PrintBlock("TableSink Example");
status = TableSinkExample();
break;
case RECORD_BATCH_READER_SOURCE:
PrintBlock("RecordBatchReaderSource Example");
status = RecordBatchReaderSourceSinkExample();
break;
default:
break;
}
Expand Down
15 changes: 15 additions & 0 deletions cpp/src/arrow/compute/exec/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ class ARROW_EXPORT SchemaSourceNodeOptions : public ExecNodeOptions {
arrow::internal::Executor* io_executor;
};

class ARROW_EXPORT RecordBatchReaderSourceNodeOptions : public ExecNodeOptions {
public:
RecordBatchReaderSourceNodeOptions(std::shared_ptr<RecordBatchReader> reader,
arrow::internal::Executor* io_executor = NULLPTR)
: reader(std::move(reader)), io_executor(io_executor) {}

/// \brief The RecordBatchReader which acts as the data source
std::shared_ptr<RecordBatchReader> reader;

/// \brief The executor to use for the reader
///
/// Defaults to the default I/O executor.
arrow::internal::Executor* io_executor;
};

using ArrayVectorIteratorMaker = std::function<Iterator<std::shared_ptr<ArrayVector>>()>;
/// \brief An extended Source node which accepts a schema and array-vectors
class ARROW_EXPORT ArrayVectorSourceNodeOptions
Expand Down
41 changes: 41 additions & 0 deletions cpp/src/arrow/compute/exec/plan_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,39 @@ void TestSourceSink(
Finishes(ResultWith(UnorderedElementsAreArray(exp_batches.batches))));
}

void TestRecordBatchReaderSourceSink(
std::function<Result<std::shared_ptr<RecordBatchReader>>(const BatchesWithSchema&)>
to_reader) {
for (bool parallel : {false, true}) {
SCOPED_TRACE(parallel ? "parallel/merged" : "serial");
auto exp_batches = MakeBasicBatches();
ASSERT_OK_AND_ASSIGN(std::shared_ptr<RecordBatchReader> reader,
to_reader(exp_batches));
RecordBatchReaderSourceNodeOptions options{reader};
Declaration plan("record_batch_reader_source", std::move(options));
ASSERT_OK_AND_ASSIGN(auto result, DeclarationToExecBatches(plan, parallel));
AssertExecBatchesEqualIgnoringOrder(result.schema, result.batches,
exp_batches.batches);
}
}

void TestRecordBatchReaderSourceSinkError(
std::function<Result<std::shared_ptr<RecordBatchReader>>(const BatchesWithSchema&)>
to_reader) {
ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make());
auto source_factory_name = "record_batch_reader_source";
auto exp_batches = MakeBasicBatches();
ASSERT_OK_AND_ASSIGN(std::shared_ptr<RecordBatchReader> reader, to_reader(exp_batches));

auto null_executor_options = RecordBatchReaderSourceNodeOptions{reader};
ASSERT_OK(MakeExecNode(source_factory_name, plan.get(), {}, null_executor_options));

std::shared_ptr<RecordBatchReader> no_reader;
auto null_reader_options = RecordBatchReaderSourceNodeOptions{no_reader};
ASSERT_THAT(MakeExecNode(source_factory_name, plan.get(), {}, null_reader_options),
Raises(StatusCode::Invalid, HasSubstr("not null")));
}

TEST(ExecPlanExecution, ArrayVectorSourceSink) {
TestSourceSink<std::shared_ptr<ArrayVector>, ArrayVectorSourceNodeOptions>(
"array_vector_source", ToArrayVectors);
Expand Down Expand Up @@ -374,6 +407,14 @@ TEST(ExecPlanExecution, RecordBatchSourceSinkError) {
"record_batch_source", ToRecordBatches);
}

TEST(ExecPlanExecution, RecordBatchReaderSourceSink) {
TestRecordBatchReaderSourceSink(ToRecordBatchReader);
}

TEST(ExecPlanExecution, RecordBatchReaderSourceSinkError) {
TestRecordBatchReaderSourceSinkError(ToRecordBatchReader);
}

TEST(ExecPlanExecution, SinkNodeBackpressure) {
std::optional<ExecBatch> batch =
ExecBatchFromJSON({int32(), boolean()},
Expand Down
48 changes: 48 additions & 0 deletions cpp/src/arrow/compute/exec/source_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,52 @@ struct SchemaSourceNode : public SourceNode {
}
};

struct RecordBatchReaderSourceNode : public SourceNode {
RecordBatchReaderSourceNode(ExecPlan* plan, std::shared_ptr<Schema> schema,
arrow::AsyncGenerator<std::optional<ExecBatch>> generator)
: SourceNode(plan, schema, generator) {}

static Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
RETURN_NOT_OK(ValidateExecNodeInputs(plan, inputs, 0, kKindName));
const auto& cast_options =
checked_cast<const RecordBatchReaderSourceNodeOptions&>(options);
auto& reader = cast_options.reader;
auto io_executor = cast_options.io_executor;

if (reader == nullptr) {
return Status::Invalid(kKindName, " requires a reader which is not null");
}

if (io_executor == nullptr) {
io_executor = io::internal::GetIOThreadPool();
}

ARROW_ASSIGN_OR_RAISE(auto generator, MakeGenerator(reader, io_executor));
return plan->EmplaceNode<RecordBatchReaderSourceNode>(plan, reader->schema(),
generator);
}

static Result<arrow::AsyncGenerator<std::optional<ExecBatch>>> MakeGenerator(
const std::shared_ptr<RecordBatchReader>& reader,
arrow::internal::Executor* io_executor) {
auto to_exec_batch =
[](const std::shared_ptr<RecordBatch>& batch) -> std::optional<ExecBatch> {
if (batch == NULLPTR) {
return std::nullopt;
}
return std::optional<ExecBatch>(ExecBatch(*batch));
};
Iterator<std::shared_ptr<RecordBatch>> batch_it = MakeIteratorFromReader(reader);
auto exec_batch_it = MakeMapIterator(to_exec_batch, std::move(batch_it));
return MakeBackgroundGenerator(std::move(exec_batch_it), io_executor);
}

static const char kKindName[];
};

const char RecordBatchReaderSourceNode::kKindName[] = "RecordBatchReaderSourceNode";

struct RecordBatchSourceNode
: public SchemaSourceNode<RecordBatchSourceNode, RecordBatchSourceNodeOptions> {
using RecordBatchSchemaSourceNode =
Expand Down Expand Up @@ -444,6 +490,8 @@ void RegisterSourceNode(ExecFactoryRegistry* registry) {
DCHECK_OK(registry->AddFactory("source", SourceNode::Make));
DCHECK_OK(registry->AddFactory("table_source", TableSourceNode::Make));
DCHECK_OK(registry->AddFactory("record_batch_source", RecordBatchSourceNode::Make));
DCHECK_OK(registry->AddFactory("record_batch_reader_source",
RecordBatchReaderSourceNode::Make));
DCHECK_OK(registry->AddFactory("exec_batch_source", ExecBatchSourceNode::Make));
DCHECK_OK(registry->AddFactory("array_vector_source", ArrayVectorSourceNode::Make));
DCHECK_OK(registry->AddFactory("named_table", MakeNamedTableNode));
Expand Down
17 changes: 14 additions & 3 deletions cpp/src/arrow/compute/exec/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ Result<std::vector<std::shared_ptr<ExecBatch>>> ToExecBatches(
const BatchesWithSchema& batches_with_schema) {
std::vector<std::shared_ptr<ExecBatch>> exec_batches;
for (auto batch : batches_with_schema.batches) {
auto exec_batch = std::make_shared<ExecBatch>(batch);
exec_batches.push_back(exec_batch);
exec_batches.push_back(std::make_shared<ExecBatch>(batch));
}
return exec_batches;
}
Expand All @@ -285,11 +284,23 @@ Result<std::vector<std::shared_ptr<RecordBatch>>> ToRecordBatches(
for (auto batch : batches_with_schema.batches) {
ARROW_ASSIGN_OR_RAISE(auto record_batch,
batch.ToRecordBatch(batches_with_schema.schema));
record_batches.push_back(record_batch);
record_batches.push_back(std::move(record_batch));
}
return record_batches;
}

Result<std::shared_ptr<RecordBatchReader>> ToRecordBatchReader(
const BatchesWithSchema& batches_with_schema) {
std::vector<std::shared_ptr<RecordBatch>> record_batches;
for (auto batch : batches_with_schema.batches) {
ARROW_ASSIGN_OR_RAISE(auto record_batch,
batch.ToRecordBatch(batches_with_schema.schema));
record_batches.push_back(std::move(record_batch));
}
ARROW_ASSIGN_OR_RAISE(auto table, Table::FromRecordBatches(std::move(record_batches)));
return std::make_shared<arrow::TableBatchReader>(std::move(table));
}

Result<std::shared_ptr<Table>> SortTableOnAllFields(const std::shared_ptr<Table>& tab) {
std::vector<SortKey> sort_keys;
for (int i = 0; i < tab->num_columns(); i++) {
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/exec/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ ARROW_TESTING_EXPORT
Result<std::vector<std::shared_ptr<RecordBatch>>> ToRecordBatches(
const BatchesWithSchema& batches);

ARROW_TESTING_EXPORT
Result<std::shared_ptr<RecordBatchReader>> ToRecordBatchReader(
const BatchesWithSchema& batches_with_schema);

ARROW_TESTING_EXPORT
Result<std::vector<std::shared_ptr<ArrayVector>>> ToArrayVectors(
const BatchesWithSchema& batches_with_schema);
Expand Down
17 changes: 2 additions & 15 deletions r/src/compute-exec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,6 @@
#include <iostream>
#include <optional>

// GH-15151: Best path forward to make this available without a hack like this one
namespace arrow {
namespace io {
namespace internal {
arrow::internal::ThreadPool* GetIOThreadPool();
}
} // namespace io
} // namespace arrow

namespace compute = ::arrow::compute;

std::shared_ptr<compute::FunctionOptions> make_compute_options(std::string func_name,
Expand Down Expand Up @@ -459,12 +450,8 @@ std::shared_ptr<compute::ExecNode> ExecNode_Union(
std::shared_ptr<compute::ExecNode> ExecNode_SourceNode(
const std::shared_ptr<compute::ExecPlan>& plan,
const std::shared_ptr<arrow::RecordBatchReader>& reader) {
arrow::compute::SourceNodeOptions options{
/*output_schema=*/reader->schema(),
/*generator=*/ValueOrStop(
compute::MakeReaderGenerator(reader, arrow::io::internal::GetIOThreadPool()))};

return MakeExecNodeOrStop("source", plan.get(), {}, options);
arrow::compute::RecordBatchReaderSourceNodeOptions options{reader};
return MakeExecNodeOrStop("record_batch_reader_source", plan.get(), {}, options);
}

// [[arrow::export]]
Expand Down