Skip to content

Commit

Permalink
ARROW-17980: [C++] As-of-Join Substrait extension (apache#17)
Browse files Browse the repository at this point in the history
* ARROW-17980: [C++] As-of-Join Substrait extension

* add missing file

* add missing proto

* CI fixes

* distinct keys per input table

* CI fixes

* resolve conflict

* fix typo

* ARROW-17980: Change extensions package from arrow::substrait to arrow::substrait_ext

* ARROW-17980: Remove more instances of ::substrait

Co-authored-by: Yaron Gvili <rtpsw@hotmail.com>
  • Loading branch information
westonpace and rtpsw authored Oct 21, 2022
1 parent 73efef8 commit 25d83a3
Show file tree
Hide file tree
Showing 15 changed files with 643 additions and 91 deletions.
27 changes: 27 additions & 0 deletions cpp/cmake_modules/ThirdpartyToolchain.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1724,6 +1724,10 @@ macro(build_substrait)
# Note: not all protos in Substrait actually matter to plan
# consumption. No need to build the ones we don't need.
set(SUBSTRAIT_PROTOS algebra extensions/extensions plan type)
set(ARROW_SUBSTRAIT_PROTOS extension_rels)
set(ARROW_SUBSTRAIT_PROTOS_DIR "${CMAKE_SOURCE_DIR}/proto")
message("SOURCE DIR IS ${SOURCE_DIR} AND ${CMAKE_SOURCE_DIR} AND ${ARROW_SUBSTRAIT_PROTOS_DIR}"
)

externalproject_add(substrait_ep
CONFIGURE_COMMAND ""
Expand Down Expand Up @@ -1774,6 +1778,29 @@ macro(build_substrait)

list(APPEND SUBSTRAIT_SOURCES "${SUBSTRAIT_PROTO_GEN}.cc")
endforeach()
message("SOURCE DIR2 IS ${SOURCE_DIR} AND ${CMAKE_SOURCE_DIR} AND ${ARROW_SUBSTRAIT_PROTOS_DIR}"
)
foreach(ARROW_SUBSTRAIT_PROTO ${ARROW_SUBSTRAIT_PROTOS})
set(ARROW_SUBSTRAIT_PROTO_GEN
"${SUBSTRAIT_CPP_DIR}/substrait/${ARROW_SUBSTRAIT_PROTO}.pb")
foreach(EXT h cc)
set_source_files_properties("${ARROW_SUBSTRAIT_PROTO_GEN}.${EXT}"
PROPERTIES COMPILE_OPTIONS
"${SUBSTRAIT_SUPPRESSED_FLAGS}"
GENERATED TRUE
SKIP_UNITY_BUILD_INCLUSION TRUE)
list(APPEND SUBSTRAIT_PROTO_GEN_ALL "${ARROW_SUBSTRAIT_PROTO_GEN}.${EXT}")
endforeach()
add_custom_command(OUTPUT "${ARROW_SUBSTRAIT_PROTO_GEN}.cc"
"${ARROW_SUBSTRAIT_PROTO_GEN}.h"
COMMAND ${ARROW_PROTOBUF_PROTOC} "-I${SUBSTRAIT_LOCAL_DIR}/proto"
"-I${ARROW_SUBSTRAIT_PROTOS_DIR}"
"--cpp_out=${SUBSTRAIT_CPP_DIR}"
"${ARROW_SUBSTRAIT_PROTOS_DIR}/substrait/${ARROW_SUBSTRAIT_PROTO}.proto"
DEPENDS ${PROTO_DEPENDS} substrait_ep)

list(APPEND SUBSTRAIT_SOURCES "${ARROW_SUBSTRAIT_PROTO_GEN}.cc")
endforeach()

add_custom_target(substrait_gen ALL DEPENDS ${SUBSTRAIT_PROTO_GEN_ALL})

Expand Down
36 changes: 36 additions & 0 deletions cpp/proto/substrait/extension_rels.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
syntax = "proto3";

package arrow.substrait_ext;

import "substrait/algebra.proto";

option csharp_namespace = "Arrow.Substrait";
option go_package = "github.com/apache/arrow/substrait";
option java_multiple_files = true;
option java_package = "io.arrow.substrait";

message AsOfJoinRel {
repeated AsOfJoinKeys input_keys = 1;
int64 tolerance = 2;

message AsOfJoinKeys {
.substrait.Expression on = 1;
repeated .substrait.Expression by = 2;
}
}
12 changes: 11 additions & 1 deletion cpp/src/arrow/compute/exec/asof_join_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,19 @@ static void TableJoinOverhead(benchmark::State& state,
benchmark::Counter(static_cast<double>(default_memory_pool()->max_memory()));
}

AsofJoinNodeOptions GetRepeatedOptions(size_t repeat, FieldRef on_key,
std::vector<FieldRef> by_key, int64_t tolerance) {
std::vector<AsofJoinNodeOptions::Keys> input_keys(repeat);
for (size_t i = 0; i < repeat; i++) {
input_keys[i] = {on_key, by_key};
}
return AsofJoinNodeOptions(input_keys, tolerance);
}

static void AsOfJoinOverhead(benchmark::State& state) {
int64_t tolerance = 0;
AsofJoinNodeOptions options = AsofJoinNodeOptions(kTimeCol, {kKeyCol}, tolerance);
AsofJoinNodeOptions options =
GetRepeatedOptions(int(state.range(4)), kTimeCol, {kKeyCol}, tolerance);
TableJoinOverhead(
state,
TableGenerationProperties{int(state.range(0)), int(state.range(1)),
Expand Down
107 changes: 85 additions & 22 deletions cpp/src/arrow/compute/exec/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

#include "arrow/compute/exec/asof_join_node.h"

#include <condition_variable>
#include <memory>
#include <mutex>
Expand Down Expand Up @@ -951,28 +953,27 @@ class AsofJoinNode : public ExecNode {
}

static arrow::Result<std::shared_ptr<Schema>> MakeOutputSchema(
const std::vector<ExecNode*>& inputs,
const std::vector<std::shared_ptr<Schema>> input_schema,
const std::vector<col_index_t>& indices_of_on_key,
const std::vector<std::vector<col_index_t>>& indices_of_by_key) {
std::vector<std::shared_ptr<arrow::Field>> fields;

size_t n_by = indices_of_by_key[0].size();
size_t n_by = indices_of_by_key.size() == 0 ? 0 : indices_of_by_key[0].size();
const DataType* on_key_type = NULLPTR;
std::vector<const DataType*> by_key_type(n_by, NULLPTR);
// Take all non-key, non-time RHS fields
for (size_t j = 0; j < inputs.size(); ++j) {
const auto& input_schema = inputs[j]->output_schema();
for (size_t j = 0; j < input_schema.size(); ++j) {
const auto& on_field_ix = indices_of_on_key[j];
const auto& by_field_ix = indices_of_by_key[j];

if ((on_field_ix == -1) || std_has(by_field_ix, -1)) {
return Status::Invalid("Missing join key on table ", j);
}

const auto& on_field = input_schema->fields()[on_field_ix];
const auto& on_field = input_schema[j]->fields()[on_field_ix];
std::vector<const Field*> by_field(n_by);
for (size_t k = 0; k < n_by; k++) {
by_field[k] = input_schema->fields()[by_field_ix[k]].get();
by_field[k] = input_schema[j]->fields()[by_field_ix[k]].get();
}

if (on_key_type == NULLPTR) {
Expand All @@ -992,8 +993,8 @@ class AsofJoinNode : public ExecNode {
}
}

for (int i = 0; i < input_schema->num_fields(); ++i) {
const auto field = input_schema->field(i);
for (int i = 0; i < input_schema[j]->num_fields(); ++i) {
const auto field = input_schema[j]->field(i);
if (i == on_field_ix) {
ARROW_RETURN_NOT_OK(is_valid_on_field(field));
// Only add on field from the left table
Expand Down Expand Up @@ -1030,6 +1031,56 @@ class AsofJoinNode : public ExecNode {
return match.indices()[0];
}

static Result<size_t> GetByKeySize(
const std::vector<asofjoin::AsofJoinKeys>& input_keys) {
size_t n_by = 0;
for (size_t i = 0; i < input_keys.size(); ++i) {
const auto& by_key = input_keys[i].by_key;
if (i == 0) {
n_by = by_key.size();
} else if (n_by != by_key.size()) {
return Status::Invalid("inconsistent size of by-key across inputs");
}
}
return n_by;
}

static Result<std::vector<col_index_t>> GetIndicesOfOnKey(
const std::vector<std::shared_ptr<Schema>>& input_schema,
const std::vector<asofjoin::AsofJoinKeys>& input_keys) {
if (input_schema.size() != input_keys.size()) {
return Status::Invalid("mismatching number of input schema and keys");
}
size_t n_input = input_schema.size();
std::vector<col_index_t> indices_of_on_key(n_input);
for (size_t i = 0; i < n_input; ++i) {
const auto& on_key = input_keys[i].on_key;
ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i],
FindColIndex(*input_schema[i], on_key, "on"));
}
return indices_of_on_key;
}

static Result<std::vector<std::vector<col_index_t>>> GetIndicesOfByKey(
const std::vector<std::shared_ptr<Schema>>& input_schema,
const std::vector<asofjoin::AsofJoinKeys>& input_keys) {
if (input_schema.size() != input_keys.size()) {
return Status::Invalid("mismatching number of input schema and keys");
}
ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(input_keys));
size_t n_input = input_schema.size();
std::vector<std::vector<col_index_t>> indices_of_by_key(
n_input, std::vector<col_index_t>(n_by));
for (size_t i = 0; i < n_input; ++i) {
for (size_t k = 0; k < n_by; k++) {
const auto& by_key = input_keys[i].by_key;
ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k],
FindColIndex(*input_schema[i], by_key[k], "by"));
}
}
return indices_of_by_key;
}

static arrow::Result<ExecNode*> Make(ExecPlan* plan, std::vector<ExecNode*> inputs,
const ExecNodeOptions& options) {
DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs";
Expand All @@ -1040,24 +1091,21 @@ class AsofJoinNode : public ExecNode {
join_options.tolerance);
}

size_t n_input = inputs.size(), n_by = join_options.by_key.size();
ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(join_options.input_keys));
size_t n_input = inputs.size();
std::vector<std::string> input_labels(n_input);
std::vector<col_index_t> indices_of_on_key(n_input);
std::vector<std::vector<col_index_t>> indices_of_by_key(
n_input, std::vector<col_index_t>(n_by));
std::vector<std::shared_ptr<Schema>> input_schema(n_input);
for (size_t i = 0; i < n_input; ++i) {
input_labels[i] = i == 0 ? "left" : "right_" + std::to_string(i);
const Schema& input_schema = *inputs[i]->output_schema();
ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i],
FindColIndex(input_schema, join_options.on_key, "on"));
for (size_t k = 0; k < n_by; k++) {
ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k],
FindColIndex(input_schema, join_options.by_key[k], "by"));
}
input_schema[i] = inputs[i]->output_schema();
}

ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Schema> output_schema,
MakeOutputSchema(inputs, indices_of_on_key, indices_of_by_key));
ARROW_ASSIGN_OR_RAISE(std::vector<col_index_t> indices_of_on_key,
GetIndicesOfOnKey(input_schema, join_options.input_keys));
ARROW_ASSIGN_OR_RAISE(std::vector<std::vector<col_index_t>> indices_of_by_key,
GetIndicesOfByKey(input_schema, join_options.input_keys));
ARROW_ASSIGN_OR_RAISE(
std::shared_ptr<Schema> output_schema,
MakeOutputSchema(input_schema, indices_of_on_key, indices_of_by_key));

