Skip to content

Commit

Permalink
Fix rewriting of single term exprs
Browse files Browse the repository at this point in the history
  • Loading branch information
tsandall committed Mar 22, 2017
1 parent f3fc6c1 commit c953d62
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 11 deletions.
49 changes: 38 additions & 11 deletions rego/rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,16 +175,7 @@ func (r *Rego) Eval(ctx context.Context) (ResultSet, error) {
return nil, err
}

// If the query contains expressions that consist of a single term, rewrite
// those expressions so that we capture the value of the term in a variable
// that can be included in the result.
for i := range query {
if !query[i].Negated {
if term, ok := query[i].Terms.(*ast.Term); ok {
query[i].Terms = ast.Equality.Expr(term, r.generateTermVar()).Terms
}
}
}
query = r.captureTerms(query)

// Compile inputs
compiled, err := r.compile(parsed, query)
Expand Down Expand Up @@ -301,7 +292,10 @@ func (r *Rego) eval(ctx context.Context, compiled ast.Body, txn storage.Transact
}
}
for _, expr := range compiled {
if _, ok := exprs[expr]; !ok {
// Don't include expressions without locations. Lack of location
// indicates it was not parsed and so the caller should not be
// shown it.
if _, ok := exprs[expr]; !ok && expr.Location != nil {
result.Expressions = append(result.Expressions, newExpressionValue(expr, true))
}
}
Expand All @@ -320,6 +314,39 @@ func (r *Rego) eval(ctx context.Context, compiled ast.Body, txn storage.Transact
return rs, nil
}

func (r *Rego) captureTerms(query ast.Body) ast.Body {

// If the query contains expressions that consist of a single term, rewrite
// those expressions so that we capture the value of the term in a variable
// that can be included in the result.
extras := map[*ast.Expr]struct{}{}

for i := range query {
if !query[i].Negated {
if term, ok := query[i].Terms.(*ast.Term); ok {

// If len(query) > 1 we must still test that evaluated value is
// not false.
if len(query) > 1 {
cpy := query[i].Copy()
// Unset location so that this expression is not included
// in the results.
cpy.Location = nil
extras[cpy] = struct{}{}
}

query[i].Terms = ast.Equality.Expr(term, r.generateTermVar()).Terms
}
}
}

for expr := range extras {
query.Append(expr)
}

return query
}

func (r *Rego) generateTermVar() *ast.Term {
r.termVarID++
return ast.VarTerm(ast.WildcardPrefix + fmt.Sprintf("term%v", r.termVarID))
Expand Down
47 changes: 47 additions & 0 deletions rego/rego_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright 2017 The OPA Authors. All rights reserved.
// Use of this source code is governed by an Apache2
// license that can be found in the LICENSE file.

package rego

import (
"context"
"encoding/json"
"reflect"
"testing"
)

func TestRegoCaptureTermsRewrite(t *testing.T) {

ctx := context.Background()

r := New(
Query(`x; deadbeef = 1; y; z`),
Package(`test`),
Module("", `
package test
x = 1
y = 2
z = 3
`),
)

rs, err := r.Eval(ctx)

if len(rs) != 1 || len(rs[0].Expressions) != 4 || err != nil {
t.Fatalf("Unexpected result set: %v (err: %v)", rs, err)
}

expected := map[string]interface{}{
"x": json.Number("1"),
"y": json.Number("2"),
"z": json.Number("3"),
"deadbeef = 1": true,
}

for _, ev := range rs[0].Expressions {
if !reflect.DeepEqual(expected[ev.Text], ev.Value) {
t.Fatalf("Expected %v == %v but got: %v", ev.Text, expected[ev.Text], ev.Value)
}
}
}

0 comments on commit c953d62

Please sign in to comment.