Skip to content

Commit

Permalink
feat: support also filter expression
Browse files Browse the repository at this point in the history
  • Loading branch information
davisusanibar committed May 30, 2023
2 parents 3e540fe + f2b0f8a commit 4099ca6
Show file tree
Hide file tree
Showing 12 changed files with 332 additions and 39 deletions.
36 changes: 30 additions & 6 deletions cpp/src/arrow/engine/substrait/expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1167,12 +1167,36 @@ Result<std::unique_ptr<substrait::Expression>> ToProto(
}

// other expression types dive into extensions immediately
ARROW_ASSIGN_OR_RAISE(
ExtensionIdRegistry::ArrowToSubstraitCall converter,
ext_set->registry()->GetArrowToSubstraitCall(call->function_name));
ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call));
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression::ScalarFunction> scalar_fn,
EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
Result<ExtensionIdRegistry::ArrowToSubstraitCall> maybe_converter =
ext_set->registry()->GetArrowToSubstraitCall(call->function_name);

ExtensionIdRegistry::ArrowToSubstraitCall converter;
std::unique_ptr<substrait::Expression::ScalarFunction> scalar_fn;
if (maybe_converter.ok()) {
converter = *maybe_converter;
ARROW_ASSIGN_OR_RAISE(SubstraitCall substrait_call, converter(*call));
ARROW_ASSIGN_OR_RAISE(
scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
} else if (maybe_converter.status().IsNotImplemented() &&
conversion_options.allow_arrow_extensions) {
if (call->options) {
return Status::NotImplemented(
"The function ", call->function_name,
" has no Substrait mapping. Arrow extensions are enabled but the call "
"contains function options and there is no current mechanism to encode those.");
}
SubstraitCall substrait_call(
Id{kArrowSimpleExtensionFunctionsUri, call->function_name},
call->type.GetSharedPtr(),
/*nullable=*/true);
for (int i = 0; i < static_cast<int>(call->arguments.size()); i++) {
substrait_call.SetValueArg(i, call->arguments[i]);
}
ARROW_ASSIGN_OR_RAISE(
scalar_fn, EncodeSubstraitCall(substrait_call, ext_set, conversion_options));
} else {
return maybe_converter.status();
}
out->set_allocated_scalar_function(scalar_fn.release());
return std::move(out);
}
Expand Down
35 changes: 21 additions & 14 deletions cpp/src/arrow/engine/substrait/extended_expression_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,25 +109,25 @@ Result<NamedExpression> ExpressionFromProto(
output_type.ToString(), " which doesn't have enough fields");
}
if (expression.output_names_size() == 0) {
// This is potentially invalid substrait but we can handle it
named_expr.name = "";
} else {
named_expr.name = expression.output_names(expression.output_names_size() - 1);
}
named_expr.name = expression.output_names(0);
return named_expr;
}