std::vector<std::unique_ptr<KeyHasher>> key_hashers;
for (size_t i = 0; i < n_input; i++) {
Expand Down Expand Up @@ -1173,5 +1221,20 @@ void RegisterAsofJoinNode(ExecFactoryRegistry* registry) {
}
} // namespace internal

namespace asofjoin {

Result<std::shared_ptr<Schema>> MakeOutputSchema(
const std::vector<std::shared_ptr<Schema>>& input_schema,
const std::vector<AsofJoinKeys>& input_keys) {
ARROW_ASSIGN_OR_RAISE(std::vector<col_index_t> indices_of_on_key,
AsofJoinNode::GetIndicesOfOnKey(input_schema, input_keys));
ARROW_ASSIGN_OR_RAISE(std::vector<std::vector<col_index_t>> indices_of_by_key,
AsofJoinNode::GetIndicesOfByKey(input_schema, input_keys));
return AsofJoinNode::MakeOutputSchema(input_schema, indices_of_on_key,
indices_of_by_key);
}

} // namespace asofjoin

} // namespace compute
} // namespace arrow
37 changes: 37 additions & 0 deletions cpp/src/arrow/compute/exec/asof_join_node.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <vector>

#include "arrow/compute/exec.h"
#include "arrow/compute/exec/options.h"
#include "arrow/type.h"
#include "arrow/util/visibility.h"

