Skip to content

Commit

Permalink
GDV-55: [C++] Added validation to projector build. (apache#33)
Browse files Browse the repository at this point in the history
Validating the input schema and expressions during the projector build.
  • Loading branch information
praveenbingo committed Sep 10, 2018
1 parent ae10571 commit 30eab61
Show file tree
Hide file tree
Showing 16 changed files with 561 additions and 31 deletions.
1 change: 1 addition & 0 deletions cpp/src/gandiva/codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ add_library(gandiva SHARED
projector.cc
status.cc
tree_expr_builder.cc
expr_validator.cc
${BC_FILE_PATH_CC})

# For users of gandiva library (including integ tests), include-dir is :
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/gandiva/codegen/expr_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,18 @@
namespace gandiva {

// Decompose a field node - simply seperate out validity & value arrays.
void ExprDecomposer::Visit(const FieldNode &node) {
Status ExprDecomposer::Visit(const FieldNode &node) {
auto desc = annotator_.CheckAndAddInputFieldDescriptor(node.field());

DexPtr validity_dex = std::make_shared<VectorReadValidityDex>(desc);
DexPtr value_dex = std::make_shared<VectorReadValueDex>(desc);
result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
return Status::OK();
}

// Decompose a field node - wherever possible, merge the validity vectors of the
// child nodes.
void ExprDecomposer::Visit(const FunctionNode &node) {
Status ExprDecomposer::Visit(const FunctionNode &node) {
auto desc = node.descriptor();
FunctionSignature signature(desc->name(),
desc->params(),
Expand Down Expand Up @@ -84,10 +85,11 @@ void ExprDecomposer::Visit(const FunctionNode &node) {
local_bitmap_idx);
result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
}
return Status::OK();
}

// Decompose an IfNode
void ExprDecomposer::Visit(const IfNode &node) {
Status ExprDecomposer::Visit(const IfNode &node) {
// Add a local bitmap to track the output validity.
node.condition()->Accept(*this);
auto condition_vv = result();
Expand All @@ -111,11 +113,13 @@ void ExprDecomposer::Visit(const IfNode &node) {
is_terminal_else);

result_ = std::make_shared<ValueValidityPair>(validity_dex, value_dex);
return Status::OK();
}

void ExprDecomposer::Visit(const LiteralNode &node) {
Status ExprDecomposer::Visit(const LiteralNode &node) {
auto value_dex = std::make_shared<LiteralDex>(node.return_type(), node.holder());
result_ = std::make_shared<ValueValidityPair>(value_dex);
return Status::OK();
}

// The bolow functions use a stack to detect :
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/gandiva/codegen/expr_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ class ExprDecomposer : public NodeVisitor {
FRIEND_TEST(TestExprDecomposer, TestInternalIf);
FRIEND_TEST(TestExprDecomposer, TestParallelIf);

void Visit(const FieldNode &node) override;
void Visit(const FunctionNode &node) override;
void Visit(const IfNode &node) override;
void Visit(const LiteralNode &node) override;
Status Visit(const FieldNode &node) override;
Status Visit(const FunctionNode &node) override;
Status Visit(const IfNode &node) override;
Status Visit(const LiteralNode &node) override;

// stack of if nodes.
class IfStackEntry {
Expand Down
132 changes: 132 additions & 0 deletions cpp/src/gandiva/codegen/expr_validator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright (C) 2017-2018 Dremio Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <string>
#include <sstream>
#include <vector>

#include "codegen/expr_validator.h"

namespace gandiva {

Status ExprValidator::Validate(const ExpressionPtr &expr) {
if (expr == nullptr) {
return Status::ExpressionValidationError("Expression cannot be null.");
}
Node &root = *expr->root();
Status status = root.Accept(*this);
if (!status.ok()) {
return status;
}
// validate return type matches
// no need to check if type is supported
// since root type has been validated.
if (!root.return_type()->Equals(*expr->result()->type())) {
std::stringstream ss;
ss << "Return type of root node " << root.return_type()->name()
<< " does not match that of expression " << *expr->result()->type();
return Status::ExpressionValidationError(ss.str());
}
return Status::OK();
}

Status ExprValidator::Visit(const FieldNode &node) {
auto llvm_type = types_->IRType(node.return_type()->id());
if (llvm_type == nullptr) {
std::stringstream ss;
ss << "Field "<< node.field()->name() << " has unsupported data type "
<< node.return_type()->name();
return Status::ExpressionValidationError(ss.str());
}

auto field_in_schema_entry = field_map_.find(node.field()->name());

// validate that field is in schema.
if (field_in_schema_entry == field_map_.end()) {
std::stringstream ss;
ss << "Field " << node.field()->name() << " not in schema.";
return Status::ExpressionValidationError(ss.str());
}

FieldPtr field_in_schema = field_in_schema_entry->second;
// validate that field matches the definition in schema.
if (!field_in_schema->Equals(node.field())) {
std::stringstream ss;
ss << "Field definition in schema " << field_in_schema->ToString()
<< " different from field in expression " << node.field()->ToString();
return Status::ExpressionValidationError(ss.str());
}
return Status::OK();
}

Status ExprValidator::Visit(const FunctionNode &node) {
auto desc = node.descriptor();
FunctionSignature signature(desc->name(),
desc->params(),
desc->return_type());
const NativeFunction *native_function = registry_.LookupSignature(signature);
if (native_function == nullptr) {
std::stringstream ss;
ss << "Function "<< signature.ToString() << " not supported yet. ";
return Status::ExpressionValidationError(ss.str());
}

for (auto &child : node.children()) {
Status status = child->Accept(*this);
GANDIVA_RETURN_NOT_OK(status);
}
return Status::OK();
}

Status ExprValidator::Visit(const IfNode &node) {
Status status = node.condition()->Accept(*this);
GANDIVA_RETURN_NOT_OK(status);
status = node.then_node()->Accept(*this);
GANDIVA_RETURN_NOT_OK(status);
status = node.else_node()->Accept(*this);
GANDIVA_RETURN_NOT_OK(status);

auto if_node_ret_type = node.return_type();
auto then_node_ret_type = node.then_node()->return_type();
auto else_node_ret_type = node.else_node()->return_type();

if (if_node_ret_type != then_node_ret_type) {
std::stringstream ss;
ss << "Return type of if "<< *if_node_ret_type << " and then "
<< then_node_ret_type->name() << " not matching.";
return Status::ExpressionValidationError(ss.str());
}

if (if_node_ret_type != else_node_ret_type) {
std::stringstream ss;
ss << "Return type of if "<< *if_node_ret_type << " and else "
<< else_node_ret_type->name() << " not matching.";
return Status::ExpressionValidationError(ss.str());
}

return Status::OK();
}

Status ExprValidator::Visit(const LiteralNode &node) {
auto llvm_type = types_->IRType(node.return_type()->id());
if (llvm_type == nullptr) {
std::stringstream ss;
ss << "Value "<< node.holder() << " has unsupported data type "
<< node.return_type()->name();
return Status::ExpressionValidationError(ss.str());
}
return Status::OK();
}

} // namespace gandiva
74 changes: 74 additions & 0 deletions cpp/src/gandiva/codegen/expr_validator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright (C) 2017-2018 Dremio Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef GANDIVA_EXPR_VALIDATOR_H
#define GANDIVA_EXPR_VALIDATOR_H

#include <string>
#include <unordered_map>

#include "boost/functional/hash.hpp"
#include "codegen/function_registry.h"
#include "codegen/node_visitor.h"
#include "codegen/node.h"
#include "codegen/llvm_types.h"
#include "gandiva/arrow.h"
#include "gandiva/expression.h"
#include "gandiva/status.h"

namespace gandiva {

class FunctionRegistry;

/// \brief Validates the entire expression tree including
/// data types, signatures and return types
class ExprValidator : public NodeVisitor {
public:
explicit ExprValidator(LLVMTypes * types, SchemaPtr schema)
: types_(types),
schema_(schema) {
for (auto &field : schema_->fields()) {
field_map_[field->name()] = field;
}
}

/// \brief Validates the root node
/// of an expression.
/// 1. Data type of fields and literals.
/// 2. Function signature is supported.
/// 3. For if nodes that return types match
/// for if, then and else nodes.
Status Validate(const ExpressionPtr &expr);

private:
Status Visit(const FieldNode &node) override;
Status Visit(const FunctionNode &node) override;
Status Visit(const IfNode &node) override;
Status Visit(const LiteralNode &node) override;

FunctionRegistry registry_;

LLVMTypes *types_;

SchemaPtr schema_;

using FieldMap = std::unordered_map<std::string,
FieldPtr,
boost::hash<std::string>>;
FieldMap field_map_;
};

} // namespace gandiva

#endif //GANDIVA_EXPR_VALIDATOR_H
3 changes: 2 additions & 1 deletion cpp/src/gandiva/codegen/llvm_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class LLVMGenerator {
Status Execute(const arrow::RecordBatch &record_batch,
const ArrayDataVector &output_vector);

LLVMTypes *types() { return types_; }

private:
LLVMGenerator();

Expand All @@ -59,7 +61,6 @@ class LLVMGenerator {
llvm::Module *module() { return engine_->module(); }
llvm::LLVMContext &context() { return *(engine_->context()); }
llvm::IRBuilder<> &ir_builder() { return engine_->ir_builder(); }
LLVMTypes *types() { return types_; }

/// Visitor to generate the code for a decomposed expression.
class Visitor : public DexVisitor {
Expand Down
19 changes: 10 additions & 9 deletions cpp/src/gandiva/codegen/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "codegen/node_visitor.h"
#include "gandiva/arrow.h"
#include "gandiva/gandiva_aliases.h"
#include "gandiva/status.h"

namespace gandiva {

Expand All @@ -36,7 +37,7 @@ class Node {
const DataTypePtr &return_type() const { return return_type_; }

/// Derived classes should simply invoke the Visit api of the visitor.
virtual void Accept(NodeVisitor &visitor) const = 0;
virtual Status Accept(NodeVisitor &visitor) const = 0;

protected:
DataTypePtr return_type_;
Expand All @@ -49,8 +50,8 @@ class LiteralNode : public Node {
: Node(type),
holder_(holder) {}

void Accept(NodeVisitor &visitor) const override {
visitor.Visit(*this);
Status Accept(NodeVisitor &visitor) const override {
return visitor.Visit(*this);
}

const LiteralHolder &holder() const { return holder_; }
Expand All @@ -65,8 +66,8 @@ class FieldNode : public Node {
explicit FieldNode(FieldPtr field)
: Node(field->type()), field_(field) {}

void Accept(NodeVisitor &visitor) const override {
visitor.Visit(*this);
Status Accept(NodeVisitor &visitor) const override {
return visitor.Visit(*this);
}

const FieldPtr &field() const { return field_; }
Expand All @@ -83,8 +84,8 @@ class FunctionNode : public Node {
DataTypePtr retType)
: Node(retType), descriptor_(descriptor), children_(children) { }

void Accept(NodeVisitor &visitor) const override {
visitor.Visit(*this);
Status Accept(NodeVisitor &visitor) const override {
return visitor.Visit(*this);
}

const FuncDescriptorPtr &descriptor() const { return descriptor_; }
Expand Down Expand Up @@ -125,8 +126,8 @@ class IfNode : public Node {
then_node_(then_node),
else_node_(else_node) {}

void Accept(NodeVisitor &visitor) const override {
visitor.Visit(*this);
Status Accept(NodeVisitor &visitor) const override {
return visitor.Visit(*this);
}

const NodePtr &condition() const { return condition_; }
Expand Down
9 changes: 5 additions & 4 deletions cpp/src/gandiva/codegen/node_visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define GANDIVA_NODE_VISITOR_H

#include "gandiva/logging.h"
#include "gandiva/status.h"

namespace gandiva {

Expand All @@ -27,10 +28,10 @@ class LiteralNode;
/// \brief Visitor for nodes in the expression tree.
class NodeVisitor {
public:
virtual void Visit(const FieldNode &node) = 0;
virtual void Visit(const FunctionNode &node) = 0;
virtual void Visit(const IfNode &node) = 0;
virtual void Visit(const LiteralNode &node) = 0;
virtual Status Visit(const FieldNode &node) = 0;
virtual Status Visit(const FunctionNode &node) = 0;
virtual Status Visit(const IfNode &node) = 0;
virtual Status Visit(const LiteralNode &node) = 0;
};

} // namespace gandiva
Expand Down
Loading

0 comments on commit 30eab61

Please sign in to comment.