Result<std::unique_ptr<substrait::ExpressionReference>> CreateExpressionReference(
const NamedExpression& named_expression, ExtensionSet* ext_set,
const std::string& name, const Expression& expr, ExtensionSet* ext_set,
const ConversionOptions& conversion_options) {
auto expr_ref = std::make_unique<substrait::ExpressionReference>();
ARROW_RETURN_NOT_OK(
VisitNestedFields(*named_expression.expression.type(), [&](const Field& field) {
expr_ref->add_output_names(field.name());
return Status::OK();
}));
expr_ref->add_output_names(named_expression.name);
ARROW_ASSIGN_OR_RAISE(
std::unique_ptr<substrait::Expression> expression,
ToProto(named_expression.expression, ext_set, conversion_options));
ARROW_RETURN_NOT_OK(VisitNestedFields(*expr.type(), [&](const Field& field) {
expr_ref->add_output_names(field.name());
return Status::OK();
}));
expr_ref->add_output_names(name);
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::Expression> expression,
ToProto(expr, ext_set, conversion_options));
expr_ref->set_allocated_expression(expression.release());
return std::move(expr_ref);
}
Expand Down Expand Up @@ -176,9 +176,16 @@ Result<std::unique_ptr<substrait::ExtendedExpression>> ToProto(
ToProto(*bound_expressions.schema, ext_set, conversion_options));
expression->set_allocated_base_schema(base_schema.release());
for (const auto& named_expression : bound_expressions.named_expressions) {
ARROW_ASSIGN_OR_RAISE(
std::unique_ptr<substrait::ExpressionReference> expr_ref,
CreateExpressionReference(named_expression, ext_set, conversion_options));
Expression bound_expr = named_expression.expression;
if (!bound_expr.IsBound()) {
// This will use the default function registry. Most of the time that will be fine.
// In the cases where this is not what the user wants then the user should make sure
// to pass in bound expressions.
ARROW_ASSIGN_OR_RAISE(bound_expr, bound_expr.Bind(*bound_expressions.schema));
}
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<substrait::ExpressionReference> expr_ref,
CreateExpressionReference(named_expression.name, bound_expr,
ext_set, conversion_options));
expression->mutable_referred_expr()->AddAllocated(expr_ref.release());
}
RETURN_NOT_OK(AddExtensionSetToExtendedExpression(*ext_set, expression.get()));
Expand Down
8 changes: 7 additions & 1 deletion cpp/src/arrow/engine/substrait/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,8 @@ struct ARROW_ENGINE_EXPORT ConversionOptions {
: strictness(ConversionStrictness::BEST_EFFORT),
named_table_provider(kDefaultNamedTableProvider),
named_tap_provider(default_named_tap_provider()),
extension_provider(default_extension_provider()) {}
extension_provider(default_extension_provider()),
allow_arrow_extensions(true) {}

/// \brief How strictly the converter should adhere to the structure of the input.
ConversionStrictness strictness;
Expand All @@ -123,6 +124,11 @@ struct ARROW_ENGINE_EXPORT ConversionOptions {
///
/// The default behavior will provide for relations known to Arrow.
std::shared_ptr<ExtensionProvider> extension_provider;
/// \brief If true, when serializing, Arrow-specific types and functions will be allowed
///
/// Set to false to create plans that are more likely to be compatible with non-Arrow
/// engines
bool allow_arrow_extensions;
};

} // namespace engine
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/engine/substrait/relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "arrow/compute/api_aggregate.h"
//#include "arrow/compute/exec/exec_plan.h"
#include "arrow/acero/exec_plan.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/engine/substrait/visibility.h"
#include "arrow/type_fwd.h"

Expand Down
3 changes: 0 additions & 3 deletions cpp/src/arrow/engine/substrait/serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,6 @@ DeclarationFactory MakeWriteDeclarationFactory(
};
}

constexpr uint32_t kMinimumMajorVersion = 0;
constexpr uint32_t kMinimumMinorVersion = 20;