namespace arrow {
namespace compute {
namespace asofjoin {

using AsofJoinKeys = AsofJoinNodeOptions::Keys;

ARROW_EXPORT Result<std::shared_ptr<Schema>> MakeOutputSchema(
const std::vector<std::shared_ptr<Schema>>& input_schema,
const std::vector<AsofJoinKeys>& input_keys);

} // namespace asofjoin
} // namespace compute
} // namespace arrow
27 changes: 18 additions & 9 deletions cpp/src/arrow/compute/exec/asof_join_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ void BuildZeroBaseBinaryArray(std::shared_ptr<Array>& empty, int64_t length) {
ASSERT_OK(builder.Finish(&empty));
}

AsofJoinNodeOptions GetRepeatedOptions(size_t repeat, FieldRef on_key,
std::vector<FieldRef> by_key, int64_t tolerance) {
std::vector<AsofJoinNodeOptions::Keys> input_keys(repeat);
for (size_t i = 0; i < repeat; i++) {
input_keys[i] = {on_key, by_key};
}
return AsofJoinNodeOptions(input_keys, tolerance);
}

// mutates by copying from_key into to_key and changing from_key to zero
Result<BatchesWithSchema> MutateByKey(BatchesWithSchema& batches, std::string from_key,
std::string to_key, bool replace_key = false,
Expand Down Expand Up @@ -246,7 +255,7 @@ void CheckRunOutput(const BatchesWithSchema& l_batches,
const BatchesWithSchema& r1_batches, const BatchesWithSchema& exp_batches, \
const FieldRef time, by_key_type key, const int64_t tolerance) { \
CheckRunOutput(l_batches, r0_batches, r1_batches, exp_batches, \
AsofJoinNodeOptions(time, {key}, tolerance)); \
GetRepeatedOptions(3, time, {key}, tolerance)); \
}

