From ebb73ad83465a33538d45d6692feffaee0e793a6 Mon Sep 17 00:00:00 2001 From: praveenbingo Date: Sat, 16 Jun 2018 11:41:29 +0530 Subject: [PATCH] GDV-55: [C++] Added validation to projector build. (#33) Validating the input schema and expressions during the projector build. --- src/gandiva/src/cpp/include/gandiva/status.h | 25 +- .../cpp/include/gandiva/tree_expr_builder.h | 7 +- src/gandiva/src/cpp/integ/CMakeLists.txt | 1 + .../integ/projector_build_validation_test.cc | 260 ++++++++++++++++++ src/gandiva/src/cpp/integ/projector_test.cc | 1 - .../src/cpp/src/codegen/CMakeLists.txt | 1 + .../src/cpp/src/codegen/expr_decomposer.cc | 12 +- .../src/cpp/src/codegen/expr_decomposer.h | 8 +- .../src/cpp/src/codegen/expr_validator.cc | 132 +++++++++ .../src/cpp/src/codegen/expr_validator.h | 74 +++++ .../src/cpp/src/codegen/llvm_generator.h | 3 +- src/gandiva/src/cpp/src/codegen/node.h | 19 +- .../src/cpp/src/codegen/node_visitor.h | 9 +- src/gandiva/src/cpp/src/codegen/projector.cc | 23 +- src/gandiva/src/cpp/src/codegen/status.cc | 3 + .../src/cpp/src/codegen/tree_expr_builder.cc | 14 +- 16 files changed, 561 insertions(+), 31 deletions(-) create mode 100644 src/gandiva/src/cpp/integ/projector_build_validation_test.cc create mode 100644 src/gandiva/src/cpp/src/codegen/expr_validator.cc create mode 100644 src/gandiva/src/cpp/src/codegen/expr_validator.h diff --git a/src/gandiva/src/cpp/include/gandiva/status.h b/src/gandiva/src/cpp/include/gandiva/status.h index bfb76cb2425b3..64e56300b3ae0 100644 --- a/src/gandiva/src/cpp/include/gandiva/status.h +++ b/src/gandiva/src/cpp/include/gandiva/status.h @@ -19,6 +19,7 @@ #define GANDIVA_STATUS_H #include +#include #include #define GANDIVA_RETURN_NOT_OK(status) \ @@ -26,7 +27,8 @@ Status _status = (status); \ if (!_status.ok()) { \ std::stringstream ss; \ - ss << __FILE__ << ":" << __LINE__ << " code: " << #status << "\n" << _status.message(); \ + ss << __FILE__ << ":" << __LINE__ << " code: " << _status.CodeAsString() \ + << " \n " << _status.message(); \ return Status(_status.code(), ss.str()); \ } \ } while (0) @@ -36,7 +38,8 @@ do { if (!condition) { \ Status _status = (status); \ std::stringstream ss; \ - ss << __FILE__ << ":" << __LINE__ << " code: " << #status << "\n" << _status.message(); \ + ss << __FILE__ << ":" << __LINE__ << " code: " << _status.CodeAsString() \ + << " \n " << _status.message(); \ return Status(_status.code(), ss.str()); \ } \ } while (0) @@ -56,6 +59,7 @@ enum class StatusCode : char { Invalid = 1, CodeGenError = 2, ArrowError = 3, + ExpressionValidationError = 4, }; class Status { @@ -92,11 +96,26 @@ class Status { return Status(StatusCode::Invalid, msg); } + static Status ArrowError(const std::string& msg) { + return Status(StatusCode::ArrowError, msg); + } + + static Status ExpressionValidationError(const std::string& msg) { + return Status(StatusCode::ExpressionValidationError, msg); + } + + // Returns true if the status indicates success. bool ok() const { return (state_ == NULL); } bool IsCodeGenError() const { return code() == StatusCode::CodeGenError; } + bool IsInvalid() const { return code() == StatusCode::Invalid; } + + bool IsArrowError() const {return code() == StatusCode::ArrowError; } + + bool IsExpressionValidationError() const {return code() == StatusCode::ExpressionValidationError; } + // Return a string representation of this status suitable for printing. // Returns the string "OK" for success. std::string ToString() const; @@ -177,4 +196,4 @@ inline Status& Status::operator&=(Status&& s) { } } // namespace gandiva -#endif // GANDIVA_STATUS_H +#endif // GANDIVA_STATUS_H \ No newline at end of file diff --git a/src/gandiva/src/cpp/include/gandiva/tree_expr_builder.h b/src/gandiva/src/cpp/include/gandiva/tree_expr_builder.h index b6a4aa59820fb..8be241114da17 100644 --- a/src/gandiva/src/cpp/include/gandiva/tree_expr_builder.h +++ b/src/gandiva/src/cpp/include/gandiva/tree_expr_builder.h @@ -34,14 +34,17 @@ class TreeExprBuilder { static NodePtr MakeLiteral(double value); /// \brief create a node on arrow field. + /// returns null if input is null. static NodePtr MakeField(FieldPtr field); /// \brief create a node with a function. + /// returns null if return_type is null static NodePtr MakeFunction(const std::string &name, const NodeVector &children, DataTypePtr return_type); - /// \brief Create a node with an if-else expression. + /// \brief create a node with an if-else expression. + /// returns null if any of the inputs is null. static NodePtr MakeIf(NodePtr condition, NodePtr this_node, NodePtr else_node, @@ -49,10 +52,12 @@ class TreeExprBuilder { /// \brief create an expression with the specified root_node, and the /// result written to result_field. + /// returns null if the result_field is null. static ExpressionPtr MakeExpression(NodePtr root_node, FieldPtr result_field); /// \brief convenience function for simple function expressions. + /// returns null if the out_field is null. static ExpressionPtr MakeExpression(const std::string &function, const FieldVector &in_fields, FieldPtr out_field); diff --git a/src/gandiva/src/cpp/integ/CMakeLists.txt b/src/gandiva/src/cpp/integ/CMakeLists.txt index 0e54804eff085..2dfce0ebc0173 100644 --- a/src/gandiva/src/cpp/integ/CMakeLists.txt +++ b/src/gandiva/src/cpp/integ/CMakeLists.txt @@ -17,3 +17,4 @@ project(gandiva) add_gandiva_integ_test(projector_test.cc) add_gandiva_integ_test(if_expr_test.cc) add_gandiva_integ_test(literal_test.cc) +add_gandiva_integ_test(projector_build_validation_test.cc) diff --git a/src/gandiva/src/cpp/integ/projector_build_validation_test.cc b/src/gandiva/src/cpp/integ/projector_build_validation_test.cc new file mode 100644 index 0000000000000..0a802cafdb769 --- /dev/null +++ b/src/gandiva/src/cpp/integ/projector_build_validation_test.cc @@ -0,0 +1,260 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "arrow/memory_pool.h" +#include "integ/test_util.h" +#include "gandiva/projector.h" +#include "gandiva/tree_expr_builder.h" + +namespace gandiva { + +using arrow::int32; +using arrow::float32; +using arrow::boolean; + +class TestProjector : public ::testing::Test { + public: + void SetUp() { pool_ = arrow::default_memory_pool(); } + + protected: + arrow::MemoryPool* pool_; +}; + +TEST_F(TestProjector, TestNonExistentFunction) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto schema = arrow::schema({field0, field1}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = TreeExprBuilder::MakeExpression("non_existent_function", + {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, pool_, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Function bool non_existent_function(float, float) not supported yet."; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestNotMatchingDataType) { + // schema for input fields + auto field0 = field("f0", float32()); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto lt_expr = TreeExprBuilder::MakeExpression(node_f0, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, pool_, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Return type of root node float does not match that of expression bool"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); + } + +TEST_F(TestProjector, TestNotSupportedDataType) { + // schema for input fields + auto field0 = field("f0", list(int32())); + auto schema = arrow::schema({field0}); + + // output fields + auto field_result = field("res", list(int32())); + + // Build expression + auto node_f0 = TreeExprBuilder::MakeField(field0); + auto lt_expr = TreeExprBuilder::MakeExpression(node_f0, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, pool_, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Field f0 has unsupported data type list"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestIncorrectSchemaMissingField) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto schema = arrow::schema({field0, field0}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = TreeExprBuilder::MakeExpression("less_than", + {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, pool_, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = "Field f2 not in schema"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestIncorrectSchemaTypeNotMatching) { + // schema for input fields + auto field0 = field("f0", float32()); + auto field1 = field("f2", float32()); + auto field2 = field("f2", int32()); + auto schema = arrow::schema({field0, field2}); + + // output fields + auto field_result = field("res", boolean()); + + // Build expression + auto lt_expr = TreeExprBuilder::MakeExpression("less_than", + {field0, field1}, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {lt_expr}, pool_, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::cout< b) + // a + // else + // b + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = TreeExprBuilder::MakeFunction("non_existent_function", + {node_a, node_b}, + boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, pool_, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); +} + +TEST_F(TestProjector, TestIfNotMatchingReturnType) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto condition = TreeExprBuilder::MakeFunction("less_than", + {node_a, node_b}, + boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_b, boolean()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, pool_, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Return type of if bool and then int32 not matching."; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestElseNotMatchingReturnType) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", boolean()); + auto schema = arrow::schema({fielda, fieldb, fieldc}); + + // output fields + auto field_result = field("res", int32()); + + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto condition = TreeExprBuilder::MakeFunction("less_than", + {node_a, node_b}, + boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_c, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, pool_, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Return type of if int32 and else bool not matching."; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +TEST_F(TestProjector, TestElseNotSupportedType) { + // schema for input fields + auto fielda = field("a", int32()); + auto fieldb = field("b", int32()); + auto fieldc = field("c", list(int32())); + auto schema = arrow::schema({fielda, fieldb}); + + // output fields + auto field_result = field("res", int32()); + + + auto node_a = TreeExprBuilder::MakeField(fielda); + auto node_b = TreeExprBuilder::MakeField(fieldb); + auto node_c = TreeExprBuilder::MakeField(fieldc); + auto condition = TreeExprBuilder::MakeFunction("less_than", + {node_a, node_b}, + boolean()); + auto if_node = TreeExprBuilder::MakeIf(condition, node_a, node_c, int32()); + + auto expr = TreeExprBuilder::MakeExpression(if_node, field_result); + + // Build a projector for the expressions. + std::shared_ptr projector; + Status status = Projector::Make(schema, {expr}, pool_, &projector); + EXPECT_TRUE(status.IsExpressionValidationError()); + std::string expected_error = + "Field c has unsupported data type list"; + EXPECT_TRUE(status.message().find(expected_error) != std::string::npos); +} + +} // namespace gandiva diff --git a/src/gandiva/src/cpp/integ/projector_test.cc b/src/gandiva/src/cpp/integ/projector_test.cc index c95e0cab88550..e53b7d3e6ba17 100644 --- a/src/gandiva/src/cpp/integ/projector_test.cc +++ b/src/gandiva/src/cpp/integ/projector_test.cc @@ -433,5 +433,4 @@ TEST_F(TestProjector, TestZeroCopyNegative) { status = projector->Evaluate(*in_batch, {bad_array_data3}); EXPECT_EQ(status.code(), StatusCode::Invalid); } - } // namespace gandiva diff --git a/src/gandiva/src/cpp/src/codegen/CMakeLists.txt b/src/gandiva/src/cpp/src/codegen/CMakeLists.txt index 213d0e4291899..a0ae1548ed00c 100644 --- a/src/gandiva/src/cpp/src/codegen/CMakeLists.txt +++ b/src/gandiva/src/cpp/src/codegen/CMakeLists.txt @@ -32,6 +32,7 @@ add_library(gandiva SHARED projector.cc status.cc tree_expr_builder.cc + expr_validator.cc ${BC_FILE_PATH_CC}) # For users of gandiva library (including integ tests), include-dir is : diff --git a/src/gandiva/src/cpp/src/codegen/expr_decomposer.cc b/src/gandiva/src/cpp/src/codegen/expr_decomposer.cc index 527a0da67a2c8..7e7fc2023df12 100644 --- a/src/gandiva/src/cpp/src/codegen/expr_decomposer.cc +++ b/src/gandiva/src/cpp/src/codegen/expr_decomposer.cc @@ -28,17 +28,18 @@ namespace gandiva { // Decompose a field node - simply seperate out validity & value arrays. -void ExprDecomposer::Visit(const FieldNode &node) { +Status ExprDecomposer::Visit(const FieldNode &node) { auto desc = annotator_.CheckAndAddInputFieldDescriptor(node.field()); DexPtr validity_dex = std::make_shared(desc); DexPtr value_dex = std::make_shared(desc); result_ = std::make_shared(validity_dex, value_dex); + return Status::OK(); } // Decompose a field node - wherever possible, merge the validity vectors of the // child nodes. -void ExprDecomposer::Visit(const FunctionNode &node) { +Status ExprDecomposer::Visit(const FunctionNode &node) { auto desc = node.descriptor(); FunctionSignature signature(desc->name(), desc->params(), @@ -84,10 +85,11 @@ void ExprDecomposer::Visit(const FunctionNode &node) { local_bitmap_idx); result_ = std::make_shared(validity_dex, value_dex); } + return Status::OK(); } // Decompose an IfNode -void ExprDecomposer::Visit(const IfNode &node) { +Status ExprDecomposer::Visit(const IfNode &node) { // Add a local bitmap to track the output validity. node.condition()->Accept(*this); auto condition_vv = result(); @@ -111,11 +113,13 @@ void ExprDecomposer::Visit(const IfNode &node) { is_terminal_else); result_ = std::make_shared(validity_dex, value_dex); + return Status::OK(); } -void ExprDecomposer::Visit(const LiteralNode &node) { +Status ExprDecomposer::Visit(const LiteralNode &node) { auto value_dex = std::make_shared(node.return_type(), node.holder()); result_ = std::make_shared(value_dex); + return Status::OK(); } // The bolow functions use a stack to detect : diff --git a/src/gandiva/src/cpp/src/codegen/expr_decomposer.h b/src/gandiva/src/cpp/src/codegen/expr_decomposer.h index 6c21c7605ea75..e6ffaf6d8b9c7 100644 --- a/src/gandiva/src/cpp/src/codegen/expr_decomposer.h +++ b/src/gandiva/src/cpp/src/codegen/expr_decomposer.h @@ -49,10 +49,10 @@ class ExprDecomposer : public NodeVisitor { FRIEND_TEST(TestExprDecomposer, TestInternalIf); FRIEND_TEST(TestExprDecomposer, TestParallelIf); - void Visit(const FieldNode &node) override; - void Visit(const FunctionNode &node) override; - void Visit(const IfNode &node) override; - void Visit(const LiteralNode &node) override; + Status Visit(const FieldNode &node) override; + Status Visit(const FunctionNode &node) override; + Status Visit(const IfNode &node) override; + Status Visit(const LiteralNode &node) override; // stack of if nodes. class IfStackEntry { diff --git a/src/gandiva/src/cpp/src/codegen/expr_validator.cc b/src/gandiva/src/cpp/src/codegen/expr_validator.cc new file mode 100644 index 0000000000000..2b280abff02bb --- /dev/null +++ b/src/gandiva/src/cpp/src/codegen/expr_validator.cc @@ -0,0 +1,132 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "codegen/expr_validator.h" + +namespace gandiva { + +Status ExprValidator::Validate(const ExpressionPtr &expr) { + if (expr == nullptr) { + return Status::ExpressionValidationError("Expression cannot be null."); + } + Node &root = *expr->root(); + Status status = root.Accept(*this); + if (!status.ok()) { + return status; + } + // validate return type matches + // no need to check if type is supported + // since root type has been validated. + if (!root.return_type()->Equals(*expr->result()->type())) { + std::stringstream ss; + ss << "Return type of root node " << root.return_type()->name() + << " does not match that of expression " << *expr->result()->type(); + return Status::ExpressionValidationError(ss.str()); + } + return Status::OK(); +} + +Status ExprValidator::Visit(const FieldNode &node) { + auto llvm_type = types_->IRType(node.return_type()->id()); + if (llvm_type == nullptr) { + std::stringstream ss; + ss << "Field "<< node.field()->name() << " has unsupported data type " + << node.return_type()->name(); + return Status::ExpressionValidationError(ss.str()); + } + + auto field_in_schema_entry = field_map_.find(node.field()->name()); + + // validate that field is in schema. + if (field_in_schema_entry == field_map_.end()) { + std::stringstream ss; + ss << "Field " << node.field()->name() << " not in schema."; + return Status::ExpressionValidationError(ss.str()); + } + + FieldPtr field_in_schema = field_in_schema_entry->second; + // validate that field matches the definition in schema. + if (!field_in_schema->Equals(node.field())) { + std::stringstream ss; + ss << "Field definition in schema " << field_in_schema->ToString() + << " different from field in expression " << node.field()->ToString(); + return Status::ExpressionValidationError(ss.str()); + } + return Status::OK(); +} + +Status ExprValidator::Visit(const FunctionNode &node) { + auto desc = node.descriptor(); + FunctionSignature signature(desc->name(), + desc->params(), + desc->return_type()); + const NativeFunction *native_function = registry_.LookupSignature(signature); + if (native_function == nullptr) { + std::stringstream ss; + ss << "Function "<< signature.ToString() << " not supported yet. "; + return Status::ExpressionValidationError(ss.str()); + } + + for (auto &child : node.children()) { + Status status = child->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + } + return Status::OK(); +} + +Status ExprValidator::Visit(const IfNode &node) { + Status status = node.condition()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + status = node.then_node()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + status = node.else_node()->Accept(*this); + GANDIVA_RETURN_NOT_OK(status); + + auto if_node_ret_type = node.return_type(); + auto then_node_ret_type = node.then_node()->return_type(); + auto else_node_ret_type = node.else_node()->return_type(); + + if (if_node_ret_type != then_node_ret_type) { + std::stringstream ss; + ss << "Return type of if "<< *if_node_ret_type << " and then " + << then_node_ret_type->name() << " not matching."; + return Status::ExpressionValidationError(ss.str()); + } + + if (if_node_ret_type != else_node_ret_type) { + std::stringstream ss; + ss << "Return type of if "<< *if_node_ret_type << " and else " + << else_node_ret_type->name() << " not matching."; + return Status::ExpressionValidationError(ss.str()); + } + + return Status::OK(); +} + +Status ExprValidator::Visit(const LiteralNode &node) { + auto llvm_type = types_->IRType(node.return_type()->id()); + if (llvm_type == nullptr) { + std::stringstream ss; + ss << "Value "<< node.holder() << " has unsupported data type " + << node.return_type()->name(); + return Status::ExpressionValidationError(ss.str()); + } + return Status::OK(); +} + +} // namespace gandiva diff --git a/src/gandiva/src/cpp/src/codegen/expr_validator.h b/src/gandiva/src/cpp/src/codegen/expr_validator.h new file mode 100644 index 0000000000000..ea43f93ddeb65 --- /dev/null +++ b/src/gandiva/src/cpp/src/codegen/expr_validator.h @@ -0,0 +1,74 @@ +// Copyright (C) 2017-2018 Dremio Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GANDIVA_EXPR_VALIDATOR_H +#define GANDIVA_EXPR_VALIDATOR_H + +#include +#include + +#include "boost/functional/hash.hpp" +#include "codegen/function_registry.h" +#include "codegen/node_visitor.h" +#include "codegen/node.h" +#include "codegen/llvm_types.h" +#include "gandiva/arrow.h" +#include "gandiva/expression.h" +#include "gandiva/status.h" + +namespace gandiva { + +class FunctionRegistry; + +/// \brief Validates the entire expression tree including +/// data types, signatures and return types +class ExprValidator : public NodeVisitor { + public: + explicit ExprValidator(LLVMTypes * types, SchemaPtr schema) + : types_(types), + schema_(schema) { + for (auto &field : schema_->fields()) { + field_map_[field->name()] = field; + } + } + + /// \brief Validates the root node + /// of an expression. + /// 1. Data type of fields and literals. + /// 2. Function signature is supported. + /// 3. For if nodes that return types match + /// for if, then and else nodes. + Status Validate(const ExpressionPtr &expr); + + private: + Status Visit(const FieldNode &node) override; + Status Visit(const FunctionNode &node) override; + Status Visit(const IfNode &node) override; + Status Visit(const LiteralNode &node) override; + + FunctionRegistry registry_; + + LLVMTypes *types_; + + SchemaPtr schema_; + + using FieldMap = std::unordered_map>; + FieldMap field_map_; +}; + +} // namespace gandiva + +#endif //GANDIVA_EXPR_VALIDATOR_H diff --git a/src/gandiva/src/cpp/src/codegen/llvm_generator.h b/src/gandiva/src/cpp/src/codegen/llvm_generator.h index 3e84308c4f348..f370092447581 100644 --- a/src/gandiva/src/cpp/src/codegen/llvm_generator.h +++ b/src/gandiva/src/cpp/src/codegen/llvm_generator.h @@ -49,6 +49,8 @@ class LLVMGenerator { Status Execute(const arrow::RecordBatch &record_batch, const ArrayDataVector &output_vector); + LLVMTypes *types() { return types_; } + private: LLVMGenerator(); @@ -59,7 +61,6 @@ class LLVMGenerator { llvm::Module *module() { return engine_->module(); } llvm::LLVMContext &context() { return *(engine_->context()); } llvm::IRBuilder<> &ir_builder() { return engine_->ir_builder(); } - LLVMTypes *types() { return types_; } /// Visitor to generate the code for a decomposed expression. class Visitor : public DexVisitor { diff --git a/src/gandiva/src/cpp/src/codegen/node.h b/src/gandiva/src/cpp/src/codegen/node.h index ddfbc490aac20..aded3f6f17aa0 100644 --- a/src/gandiva/src/cpp/src/codegen/node.h +++ b/src/gandiva/src/cpp/src/codegen/node.h @@ -23,6 +23,7 @@ #include "codegen/node_visitor.h" #include "gandiva/arrow.h" #include "gandiva/gandiva_aliases.h" +#include "gandiva/status.h" namespace gandiva { @@ -36,7 +37,7 @@ class Node { const DataTypePtr &return_type() const { return return_type_; } /// Derived classes should simply invoke the Visit api of the visitor. - virtual void Accept(NodeVisitor &visitor) const = 0; + virtual Status Accept(NodeVisitor &visitor) const = 0; protected: DataTypePtr return_type_; @@ -49,8 +50,8 @@ class LiteralNode : public Node { : Node(type), holder_(holder) {} - void Accept(NodeVisitor &visitor) const override { - visitor.Visit(*this); + Status Accept(NodeVisitor &visitor) const override { + return visitor.Visit(*this); } const LiteralHolder &holder() const { return holder_; } @@ -65,8 +66,8 @@ class FieldNode : public Node { explicit FieldNode(FieldPtr field) : Node(field->type()), field_(field) {} - void Accept(NodeVisitor &visitor) const override { - visitor.Visit(*this); + Status Accept(NodeVisitor &visitor) const override { + return visitor.Visit(*this); } const FieldPtr &field() const { return field_; } @@ -83,8 +84,8 @@ class FunctionNode : public Node { DataTypePtr retType) : Node(retType), descriptor_(descriptor), children_(children) { } - void Accept(NodeVisitor &visitor) const override { - visitor.Visit(*this); + Status Accept(NodeVisitor &visitor) const override { + return visitor.Visit(*this); } const FuncDescriptorPtr &descriptor() const { return descriptor_; } @@ -125,8 +126,8 @@ class IfNode : public Node { then_node_(then_node), else_node_(else_node) {} - void Accept(NodeVisitor &visitor) const override { - visitor.Visit(*this); + Status Accept(NodeVisitor &visitor) const override { + return visitor.Visit(*this); } const NodePtr &condition() const { return condition_; } diff --git a/src/gandiva/src/cpp/src/codegen/node_visitor.h b/src/gandiva/src/cpp/src/codegen/node_visitor.h index 6ba442a477e37..3252cc799f1b8 100644 --- a/src/gandiva/src/cpp/src/codegen/node_visitor.h +++ b/src/gandiva/src/cpp/src/codegen/node_visitor.h @@ -16,6 +16,7 @@ #define GANDIVA_NODE_VISITOR_H #include "gandiva/logging.h" +#include "gandiva/status.h" namespace gandiva { @@ -27,10 +28,10 @@ class LiteralNode; /// \brief Visitor for nodes in the expression tree. class NodeVisitor { public: - virtual void Visit(const FieldNode &node) = 0; - virtual void Visit(const FunctionNode &node) = 0; - virtual void Visit(const IfNode &node) = 0; - virtual void Visit(const LiteralNode &node) = 0; + virtual Status Visit(const FieldNode &node) = 0; + virtual Status Visit(const FunctionNode &node) = 0; + virtual Status Visit(const IfNode &node) = 0; + virtual Status Visit(const LiteralNode &node) = 0; }; } // namespace gandiva diff --git a/src/gandiva/src/cpp/src/codegen/projector.cc b/src/gandiva/src/cpp/src/codegen/projector.cc index 084c6f3c3a1d4..88dda81990ad0 100644 --- a/src/gandiva/src/cpp/src/codegen/projector.cc +++ b/src/gandiva/src/cpp/src/codegen/projector.cc @@ -18,7 +18,9 @@ #include #include +#include "codegen/expr_validator.h" #include "codegen/llvm_generator.h" +#include "gandiva/status.h" namespace gandiva { @@ -35,13 +37,24 @@ Status Projector::Make(SchemaPtr schema, const ExpressionVector &exprs, arrow::MemoryPool *pool, std::shared_ptr *projector) { - // TODO: validate schema - // TODO : validate expressions (fields, function signatures, output types, ..) - + GANDIVA_RETURN_FAILURE_IF_FALSE((schema != nullptr), + Status::Invalid("schema cannot be null")); + GANDIVA_RETURN_FAILURE_IF_FALSE(!exprs.empty(), + Status::Invalid("expressions need to be non-empty")); // Build LLVM generator, and generate code for the specified expressions std::unique_ptr llvm_gen; Status status = LLVMGenerator::Make(&llvm_gen); GANDIVA_RETURN_NOT_OK(status); + + // Run the validation on the expressions. + // Return if any of the expression is invalid since + // we will not be able to process further. + ExprValidator expr_validator(llvm_gen->types(), schema); + for (auto &expr : exprs) { + status = expr_validator.Validate(expr); + GANDIVA_RETURN_NOT_OK(status); + } + llvm_gen->Build(exprs); // save the output field types. Used for validation at Evaluate() time. @@ -96,6 +109,10 @@ Status Projector::Evaluate(const arrow::RecordBatch &batch, return Status::Invalid("output must be non-null."); } + if (pool_ == nullptr) { + return Status::Invalid("memory pool must be non-null."); + } + // Allocate the output data vecs. ArrayDataVector output_data_vecs; for (auto &field : output_fields_) { diff --git a/src/gandiva/src/cpp/src/codegen/status.cc b/src/gandiva/src/cpp/src/codegen/status.cc index 2a36b9e482a03..987eefbd50209 100644 --- a/src/gandiva/src/cpp/src/codegen/status.cc +++ b/src/gandiva/src/cpp/src/codegen/status.cc @@ -53,6 +53,9 @@ std::string Status::CodeAsString() const { case StatusCode::Invalid: type = "Invalid"; break; + case StatusCode::ExpressionValidationError: + type = "ExpressionValidationError"; + break; default: type = "Unknown"; break; diff --git a/src/gandiva/src/cpp/src/codegen/tree_expr_builder.cc b/src/gandiva/src/cpp/src/codegen/tree_expr_builder.cc index 5c4eba44399a8..fef670672d5d5 100644 --- a/src/gandiva/src/cpp/src/codegen/tree_expr_builder.cc +++ b/src/gandiva/src/cpp/src/codegen/tree_expr_builder.cc @@ -38,6 +38,9 @@ NodePtr TreeExprBuilder::MakeField(FieldPtr field) { NodePtr TreeExprBuilder::MakeFunction(const std::string &name, const NodeVector ¶ms, DataTypePtr result) { + if (result == nullptr) { + return nullptr; + } return FunctionNode::MakeFunction(name, params, result); } @@ -45,11 +48,18 @@ NodePtr TreeExprBuilder::MakeIf(NodePtr condition, NodePtr then_node, NodePtr else_node, DataTypePtr result_type) { + if (condition == nullptr || then_node == nullptr || + else_node == nullptr || result_type == nullptr) { + return nullptr; + } return std::make_shared(condition, then_node, else_node, result_type); } ExpressionPtr TreeExprBuilder::MakeExpression(NodePtr root_node, FieldPtr result_field) { + if (result_field == nullptr) { + return nullptr; + } return ExpressionPtr(new Expression(root_node, result_field)); } @@ -57,7 +67,9 @@ ExpressionPtr TreeExprBuilder::MakeExpression( const std::string &function, const FieldVector &in_fields, FieldPtr out_field) { - + if (out_field == nullptr) { + return nullptr; + } std::vector field_nodes; for (auto it = in_fields.begin(); it != in_fields.end(); ++it) { auto node = MakeField(*it);