Result<std::vector<acero::Declaration>> DeserializePlans(
const Buffer& buf, DeclarationFactory declaration_factory,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out,
Expand Down
70 changes: 69 additions & 1 deletion cpp/src/arrow/engine/substrait/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6079,9 +6079,77 @@ void CheckExpressionRoundTrip(const Schema& schema,

TEST(Substrait, ExtendedExpressionSerialization) {
std::shared_ptr<Schema> test_schema =
schema({field("a", int32()), field("b", int32())});
schema({field("a", int32()), field("b", int32()), field("c", float32()),
field("nested", struct_({field("x", float32()), field("y", float32())}))});
// Basic a + b
CheckExpressionRoundTrip(
*test_schema, compute::call("add", {compute::field_ref(0), compute::field_ref(1)}));
// Nested struct reference
CheckExpressionRoundTrip(*test_schema, compute::field_ref(FieldPath{3, 0}));
// Struct return type
CheckExpressionRoundTrip(*test_schema, compute::field_ref(3));
// c + nested.y
CheckExpressionRoundTrip(
*test_schema,
compute::call("add", {compute::field_ref(2), compute::field_ref(FieldPath{3, 1})}));
}

TEST(Substrait, ExtendedExpressionInvalidPlans) {
// The schema defines the type as {"x", "y"} but output_names has {"a", "y"}
constexpr std::string_view kBadOuptutNames = R"(
{
"referredExpr":[
{
"expression":{
"selection":{
"directReference":{
"structField":{
"field":3
}
},
"rootReference":{}
}
},
"outputNames":["a", "y", "some_name"]
}
],
"baseSchema":{
"names":["a","b","c","nested","x","y"],
"struct":{
"types":[
{
"i32":{"nullability":"NULLABILITY_NULLABLE"}
},
{
"i32":{"nullability":"NULLABILITY_NULLABLE"}
},
{
"fp32":{"nullability":"NULLABILITY_NULLABLE"}
},
{
"struct":{
"types":[
{
"fp32":{"nullability":"NULLABILITY_NULLABLE"}
},
{
"fp32":{"nullability":"NULLABILITY_NULLABLE"}
}
],
"nullability":"NULLABILITY_NULLABLE"
}
}
]
}
},
"version":{"majorNumber":9999}
}
)";

std::shared_ptr<Buffer> buf = std::make_shared<Buffer>(kBadOuptutNames);

ASSERT_THAT(DeserializeExpressions(*buf),
Raises(StatusCode::Invalid, testing::HasSubstr("Ambiguous plan")));
}

} // namespace engine
Expand Down
43 changes: 30 additions & 13 deletions java/dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,10 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset
* Signature: (J[Ljava/lang/String;JJ)J
*/
JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner(
JNIEnv* env, jobject, jlong dataset_id, jobjectArray columnsSubset, jobject columnsToProduce, jlong batch_size,
JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns_subset, jobject columns_to_produce_or_filter, jlong batch_size,
jlong memory_pool_id) {
JNI_METHOD_START
std::cout << "Inicio createScanner" << std::endl;
arrow::MemoryPool* pool = reinterpret_cast<arrow::MemoryPool*>(memory_pool_id);
if (pool == nullptr) {
JniThrow("Memory pool does not exist or has been closed");
Expand All @@ -466,27 +467,43 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann
std::shared_ptr<arrow::dataset::ScannerBuilder> scanner_builder =
JniGetOrThrow(dataset->NewScan());
JniAssertOkOrThrow(scanner_builder->Pool(pool));
if (columnsSubset != nullptr) {
std::vector<std::string> column_vector = ToStringVector(env, columnsSubset);
if (columns_subset != nullptr) {
std::vector<std::string> column_vector = ToStringVector(env, columns_subset);
JniAssertOkOrThrow(scanner_builder->Project(column_vector));
}
if (columnsToProduce != nullptr) {
auto *buff = reinterpret_cast<jbyte*>(env->GetDirectBufferAddress(columnsToProduce));
int length = env->GetDirectBufferCapacity(columnsToProduce);
if (columns_to_produce_or_filter != nullptr) {
std::cout << "Inicio columns_to_produce_or_filter" << std::endl;
auto *buff = reinterpret_cast<jbyte*>(env->GetDirectBufferAddress(columns_to_produce_or_filter));
int length = env->GetDirectBufferCapacity(columns_to_produce_or_filter);
std::shared_ptr<arrow::Buffer> buffer = JniGetOrThrow(arrow::AllocateBuffer(length));
std::memcpy(buffer->mutable_data(), buff, length);
// execute expression
arrow::engine::BoundExpressions round_tripped =
std::cout << "Call DeserializeExpressions" << std::endl;
arrow::engine::BoundExpressions bounded_expression =
JniGetOrThrow(arrow::engine::DeserializeExpressions(*buffer));
// validate result
// create exprs / names
std::vector<arrow::compute::Expression> exprs;
std::vector<std::string> names;
for(arrow::engine::NamedExpression named_expression : round_tripped.named_expressions) {
exprs.push_back(named_expression.expression);
names.push_back(named_expression.name);
std::vector<arrow::compute::Expression> project_exprs;
std::vector<std::string> project_names;
arrow::compute::Expression filter_expr;
int filter_count = 0;
std::cout << "Iterate bounded_expression.named_expressions" << std::endl;
for(arrow::engine::NamedExpression named_expression : bounded_expression.named_expressions) {
if (named_expression.expression.type()->id() == arrow::Type::BOOL) {
std::cout << "Filter: " + named_expression.expression.ToString() << std::endl;
if (filter_count > 1) {
std::cout << "Error! Only one filter expression is supported" << std::endl;
}
filter_expr = named_expression.expression;
filter_count++;
} else {
std::cout << "Project: " + named_expression.expression.ToString() << std::endl;
project_exprs.push_back(named_expression.expression);
project_names.push_back(named_expression.name);
}
}
JniAssertOkOrThrow(scanner_builder->Project(exprs, names));
JniAssertOkOrThrow(scanner_builder->Project(project_exprs, project_names));
JniAssertOkOrThrow(scanner_builder->Filter(filter_expr));
}
JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,8 @@ public void testDeserializeExtendedExpressions() {
// Extended Expression 01: n_nationkey + 7
// Extended Expression 02: n_nationkey > 23
// OK generado por POJO: String binaryExtendedExpressions = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwKHggCEhovZnVuY3Rpb25zX2NvbXBhcmlzb24ueWFtbBIRGg8IARoLYWRkOmkzMl9pMzISEhoQCAIQARoKZ3Q6YW55X2FueRopChoaGBoEKgIQASIIGgYSBAoCEgAiBhoECgIoAhoLcHJvamVjdF9vbmUaKwocGhoIARoEKgIQASIIGgYSBAoCEgAiBhoECgIoChoLcHJvamVjdF90d28iGgoCSUQKBE5BTUUSDgoEKgIQAQoEYgIQARgC";
String binaryExtendedExpressions = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwKHggCEhovZnVuY3Rpb25zX2NvbXBhcmlzb24ueWFtbBIRGg8IARoLYWRkOmkzMl9pMzISEhoQCAIQARoKZ3Q6YW55X2FueRISGhAIAhACGgpsdDphbnlfYW55GikKGhoYGgQqAhABIggaBhIECgISACIGGgQKAigCGgtwcm9qZWN0X29uZRorChwaGggBGgQqAhABIggaBhIECgISACIGGgQKAigKGgtwcm9qZWN0X3R3bxoqChwaGggCGgQKAhABIggaBhIECgISACIGGgQKAigUGgpmaWx0ZXJfb25lIhoKAklECgROQU1FEg4KBCoCEAEKBGICEAEYAg==";
// String binaryExtendedExpressions = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwKHggCEhovZnVuY3Rpb25zX2NvbXBhcmlzb24ueWFtbBIRGg8IARoLYWRkOmkzMl9pMzISEhoQCAIQARoKZ3Q6YW55X2FueRISGhAIAhACGgpsdDphbnlfYW55GikKGhoYGgQqAhABIggaBhIECgISACIGGgQKAigCGgtwcm9qZWN0X29uZRorChwaGggBGgQqAhABIggaBhIECgISACIGGgQKAigKGgtwcm9qZWN0X3R3bxoqChwaGggCGgQKAhABIggaBhIECgISACIGGgQKAigUGgpmaWx0ZXJfb25lIhoKAklECgROQU1FEg4KBCoCEAEKBGICEAEYAg==";
String binaryExtendedExpressions = "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwKHggCEhovZnVuY3Rpb25zX2NvbXBhcmlzb24ueWFtbBIRGg8IARoLYWRkOmkzMl9pMzISFBoSCAIQARoMY29uY2F0OnZjaGFyEhIaEAgCEAIaCmx0OmFueV9hbnkaMQoaGhgaBCoCEAEiCBoGEgQKAhIAIgYaBAoCKAIaE2FkZF90d29fdG9fY29sdW1uX2EaOwoiGiAIARoEYgIQASIKGggSBgoEEgIIASIKGggSBgoEEgIIAhoVY29uY2F0X2NvbHVtbl9hX2FuZF9iGioKHBoaCAIaBAoCEAEiCBoGEgQKAhIAIgYaBAoCKBQaCmZpbHRlcl9vbmUiKgoCSUQKBE5BTUUKCExBU1ROQU1FEhQKBCoCEAEKBGICEAEKBGICEAEYAg==";
// String binaryExtendedExpressions =
// "Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwKHggCEhovZnVuY3Rpb25zX2NvbXBhcmlzb24ueWFtbBIRGg8IA" +
// "RoLYWRkOmkzMl9pMzISEhoQCAIQARoKZ3Q6YW55X2FueRooChwaGhoEKgIQAiIKGggSBgoCEgAiACIGGgQKAigCGg" +
Expand Down Expand Up @@ -271,4 +272,45 @@ public void testBaseParquetReadWithExtendedExpressions() throws Exception {
}
}
}

@Test(expected = RuntimeException.class)
public void testBaseParquetReadWithExtendedExpressionsProjectAndFilter() throws Exception {
// Extended Expression 01: id + 2
// Extended Expression 02: id > 10
// Parsed as: [column_0, add(FieldPath(0), 2), column_1, (FieldPath(0) > 10)] :
// -> Fail with: java.lang.RuntimeException: Inferring column projection from FieldRef FieldRef.FieldPath(0)
// Parsed as: [column_0, add(FieldPath("id"), 2), column_1, (FieldPath("id") > 10)] :
// -> OK
final Schema schema = new Schema(Arrays.asList(
Field.nullable("ID", new ArrowType.Int(32, true)),
Field.nullable("NAME", new ArrowType.Utf8())
), Collections.emptyMap());
// Base64.getEncoder().encodeToString(plan.toByteArray());
String binaryExtendedExpressions =
"Ch4IARIaL2Z1bmN0aW9uc19hcml0aG1ldGljLnlhbWwKHggCEhovZnVuY3Rpb25zX2NvbXBhcmlzb24ueWFtbBIRGg8IARoLYWRkOmkzMl9pMzISFBoSCAIQARoMY29uY2F0OnZjaGFyEhIaEAgCEAIaCmx0OmFueV9hbnkaMQoaGhgaBCoCEAEiCBoGEgQKAhIAIgYaBAoCKAIaE2FkZF90d29fdG9fY29sdW1uX2EaOwoiGiAIARoEYgIQASIKGggSBgoEEgIIASIKGggSBgoEEgIIARoVY29uY2F0X2NvbHVtbl9hX2FuZF9iGioKHBoaCAIaBAoCEAEiCBoGEgQKAhIAIgYaBAoCKBQaCmZpbHRlcl9vbmUiGgoCSUQKBE5BTUUSDgoEKgIQAQoEYgIQARgC";
// get binary plan
byte[] extendedExpressions = Base64.getDecoder().decode(binaryExtendedExpressions);
ByteBuffer substraitExtendedExpressions = ByteBuffer.allocateDirect(extendedExpressions.length);
substraitExtendedExpressions.put(extendedExpressions);
System.out.println("escasa");
ParquetWriteSupport writeSupport = ParquetWriteSupport
.writeTempFile(AVRO_SCHEMA_USER, TMP.newFolder(), 1, "a", 11, "b", 21, "c");
ScanOptions options = new ScanOptions(/*batchSize*/ 32768, Optional.empty(),
Optional.of(substraitExtendedExpressions));
System.out.println("nnnn");
try (
DatasetFactory datasetFactory = new FileSystemDatasetFactory(rootAllocator(), NativeMemoryPool.getDefault(),
FileFormat.PARQUET, writeSupport.getOutputURI());
Dataset dataset = datasetFactory.finish();
Scanner scanner = dataset.newScan(options);
ArrowReader reader = scanner.scanBatches()
) {
System.out.println("aaa");
while (reader.loadNextBatch()) {
try (VectorSchemaRoot root = reader.getVectorSchemaRoot()) {
System.out.print(root.contentToTSVString());
}
}
}
}
}
Loading

0 comments on commit 4099ca6

Please sign in to comment.