EXPAND_BY_KEY_TYPE(CHECK_RUN_OUTPUT)
Expand Down Expand Up @@ -298,7 +307,7 @@ void DoRunInvalidPlanTest(const std::shared_ptr<Schema>& l_schema,
const std::shared_ptr<Schema>& r_schema, int64_t tolerance,
const std::string& expected_error_str) {
DoRunInvalidPlanTest(l_schema, r_schema,
AsofJoinNodeOptions("time", {"key"}, tolerance),
GetRepeatedOptions(2, "time", {"key"}, tolerance),
expected_error_str);
}

Expand All @@ -321,27 +330,27 @@ void DoRunMissingKeysTest(const std::shared_ptr<Schema>& l_schema,
void DoRunMissingOnKeyTest(const std::shared_ptr<Schema>& l_schema,
const std::shared_ptr<Schema>& r_schema) {
DoRunInvalidPlanTest(l_schema, r_schema,
AsofJoinNodeOptions("invalid_time", {"key"}, 0),
GetRepeatedOptions(2, "invalid_time", {"key"}, 0),
"Bad join key on table : No match");
}

void DoRunMissingByKeyTest(const std::shared_ptr<Schema>& l_schema,
const std::shared_ptr<Schema>& r_schema) {
DoRunInvalidPlanTest(l_schema, r_schema,
AsofJoinNodeOptions("time", {"invalid_key"}, 0),
GetRepeatedOptions(2, "time", {"invalid_key"}, 0),
"Bad join key on table : No match");
}

void DoRunNestedOnKeyTest(const std::shared_ptr<Schema>& l_schema,
const std::shared_ptr<Schema>& r_schema) {
DoRunInvalidPlanTest(l_schema, r_schema, AsofJoinNodeOptions({0, "time"}, {"key"}, 0),
DoRunInvalidPlanTest(l_schema, r_schema, GetRepeatedOptions(2, {0, "time"}, {"key"}, 0),
"Bad join key on table : No match");
}

void DoRunNestedByKeyTest(const std::shared_ptr<Schema>& l_schema,
const std::shared_ptr<Schema>& r_schema) {
DoRunInvalidPlanTest(l_schema, r_schema,
AsofJoinNodeOptions("time", {FieldRef{0, 1}}, 0),
GetRepeatedOptions(2, "time", {FieldRef{0, 1}}, 0),
"Bad join key on table : No match");
}

Expand Down Expand Up @@ -402,7 +411,7 @@ void DoRunUnorderedPlanTest(bool l_unordered, bool r_unordered,
const std::shared_ptr<Schema>& l_schema,
const std::shared_ptr<Schema>& r_schema) {
DoRunUnorderedPlanTest(l_unordered, r_unordered, l_schema, r_schema,
AsofJoinNodeOptions("time", {"key"}, 1000),
GetRepeatedOptions(2, "time", {"key"}, 1000),
"out-of-order on-key values");
}

Expand Down Expand Up @@ -499,7 +508,7 @@ struct BasicTest {
ASSERT_OK_AND_ASSIGN(exp_nokey_batches,
MutateByKey(exp_nokey_batches, "key", "key2", true, true));
CheckRunOutput(l_batches, r0_batches, r1_batches, exp_nokey_batches,
AsofJoinNodeOptions("time", {"key2"}, tolerance));
GetRepeatedOptions(3, "time", {"key2"}, tolerance));
});
}
static void DoMutateNullKey(BasicTest& basic_tests) { basic_tests.RunMutateNullKey(); }
Expand All @@ -512,7 +521,7 @@ struct BasicTest {
ASSERT_OK_AND_ASSIGN(r1_batches,
MutateByKey(r1_batches, "key", "key", false, false, true));
CheckRunOutput(l_batches, r0_batches, r1_batches, exp_emptykey_batches,
AsofJoinNodeOptions("time", {}, tolerance));
GetRepeatedOptions(3, "time", {}, tolerance));
});
}
static void DoMutateEmptyKey(BasicTest& basic_tests) {
Expand Down
Loading

0 comments on commit 25d83a3

Please sign in to comment.